1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""Prototype decorator for defining graph functions with eager semantics."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23import weakref
24
25from tensorflow.python.eager import context
26from tensorflow.python.eager import function as function_lib
27from tensorflow.python.eager import lift_to_graph
28from tensorflow.python.framework import func_graph as func_graph_module
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import resource_variable_ops
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.training.tracking import base as trackable
35from tensorflow.python.util import nest
36from tensorflow.python.util import tf_decorator
37from tensorflow.python.util.tf_export import tf_export
38
39
40class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable):
41  """Variable which does not lift its initializer out of function context.
42
43  Instances of this variable, when created, build a graph which runs their
44  initializer inside a tf.cond(is_initialized) block.
45
46  This can only be created inside a defun called from (eventually) eager
47  mode. That is, non-function-building graphs are not supported.
48  """
49
50  def __init__(self,  # pylint: disable=super-init-not-called
51               initial_value=None,
52               trainable=None,
53               caching_device=None,
54               name=None,
55               dtype=None,
56               constraint=None,
57               add_initializers_to=None,
58               lifted_initializer_graph=None,
59               **unused_kwargs):
60    """Creates a variable.
61
62    Args:
63      initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
64        which is the initial value for the Variable. The initial value must have
65        a shape specified unless `validate_shape` is set to False. Can also be a
66        callable with no argument that returns the initial value when called.
67        (Note that initializer functions from init_ops.py must first be bound
68         to a shape before being used here.)
69      trainable: If `True`, GradientTapes automatically watch uses of this
70        Variable.
71      caching_device: Optional device string or function describing where the
72        Variable should be cached for reading.  Defaults to the Variable's
73        device.  If not `None`, caches on another device.  Typical use is to
74        cache on the device where the Ops using the Variable reside, to
75        deduplicate copying through `Switch` and other conditional statements.
76      name: Optional name for the variable. Defaults to `'Variable'` and gets
77        uniquified automatically.
78      dtype: If set, initial_value will be converted to the given type.
79        If None, either the datatype will be kept (if initial_value is
80       a Tensor) or float32 will be used (if it is a Python object convertible
81       to a Tensor).
82      constraint: An optional projection function to be applied to the variable
83        after being updated by an `Optimizer` (e.g. used to implement norm
84        constraints or value constraints for layer weights). The function must
85        take as input the unprojected Tensor representing the value of the
86        variable and return the Tensor for the projected value
87        (which must have the same shape). Constraints are not safe to
88        use when doing asynchronous distributed training.
89      add_initializers_to: if not None and not in legacy graph mode, the
90        initializer tensor will be added to this map in addition to adding the
91        assignment to the function.
92      lifted_initializer_graph: FuncGraph to try to lift initializers to.
93
94    Raises:
95      ValueError: If the initial value is not specified, or does not have a
96        shape and `validate_shape` is `True`.
97      RuntimeError: If called outside of a function definition.
98    """
99    if not ops.inside_function():
100      # If we've been init_scope()d out of the function definition nothing to do
101      # here; we can't really do the capturing or conditional logic.
102      resource_variable_ops.ResourceVariable.__init__(
103          self, initial_value=initial_value, trainable=trainable,
104          caching_device=caching_device, name=name, dtype=dtype,
105          constraint=constraint)
106      return
107    with ops.init_scope():
108      self._in_graph_mode = not context.executing_eagerly()
109    if initial_value is None:
110      raise ValueError("initial_value must be specified.")
111    init_from_fn = callable(initial_value)
112
113    if constraint is not None and not callable(constraint):
114      raise ValueError("The `constraint` argument must be a callable.")
115
116    if isinstance(initial_value, trackable.CheckpointInitialValue):
117      self._maybe_initialize_trackable()
118      self._update_uid = initial_value.checkpoint_position.restore_uid
119      initial_value = initial_value.wrapped_value
120
121    if trainable is None:
122      trainable = True
123    self._trainable = trainable
124    self._save_slice_info = None
125    self._initial_value = None
126    self._initializer_op = None
127    self._is_initialized_op = None
128    self._graph_element = None
129    self._cached_value = None
130    # Store the graph key so optimizers know how to only retrieve variables from
131    # this graph. Guaranteed to be the same as the eager graph_key.
132    self._graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
133    with ops.name_scope(name, "Variable", []
134                        if init_from_fn else [initial_value]) as name:
135      # pylint: disable=protected-access
136      with ops.init_scope():
137        handle_name = ops._name_from_scope_name(name)
138        unique_id = "%s_%d" % (handle_name, ops.uid())
139        shared_name = context.shared_name(unique_id)
140      with ops.name_scope("Initializer"), ops.device(None):
141        initial_value = ops.convert_to_tensor(
142            initial_value() if init_from_fn else initial_value,
143            name="initial_value", dtype=dtype)
144      with ops.init_scope():
145        self._handle = resource_variable_ops.eager_safe_variable_handle(
146            initial_value=initial_value,
147            shared_name=shared_name,
148            name=name,
149            graph_mode=self._in_graph_mode)
150      self._shape = initial_value.shape
151      self._unique_id = unique_id
152      self._handle_name = handle_name + ":0"
153      self._dtype = initial_value.dtype.base_dtype
154      self._constraint = constraint
155      assert initial_value is not None
156      if self._in_graph_mode:
157        with ops.init_scope():
158          outer_graph = ops.get_default_graph()
159        func_graph = ops.get_default_graph()
160        function_placeholders = (
161            func_graph.inputs + func_graph.internal_captures)
162        placeholder_ops = set(
163            [tensor.op for tensor in function_placeholders])
164        lifted_initializer = lift_to_graph.lift_to_graph(
165            [initial_value], outer_graph,
166            disallowed_placeholders=placeholder_ops)[initial_value]
167        with ops.init_scope():
168          self._initial_value = lifted_initializer
169          with ops.name_scope("IsInitialized"):
170            self._is_initialized_op = (
171                resource_variable_ops.var_is_initialized_op(self._handle))
172          if initial_value is not None:
173            with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
174              self._initializer_op = resource_variable_ops.assign_variable_op(
175                  self._handle, lifted_initializer, name=n)
176          with ops.name_scope("Read"), ops.colocate_with(self._handle):
177            # Manually assign reads to the handle's device to avoid log
178            # messages.
179            with ops.device(self._handle.device):
180              value = self._read_variable_op()
181            self._graph_element = value
182          ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self)
183      else:
184        if add_initializers_to is not None:
185          add_initializers_to[self] = initial_value
186        def assign_fn():
187          with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
188            resource_variable_ops.assign_variable_op(
189                self._handle,
190                initial_value,
191                name=n)
192            # Returning values to keep tf.cond happy.
193          return ops.convert_to_tensor(1)
194        def not_assign_fn():
195          return ops.convert_to_tensor(0)
196        # Note: this cond is always guaranteed to run because we're inside a
197        # defun which will insert automatic control dependencies.
198        control_flow_ops.cond(
199            resource_variable_ops.var_is_initialized_op(self._handle),
200            not_assign_fn, assign_fn)
201
202    # After the handle has been created, set up a way to clean it up when
203    # executing eagerly. We'll hold the only reference to the deleter, so that
204    # when this object is garbage collected the deleter will be too. This
205    # means ResourceVariables can be part of reference cycles without those
206    # cycles being uncollectable.
207    if not self._in_graph_mode:
208      self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
209          handle=self._handle, handle_device=self._handle.device)
210    self._cached_shape_as_list = None
211
212
213RUN_FUNCTIONS_EAGERLY = False
214
215
216@tf_export("config.experimental_run_functions_eagerly")
217def run_functions_eagerly(run_eagerly):
218  """Enables / disables eager execution of `tf.function`s.
219
220  After calling `tf.config.experimental_run_functions_eagerly(True)` all
221  invocations of tf.function will run eagerly instead of running through a graph
222  function.
223
224  This can be useful for debugging or profiling.
225
226  Similarly, calling `tf.config.experimental_run_functions_eagerly(False)` will
227  revert the behavior of all functions to graph functions.
228
229  Args:
230    run_eagerly: Boolean. Whether to run functions eagerly.
231  """
232  global RUN_FUNCTIONS_EAGERLY
233  RUN_FUNCTIONS_EAGERLY = bool(run_eagerly)
234
235
236class FunctionDeleter(object):
237
238  def __init__(self, func_graph):
239    self.func_graph = func_graph
240
241  def __del__(self):
242    try:
243      func_graph_module.dismantle_func_graph(self.func_graph)
244    except:  # pylint: disable=bare-except
245      # Note: bare except here because this can be noisy at shutdown time.
246      pass
247
248
249class Function(object):
250  """Wrapper class for the graph functions defined for a Python function.
251
252  See the documentation for `tf.function` for more information on the semantics
253  of defined functions.
254
255  `Function` is thread-compatible.
256  """
257
258  def __init__(self,
259               python_function,
260               name,
261               input_signature=None,
262               autograph=True,
263               experimental_autograph_options=None):
264    """Initializes a `Function`.
265
266    Args:
267      python_function: the function to be wrapped.
268      name: the name given to it.
269      input_signature: a possibly nested sequence of `TensorSpec` objects
270        specifying the input signature of this function. If `None`, a separate
271        function is instantiated for each inferred input signature.
272      autograph: whether `python_function` should be converted to graph mode.
273        See https://www.tensorflow.org/guide/autograph for more information.
274      experimental_autograph_options: optional tuple of
275        tensorflow.autograph.Feature values. Allows enabling additional
276        conversion options when autograph is set to True.
277
278    Raises:
279      ValueError: if `input_signature` is not None and the `python_function`'s
280        argspec has keyword arguments.
281    """
282    self._python_function = python_function
283    # TODO(vbardiovsky): Both _stateful_fn and _stateless_fn are populating the
284    # same FunctionSpec. Consider removing it from both and passing in instead.
285    self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
286        python_function, input_signature)
287    self._autograph = autograph
288    self._experimental_autograph_options = experimental_autograph_options
289    self._created_variables = None
290    self._stateful_fn = None
291    self._stateless_fn = None
292    self._descriptor_cache = weakref.WeakKeyDictionary()
293    self._name = name
294
295  def _defun_with_scope(self, scope):
296    """Creates a defun wrapped inside a variable creator scope."""
297
298    weak_wrapped_fn = None
299    def wrapped_fn(*args, **kwds):
300      """Wraps `self._python_function` in a variable creator scope."""
301      # We register a variable creator with reduced priority. If an outer
302      # variable creator is just modifying keyword arguments to the variable
303      # constructor, this will work harmoniously. Since the `scope` registered
304      # here actually creates the variable, it taking priority would otherwise
305      # ignore the outer creator.
306      #
307      # If an outer variable creator calls the variable constructor manually,
308      # for example creating a MirroredVariable, then they won't call our
309      # creator. This means we won't be able to trace the initialization graph,
310      # and so variable initializers can't depend on function arguments. This is
311      # better than the alternative, tracing the initialization graph but giving
312      # the user a variable type they didn't want.
313      with ops.get_default_graph()._variable_creator_scope(scope, priority=50):  # pylint: disable=protected-access
314        # __wrapped__ allows AutoGraph to swap in a converted function. We give
315        # the function a weak reference to itself to avoid a reference cycle.
316        return weak_wrapped_fn().__wrapped__(*args, **kwds)
317    weak_wrapped_fn = weakref.ref(wrapped_fn)
318
319    # TODO(mdan): Pipe self._experimental_autograph_options through.
320    return function_lib.defun(
321        tf_decorator.make_decorator(
322            self._python_function,
323            wrapped_fn,
324            decorator_argspec=self._function_spec.fullargspec),
325        input_signature=self.input_signature,
326        autograph=self._autograph,
327        experimental_autograph_options=self._experimental_autograph_options)
328
329  def _initialize(self, args, kwds, add_initializers_to=None):
330    """Initializes, on the first call.
331
332    Creates two `Function`s, one that will allow creation of variables
333    and one that won't.
334
335    Additionally runs a trace for the `Function` that allows creation
336    of variables.
337
338    Args:
339      args: Arguments to the underlying python callable.
340      kwds: Keyword arguments to the python callable.
341      add_initializers_to: Where to collect variable initializers, if not None.
342    """
343
344    created_variables = []
345    lifted_initializer_graph = func_graph_module.FuncGraph("initializer")
346
347    def variable_capturing_scope(unused_next_creator, **kwds):
348      """Creates UnliftedInitializerVariables and saves references to them."""
349      v = UnliftedInitializerVariable(
350          add_initializers_to=add_initializers_to,
351          lifted_initializer_graph=lifted_initializer_graph, **kwds)
352      created_variables.append(weakref.ref(v))
353      return v
354
355    self._created_variables = created_variables
356    self._stateful_fn = self._defun_with_scope(variable_capturing_scope)
357    self._stateful_fn._name = self._name  # pylint: disable=protected-access
358    # Force the definition of the function for these arguments
359    self._lifted_initializer_graph = lifted_initializer_graph
360    self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
361    self._concrete_stateful_fn = (
362        self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
363            *args, **kwds))
364
365    def invalid_creator_scope(*unused_args, **unused_kwds):
366      """Disables variable creation."""
367      raise ValueError(
368          "tf.function-decorated function tried to create "
369          "variables on non-first call.")
370
371    self._stateless_fn = self._defun_with_scope(invalid_creator_scope)
372    self._stateless_fn._name = self._name  # pylint: disable=protected-access
373
374  def _decorate(self, decorator):
375    """Allows the captured Python function to be decorated in place.
376
377    This method is only safe to call when the Function has not been called by a
378    user. It makes sense to use this method to push a decorator into the
379    function rather than wrapping the function in the decorator.
380
381    We use this in tf.Module to allow user annotated `tf.functions` to remain as
382    `Function` objects but still automatically enter the Module name_scope
383    when they are evaluated like all other methods.
384
385    Args:
386      decorator: A callable accepting a single argument which is the function
387        to decorate and returning a callable result.
388
389    Raises:
390      ValueError: If the function has been called a ValueError is raised.
391    """
392    if self._stateful_fn is not None or self._stateless_fn is not None:
393      raise ValueError(
394          "Functions cannot be decorated after they have been traced.")
395
396    self._python_function = decorator(self._python_function)
397    self._function_spec = function_lib.FunctionSpec.from_function_and_signature(
398        self._python_function, self.input_signature)
399
400  def __call__(self, *args, **kwds):
401    """Calls the graph function."""
402    if RUN_FUNCTIONS_EAGERLY:
403      return self._python_function(*args, **kwds)
404    if self._created_variables:
405      # In this case we have created variables on the first call, so we run the
406      # defunned version which is guaranteed to never create variables.
407      return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
408    elif self._stateful_fn is not None:
409      # In this case we have not created variables on the first call. So we can
410      # run the first trace but we should fail if variables are created.
411      results = self._stateful_fn(*args, **kwds)
412      if self._created_variables:
413        raise ValueError("Creating variables on a non-first call to a function"
414                         " decorated with tf.function.")
415      return results
416
417    # This is the first call of __call__, so we have to initialize.
418    initializer_map = {}
419    self._initialize(args, kwds, add_initializers_to=initializer_map)
420    if self._created_variables:
421      try:
422        # Attempt to initialize variables eagerly and without conds by lifting
423        # out initialization graphs. This is the only initialization strategy
424        # compatible with XLA at the moment.
425        self._initialize_uninitialized_variables(initializer_map)
426      except lift_to_graph.UnliftableError:
427        pass  # Fall through to cond-based initialization.
428      else:
429        # Lifting succeeded, so variables are initialized and we can run the
430        # stateless function.
431        return self._stateless_fn(*args, **kwds)
432    else:
433      canon_args, canon_kwds = \
434          self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
435              *args, **kwds)
436      # If we did not create any variables the trace we have is good enough.
437      return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)  # pylint: disable=protected-access
438
439    def fn_with_cond(*inner_args, **inner_kwds):
440      """Conditionally runs initialization if it's needed."""
441      condition = True
442      for wr in self._created_variables:
443        variable = wr()
444        if variable is None:
445          raise ValueError(
446              "A tf.Variable created inside your tf.function has been"
447              " garbage-collected. Your code needs to keep Python references"
448              " to variables created inside `tf.function`s.\n"
449              "\n"
450              "A common way to raise this error is to create and return a"
451              " variable only referenced inside your function:\n"
452              "\n"
453              "@tf.function\n"
454              "def f():\n"
455              "  v = tf.Variable(1.0)\n"
456              "  return v\n"
457              "\n"
458              "v = f()  # Crashes with this error message!\n"
459              "\n"
460              "The reason this crashes is that @tf.function annotated"
461              " function returns a **`tf.Tensor`** with the **value** of the"
462              " variable when the function is called rather than the"
463              " variable instance itself. As such there is no code holding a"
464              " reference to the `v` created inside the function and Python"
465              " garbage collects it.\n"
466              "\n"
467              "The simplest way to fix this issue is to create variables"
468              " outside the function and capture them:\n"
469              "\n"
470              "v = tf.Variable(1.0)\n"
471              "\n"
472              "@tf.function\n"
473              "def f():\n"
474              "  return v\n"
475              "\n"
476              "f()  # <tf.Tensor: ... numpy=1.>\n"
477              "v.assign_add(1.)\n"
478              "f()  # <tf.Tensor: ... numpy=2.>")
479        condition = math_ops.logical_and(
480            condition, resource_variable_ops.var_is_initialized_op(
481                variable.handle))
482      # We want to call stateless_fn if possible because it avoids recomputing
483      # potentially expensive initializers.
484      return control_flow_ops.cond(
485          condition,
486          lambda: self._stateless_fn(*inner_args, **inner_kwds),
487          functools.partial(self._concrete_stateful_fn._filtered_call,  # pylint: disable=protected-access
488                            inner_args, inner_kwds))
489
490    # We've created variables and are unable to lift the initialization graphs,
491    # so we fall back to initializing with conds while running the function.
492    canon_args, canon_kwds = \
493        self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
494            *args, **kwds)
495    return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds)
496
497  @property
498  def python_function(self):
499    """The python function wrapped in this tf.function."""
500    return self._python_function
501
502  @property
503  def input_signature(self):
504    return self._function_spec.input_signature
505
506  @property
507  def function_spec(self):
508    return self._function_spec
509
510  def _initialize_uninitialized_variables(self, initializer_map):
511    """Make and call a `ConcreteFunction` which initializes variables."""
512
513    # Note: using defun here avoids an infinite recursion.
514    # Note: there is no reason not to autograph once the overhead is negligible.
515    @function_lib.defun(autograph=False)  # tf.function internal, pure graph
516    def initialize_variables():
517      for v, init in initializer_map.items():
518        with ops.init_scope():
519          if resource_variable_ops.var_is_initialized_op(v.handle):
520            # Ignore variables which are already initialized at trace time.
521            continue
522        v.assign(lift_to_graph.lift_to_graph(
523            [init], ops.get_default_graph())[init])
524
525    with ops.init_scope():
526      return initialize_variables.get_concrete_function()()
527
528  def get_initialization_function(self, *args, **kwargs):
529    """Returns a `ConcreteFunction` which initializes this function's variables.
530
531    Requires that this function hasn't been accessed yet through either calling
532    it or calling get_concrete_function. Fails if we cannot build an initializer
533    function which does not depend on the concrete values of the inputs to this
534    function.
535
536    Note that running this function will overwrite any values currently assigned
537    to variables, for example restores from a checkpoint.
538
539    Args:
540      *args: arguments to the underlying python callable.
541      **kwargs: keyword arguments to the python callable.
542
543    Returns:
544      A `ConcreteFunction` object which initializes the variables of this
545      function.
546
547    Raises:
548      RuntimeError: if called after the variables have been initialized.
549    """
550    if self._stateful_fn is not None:
551      raise RuntimeError(
552          "get_initialization_function cannot be called after the function "
553          "has been used")
554    # Here we trace the function, collect the initializers, and attempt to
555    # extract them and run them eagerly. Fail only if we cannot do so.
556    initializer_map = {}
557    self._initialize(args, kwargs, add_initializers_to=initializer_map)
558
559    # Note: using defun here avoids an infinite recursion.
560    @function_lib.defun
561    def initialize_variables():
562      for v, init in initializer_map.items():
563        v.assign(lift_to_graph.lift_to_graph(
564            [init], ops.get_default_graph())[init])
565
566    return initialize_variables.get_concrete_function()
567
568  def _list_all_concrete_functions_for_serialization(self):
569    """Returns all concrete functions for serialization.
570
571    Returns:
572      A list of instances of `Function`.
573    """
574    if self.input_signature is not None:
575      self.get_concrete_function()
576    concrete_functions = []
577    # pylint: disable=protected-access
578    if self._stateful_fn:
579      concrete_functions.extend(
580          self._stateful_fn._function_cache.all_values())
581    if self._stateless_fn:
582      concrete_functions.extend(
583          self._stateless_fn._function_cache.all_values())
584    # pylint: enable=protected-access
585    deduplicated_concrete_functions = list()
586    seen_signatures = list()
587    # We are using a list so that:
588    #  - the returned collection is deterministic, and
589    #  - we can use a custom equality operator (is_same_structure).
590    # This is run only at serialization time on likely very small inputs so we
591    # are not concerned about O(n^2) runtime.
592    for concrete_function in concrete_functions:
593      signature, _ = concrete_function.structured_input_signature
594      flattened = nest.flatten(signature)
595      if any(
596          isinstance(arg, func_graph_module.UnknownArgument)
597          for arg in flattened):
598        logging.info("Unsupported signature for serialization: %s.", signature)
599        continue
600      equal_to_signature = functools.partial(
601          function_lib.is_same_structure, signature, check_values=True)
602      if not any(equal_to_signature(s) for s in seen_signatures):
603        deduplicated_concrete_functions.append(concrete_function)
604        seen_signatures.append(signature)
605    return deduplicated_concrete_functions
606
607  def get_concrete_function(self, *args, **kwargs):
608    """Returns a `ConcreteFunction` specialized to inputs and execution context.
609
610    If this `Function` was created with an `input_signature`, `args` and
611    `kwargs` may be omitted. With an input signature there is only one
612    concrete function associated with this `Function`.
613
614    If there is no fixed `input_signature` associated with this
615    `Function`, positional and keyword arguments to `get_concrete_function`
616    follow the same rules as input signature specification, with `tf.TensorSpec`
617    objects describing `tf.Tensor`s which will be passed to the concrete
618    function.
619
620    Each `tf.Tensor` argument to the concrete function must have a unique name,
621    either because it is the only one associated with a named argument of the
622    Python function or because an explicit `name=` was passed to its
623    `tf.TensorSpec` object. These names become the argument names for the
624    concrete function.
625
626    Arguments to the concrete function may always be specified as keyword
627    arguments, naming the Tensor input. Positional arguments may be used instead
628    when each preceding argument to the Python function is a Tensor.
629
630    ```python
631    @tf.function
632    def f(x):
633      return x
634
635    f_concrete = f.get_concrete_function(tf.TensorSpec([], tf.float64))
636    f_concrete(tf.constant(1.))
637    f_concrete(x=tf.constant(1.))
638    ```
639
640    Nested structures containing Tensors may be specified when retrieving
641    concrete functions. Structures with multiple Tensors are expanded into
642    multiple arguments of the concrete function. Since multiple concrete
643    function arguments are associated with one argument to the original
644    function, these Tensors must be named explicitly. Tensors in nested
645    structures may not be passed using positional arguments when calling the
646    concrete function.
647
648    ```python
649    f_concrete2 = f.get_concrete_function(
650        (tf.TensorSpec(None, tf.float64, name="first"),
651         tf.TensorSpec([], tf.float32, name="second")))
652    # Keyword arguments are required when identifying Tensors in nested
653    # structures.
654    f_concrete2(first=tf.constant([1.]), second=tf.constant(0.))
655    ```
656
657    Functions with fixed input signatures have only one concrete function
658    associated with them, which can be retrieved without specifying any
659    arguments. As before Tensors must have unique names, either inferred from
660    the argument names in the original Python function or specified
661    explicitly.
662
663    ```python
664    @tf.function(input_signature=(tf.TensorSpec(None, tf.float32)))
665    def f_sig(y):
666      return y
667
668    f_sig_concrete = f.get_concrete_function()
669    f_sig_concrete(tf.constant(1.))
670    f_sig_concrete(y=tf.constant(1.))
671    ```
672
673    Args:
674      *args: inputs to specialize on.
675      **kwargs: inputs to specialize on.
676
677    Returns:
678      A TensorFlow function which takes exactly one `tf.Tensor` per argument.
679
680    Raises:
681      ValueError: if this object has not yet been called on concrete values.
682    """
683    if self._stateful_fn is None:
684      initializer_map = {}
685      self._initialize(args, kwargs, add_initializers_to=initializer_map)
686      self._initialize_uninitialized_variables(initializer_map)
687
688    if self._created_variables:
689      # In this case we have created variables on the first call, so we run the
690      # defunned version which is guaranteed to never create variables.
691      return self._stateless_fn.get_concrete_function(*args, **kwargs)
692    elif self._stateful_fn is not None:
693      # In this case we have not created variables on the first call. So we can
694      # run the first trace but we should fail if variables are created.
695      concrete = self._stateful_fn.get_concrete_function(*args, **kwargs)
696      if self._created_variables:
697        raise ValueError("Creating variables on a non-first call to a function"
698                         " decorated with tf.function.")
699      return concrete
700
701  def __get__(self, instance, owner):
702    """Makes it possible to defun instance methods."""
703    del owner
704    # `instance` here is the instance that this `Function` was accessed through
705    # e.g., for
706    #
707    #   class Foo(object):
708    #
709    #     @function.defun
710    #     def bar(self):
711    #       ...
712    #
713    #   foo = Foo()
714    #   foo.bar()  # `foo.bar` is a `Function` instance
715    #
716    # then `instance` will be `foo` (and `owner` will be `Foo`).  We create a
717    # new instance of `Function` here to allow different instances each
718    # to create variables once, thereby allowing methods to be decorated with
719    # tf.function. Keeps a cache to avoid retracing the function every time the
720    # descriptor is accessed.
721    if instance not in self._descriptor_cache:
722      if instance is None:
723        return self
724      self._descriptor_cache[instance] = (
725          function_lib.class_method_to_instance_method(self, instance))
726    return self._descriptor_cache[instance]
727
728
729@tf_export("function")
730def function(func=None,
731             input_signature=None,
732             autograph=True,
733             experimental_autograph_options=None):
734  """Creates a callable TensorFlow graph from a Python function.
735
736  `function` constructs a callable that executes a TensorFlow graph
737  (`tf.Graph`) created by tracing the TensorFlow operations in `func`.
738  This allows the TensorFlow runtime to apply optimizations and exploit
739  parallelism in the computation defined by `func`.
740
741  _Example Usage_
742
743  ```python
744  def f(x, y):
745    return tf.reduce_mean(tf.multiply(x ** 2, 3) + y)
746
747  g = tf.function(f)
748
749  x = tf.constant([[2.0, 3.0]])
750  y = tf.constant([[3.0, -2.0]])
751
752  # `f` and `g` will return the same value, but `g` will be executed as a
753  # TensorFlow graph.
754  assert f(x, y).numpy() == g(x, y).numpy()
755
756  # Tensors and tf.Variables used by the Python function are captured in the
757  # graph.
758  @tf.function
759  def h():
760    return f(x, y)
761
762  assert (h().numpy() == f(x, y).numpy()).all()
763
764  # Data-dependent control flow is also captured in the graph. Supported
765  # control flow statements include `if`, `for`, `break`, `continue`, `return`.
766  @tf.function
767  def g(x):
768    if tf.reduce_sum(x) > 0:
769      return x * x
770    else:
771      return -x // 2
772
773  # print and TensorFlow side effects are supported, but exercise caution when
774  # using Python side effects like mutating objects, saving to files, etc.
775  l = []
776
777  @tf.function
778  def g(x):
779    for i in x:
780      print(i)                              # Works
781      tf.assign(v, i)                       # Works
782      tf.py_func(lambda i: l.append(i))(i)  # Works
783      l.append(i)                           # Caution! Doesn't work.
784  ```
785
786  Note that unlike other TensorFlow operations, we don't convert python
787  numerical inputs to tensors.
788
789  _Referencing `tf.Variable`s_
790
791  The Python function `func` may reference stateful objects (such as
792  `tf.Variable`).
793  These are captured as implicit inputs to the callable returned by `function`.
794  For example:
795
796  ```python
797  c = tf.Variable(0)
798
799  @tf.function
800  def f(x):
801    c.assign_add(1)
802    return x + tf.to_float(c)
803
804  assert int(c) == 0
805  assert f(1.0) == 2.0
806  assert int(c) == 1
807  assert f(1.0) == 3.0
808  assert int(c) == 2
809  ```
810
811  `function` can be applied to methods of an object. For example:
812
813  ```python
814  class Dense(object):
815    def __init__(self):
816      self.W = tf.Variable(tf.glorot_uniform_initializer()((10, 10)))
817      self.b = tf.Variable(tf.zeros(10))
818
819    @tf.function
820    def compute(self, x):
821      return tf.matmul(x, self.W) + self.b
822
823  d1 = Dense()
824  d2 = Dense()
825  x = tf.random_uniform((10, 10))
826  # d1 and d2 are using distinct variables
827  assert not (d1.compute(x).numpy() == d2.compute(x).numpy()).all()
828  ```
829
830  _Usage with `tf.keras`_
831
832  The `call` methods of a `tf.keras.Model` subclass can be decorated with
833  `function` in order to apply graph execution optimizations on it.
834  For example:
835
836  ```python
837  class MyModel(tf.keras.Model):
838    def __init__(self, keep_probability=0.2):
839      super(MyModel, self).__init__()
840      self.dense1 = tf.keras.layers.Dense(4)
841      self.dense2 = tf.keras.layers.Dense(5)
842      self.keep_probability = keep_probability
843
844    @tf.function
845    def call(self, inputs, training=True):
846      y = self.dense2(self.dense1(inputs))
847      if training:
848        return tf.nn.dropout(y, self.keep_probability)
849      else:
850        return y
851
852  model = MyModel()
853  model(x, training=True)  # executes a graph, with dropout
854  model(x, training=False) # executes a graph, without dropout
855  ```
856
857  _Input Signatures_
858
859  `function` instantiates a separate graph for every unique set of input
860  shapes and datatypes. For example, the following code snippet will result
861  in three distinct graphs being traced, as each input has a different
862  shape.
863
864  ```python
865  @tf.function
866  def f(x): return tf.add(x, 1.)
867
868  scalar = tf.constant(1.0)
869  vector = tf.constant([1.0, 1.0])
870  matrix = tf.constant([[3.0]])
871
872  f(scalar)
873  f(vector)
874  f(matrix)
875  ```
876
877  An "input signature" can be optionally provided to `function` to control
878  the graphs traced. The input signature specifies the shape and type of each
879  `Tensor` argument to the function using a `tf.TensorSpec` object. For example,
880  the following code snippet ensures that a single graph is created where the
881  input `Tensor` is required to be a floating point tensor with no restrictions
882  on shape.
883
884  ```python
885  @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
886  def f(x): return tf.add(x, 1.)
887  ```
888
889  When an `input_signature` is specified, the callable will convert the inputs
890  to the specified TensorSpecs.
891
892  _Tracing and staging_
893
894  When `autograph` is `True`, all Python code that depends on `Tensor` values is
895  staged into a TensorFlow graph. When `autograph` is `False`, the function is
896  traced and control flow is not allowed to depend on data.
897
898  Note that `function` only stages TensorFlow operations, all Python code that
899  `func` executes and does not depend on data will shape the _construction_ of
900  the graph.
901  For example, consider the following:
902
903  ```python
904  import numpy as np
905
906  def add_noise():
907    return tf.eye(5) + np.random.randn(5, 5)
908
909  traced = tf.function(add_noise)
910  ```
911
912  `add_noise()` will return a different output every time it is invoked.
913  However, `traced()` will return the same value every time it is called,
914  since a particular random value generated by the `np.random.randn` call will
915  be inserted in the traced/staged TensorFlow graph as a constant. In this
916  particular example, replacing `np.random.randn(5, 5)` with
917  `tf.random_normal((5, 5))` will result in the same behavior for `add_noise()`
918  and `traced()`.
919
920  _Python Side-Effects_
921
922  A corollary of the previous discussion on tracing is the following: If a
923  Python function `func` has Python side-effects, then executing `func` multiple
924  times may not be semantically equivalent to executing `F = tf.function(func)`
925  multiple times; this difference is due to the fact that `function` only
926  captures the subgraph of TensorFlow operations that is constructed when `func`
927  is invoked to trace a graph.
928
929  The same is true if code with Python side effects is used inside control flow,
930  such as a loop. If your code uses side effects that are not intended to
931  control graph construction, wrap them inside `tf.py_func`.
932
933  Args:
934    func: function to be compiled. If `func` is None, returns a decorator that
935      can be invoked with a single argument - `func`. The end result is
936      equivalent to providing all the arguments up front. In other words,
937      `tf.function(input_signature=...)(func)` is equivalent to
938      `tf.function(func, input_signature=...)`. The former can be used to
939      decorate Python functions, for example:
940        @tf.function(input_signature=...)
941        def foo(...): ...
942    input_signature: A possibly nested sequence of `tf.TensorSpec` objects
943      specifying the shapes and dtypes of the Tensors that will be supplied to
944      this function. If `None`, a separate function is instantiated for each
945      inferred input signature.  If input_signature is specified, every input to
946      `func` must be a `Tensor`, and `func` cannot accept `**kwargs`.
947    autograph: Whether autograph should be applied on `func` before tracing a
948      graph. This allows for dynamic control flow (Python if's, loops etc.)
949      in the traced graph. See https://www.tensorflow.org/guide/autograph for
950        more information.
951    experimental_autograph_options: Experimental knobs (in the form of a tuple
952      of tensorflow.autograph.Feature values) to control behavior when
953      autograph=True.
954
955  Returns:
956     If `func` is not None, returns a callable that will execute the compiled
957     function (and return zero or more `tf.Tensor` objects).
958     If `func` is None, returns a decorator that, when invoked with a single
959     `func` argument, returns a callable equivalent to the case above.
960
961  Raises:
962    TypeError: If `input_signature` is neither `None` nor a sequence of
963      `TensorSpec` objects.
964  """
965  if input_signature is not None:
966    function_lib.validate_signature(input_signature)
967
968  def decorated(inner_function):
969    try:
970      name = inner_function.__name__
971    except AttributeError:
972      name = "function"
973    return tf_decorator.make_decorator(
974        inner_function,
975        Function(
976            inner_function,
977            name,
978            input_signature=input_signature,
979            autograph=autograph,
980            experimental_autograph_options=experimental_autograph_options))
981
982  # This code path is for the `foo = tf.function(foo, ...)` use case
983  if func is not None:
984    return decorated(func)
985
986  # This code path is for the
987  #
988  # @tf.function(...)
989  # def foo(...):
990  #    ...
991  #
992  # use case, which is equivalent to `foo = tf.function(...)(foo)`
993  return decorated
994