1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Decorator to overrides the gradient for a function."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20from tensorflow.python.eager import backprop
21from tensorflow.python.eager import context
22from tensorflow.python.eager import tape as tape_lib
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import gen_array_ops
27from tensorflow.python.ops import handle_data_util
28from tensorflow.python.ops import op_selector
29from tensorflow.python.ops import resource_variable_ops
30from tensorflow.python.ops import variable_scope
31from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
32from tensorflow.python.platform import tf_logging as logging
33from tensorflow.python.util import nest
34from tensorflow.python.util import tf_decorator
35from tensorflow.python.util import tf_inspect
36from tensorflow.python.util.tf_export import tf_export
37
38
39VAR_OP_TYPES = [
40    "VariableV2",
41    "VarHandleOp",
42]
43
44
45# TODO(allenl): Remove this alias and migrate callers.
46copy_handle_data = handle_data_util.copy_handle_data
47
48
49@tf_export("custom_gradient")
50def custom_gradient(f=None):
51  """Decorator to define a function with a custom gradient.
52
53  This decorator allows fine grained control over the gradients of a sequence
54  for operations.  This may be useful for multiple reasons, including providing
55  a more efficient or numerically stable gradient for a sequence of operations.
56
57  For example, consider the following function that commonly occurs in the
58  computation of cross entropy and log likelihoods:
59
60  ```python
61  def log1pexp(x):
62    return tf.math.log(1 + tf.exp(x))
63  ```
64
65  Due to numerical instability, the gradient of this function evaluated at x=100
66  is NaN.  For example:
67
68  ```python
69  x = tf.constant(100.)
70  y = log1pexp(x)
71  dy = tf.gradients(y, x) # Will be NaN when evaluated.
72  ```
73
74  The gradient expression can be analytically simplified to provide numerical
75  stability:
76
77  ```python
78  @tf.custom_gradient
79  def log1pexp(x):
80    e = tf.exp(x)
81    def grad(dy):
82      return dy * (1 - 1 / (1 + e))
83    return tf.math.log(1 + e), grad
84  ```
85
86  With this definition, the gradient at x=100 will be correctly evaluated as
87  1.0.
88
89  The variable `dy` is defined as the upstream gradient. i.e. the gradient from
90  all the layers or functions originating from this layer.
91
92  By chain rule we know that
93  `dy/dx = dy/dx_0 * dx_0/dx_1 * ... * dx_i/dx_i+1 * ... * dx_n/dx`
94
95  In this case the gradient of our current function defined as
96  `dx_i/dx_i+1 = (1 - 1 / (1 + e))`. The upstream gradient `dy` would be
97  `dx_i+1/dx_i+2 * dx_i+2/dx_i+3 * ... * dx_n/dx`. The upstream gradient
98  multiplied by the current gradient is then passed downstream.
99
100  In case the function takes multiple variables as input, the `grad`
101  function must also return  the same number of variables.
102  We take the function `z = x * y` as an example.
103
104  >>> @tf.custom_gradient
105  ... def bar(x, y):
106  ...   def grad(upstream):
107  ...     dz_dx = y
108  ...     dz_dy = x
109  ...     return upstream * dz_dx, upstream * dz_dy
110  ...   z = x * y
111  ...   return z, grad
112  >>> x = tf.constant(2.0, dtype=tf.float32)
113  >>> y = tf.constant(3.0, dtype=tf.float32)
114  >>> with tf.GradientTape(persistent=True) as tape:
115  ...   tape.watch(x)
116  ...   tape.watch(y)
117  ...   z = bar(x, y)
118  >>> z
119  <tf.Tensor: shape=(), dtype=float32, numpy=6.0>
120  >>> tape.gradient(z, x)
121  <tf.Tensor: shape=(), dtype=float32, numpy=3.0>
122  >>> tape.gradient(z, y)
123  <tf.Tensor: shape=(), dtype=float32, numpy=2.0>
124
125  Nesting custom gradients can lead to unintuitive results. The default
126  behavior does not correspond to n-th order derivatives. For example
127
128  ```python
129  @tf.custom_gradient
130  def op(x):
131    y = op1(x)
132    @tf.custom_gradient
133    def grad_fn(dy):
134      gdy = op2(x, y, dy)
135      def grad_grad_fn(ddy):  # Not the 2nd order gradient of op w.r.t. x.
136        return op3(x, y, dy, ddy)
137      return gdy, grad_grad_fn
138    return y, grad_fn
139  ```
140
141  The function `grad_grad_fn` will be calculating the first order gradient
142  of `grad_fn` with respect to `dy`, which is used to generate forward-mode
143  gradient graphs from backward-mode gradient graphs, but is not the same as
144  the second order gradient of `op` with respect to `x`.
145
146  Instead, wrap nested `@tf.custom_gradients` in another function:
147
148  ```python
149  @tf.custom_gradient
150  def op_with_fused_backprop(x):
151    y, x_grad = fused_op(x)
152    def first_order_gradient(dy):
153      @tf.custom_gradient
154      def first_order_custom(unused_x):
155        def second_order_and_transpose(ddy):
156          return second_order_for_x(...), gradient_wrt_dy(...)
157        return x_grad, second_order_and_transpose
158      return dy * first_order_custom(x)
159    return y, first_order_gradient
160  ```
161
162  Additional arguments to the inner `@tf.custom_gradient`-decorated function
163  control the expected return values of the innermost function.
164
165  See also `tf.RegisterGradient` which registers a gradient function for a
166  primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows
167  for fine grained control over the gradient computation of a sequence of
168  operations.
169
170  Note that if the decorated function uses `Variable`s, the enclosing variable
171  scope must be using `ResourceVariable`s.
172
173  Args:
174    f: function `f(*x)` that returns a tuple `(y, grad_fn)` where:
175       - `x` is a sequence of (nested structures of) `Tensor` inputs to the
176         function.
177       - `y` is a (nested structure of) `Tensor` outputs of applying TensorFlow
178         operations in `f` to `x`.
179       - `grad_fn` is a function with the signature `g(*grad_ys)` which returns
180         a list of `Tensor`s the same size as (flattened) `x` - the derivatives
181         of `Tensor`s in `y` with respect to the `Tensor`s in `x`.  `grad_ys` is
182         a sequence of `Tensor`s the same size as (flattened) `y` holding the
183         initial value gradients for each `Tensor` in `y`.
184
185         In a pure mathematical sense, a vector-argument vector-valued function
186         `f`'s derivatives should be its Jacobian matrix `J`. Here we are
187         expressing the Jacobian `J` as a function `grad_fn` which defines how
188         `J` will transform a vector `grad_ys` when left-multiplied with it
189         (`grad_ys * J`, the vector-Jacobian product, or VJP). This functional
190         representation of a matrix is convenient to use for chain-rule
191         calculation (in e.g. the back-propagation algorithm).
192
193         If `f` uses `Variable`s (that are not part of the
194         inputs), i.e. through `get_variable`, then `grad_fn` should have
195         signature `g(*grad_ys, variables=None)`, where `variables` is a list of
196         the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where
197         `grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>`
198         with the derivatives of `Tensor`s in `y` with respect to the variables
199         (that is, grad_vars has one Tensor per variable in variables).
200
201  Returns:
202    A function `h(x)` which returns the same value as `f(x)[0]` and whose
203    gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`.
204  """
205
206  if f is None:
207    return lambda f: custom_gradient(f=f)
208
209  @Bind.decorator
210  def decorated(wrapped, args, kwargs):
211    """Decorated function with custom gradient."""
212    # raise ValueError("PW: trap")
213
214    if context.executing_eagerly():
215      return _eager_mode_decorator(wrapped, args, kwargs)
216    else:
217      return _graph_mode_decorator(wrapped, args, kwargs)
218
219  return tf_decorator.make_decorator(f, decorated(f))  # pylint: disable=no-value-for-parameter
220
221
222class Bind(object):
223  """When called evaluates `d(f, args, kwargs)` but supports binding `f`.
224
225  >>> @Bind.decorator
226  ... def my_decorator(f, args, kwargs):
227  ...   print("my_decorator called with", args, kwargs)
228  ...   return f(*args, **kwargs)
229
230  >>> class Foo(object):
231  ...   @my_decorator
232  ...   def bar(self, a, b, c):
233  ...     return a * b * c
234
235  >>> Foo.bar(None, 1, 2, c=3)
236  my_decorator called with (None, 1, 2) {'c': 3}
237  6
238
239  >>> foo = Foo()
240  >>> foo.bar(1, 2, c=3)
241  my_decorator called with (1, 2) {'c': 3}
242  6
243  """
244
245  @classmethod
246  def decorator(cls, d):
247    return lambda f: Bind(f, d)
248
249  def __init__(self, f, d):
250    self._f = f
251    self._d = d
252
253  def __get__(self, instance, owner):
254    if instance is not None:
255      f = self._f.__get__(instance, owner)
256      return tf_decorator.make_decorator(f, Bind(f, self._d))
257    else:
258      return self
259
260  def __call__(self, *a, **k):
261    return self._d(self._f, a, k)
262
263
264def get_variable_by_name(var_name):
265  """Given a variable name, retrieves a handle on the tensorflow Variable."""
266
267  candidate_vars = ops.get_collection(
268      ops.GraphKeys.GLOBAL_VARIABLES, scope="{}:0".format(var_name))
269  if len(candidate_vars) >= 1:
270    # Filter out non-trainable variables.
271    candidate_vars = [v for v in candidate_vars if v.trainable]
272  else:
273    raise ValueError("Unsuccessful at finding variable {}.".format(var_name))
274
275  if len(candidate_vars) == 1:
276    return candidate_vars[0]
277  elif len(candidate_vars) > 1:
278    raise ValueError(
279        "Unsuccessful at finding trainable variable {}. "
280        "Number of candidates: {}. "
281        "Candidates: {}".format(var_name, len(candidate_vars), candidate_vars))
282  else:
283    # The variable is not trainable.
284    return None
285
286
287def _get_dependent_variables(input_ops, output_ops):
288  """Finds variables involved in the subgraph between input_ops and output_ops.
289
290  Args:
291    input_ops: Flattened list of input ops
292    output_ops: Flattened list of output ops
293
294  Returns:
295    A list of variables
296  """
297
298  # avoids the edge-case when input_ops == output_ops.
299  output_ops = nest.map_structure(gen_array_ops.identity, output_ops)
300  inbetween_ops = op_selector.get_backward_walk_ops(
301      seed_ops=output_ops,
302      stop_at_ts=input_ops,
303      inclusive=False,
304      only_differentiable=True)
305  var_ops = (op for op in inbetween_ops if op.type in VAR_OP_TYPES)
306  var_names = (op.name for op in var_ops)
307  tf_vars = (get_variable_by_name(var_name) for var_name in var_names)
308  tf_vars = [v for v in tf_vars if v is not None]
309  return tf_vars
310
311
312def _graph_mode_decorator(f, args, kwargs):
313  """Implement custom gradient decorator for graph mode."""
314  # TODO(rsepassi): Add support for kwargs
315  if kwargs:
316    raise ValueError(
317        "The custom_gradient decorator currently supports keywords "
318        "arguments only when eager execution is enabled.")
319  name = "CustomGradient-%s" % ops.uid()
320  args = nest.map_structure(ops.convert_to_tensor, args)
321
322  # Checking global and local variables attempts to ensure that no non-resource
323  # Variables are added to the graph.
324  current_var_scope = variable_scope.get_variable_scope()
325  before_vars = set([
326      v.ref() for v in current_var_scope.global_variables() +
327      current_var_scope.local_variables()
328  ])
329  with tape_lib.VariableWatcher() as variable_watcher:
330    result, grad_fn = f(*args)
331
332  args = nest.flatten(args)
333  flat_result = nest.flatten(result)
334  flat_result_len = len(flat_result)
335
336  after_vars = set([
337      v.ref() for v in current_var_scope.global_variables() +
338      current_var_scope.local_variables()
339  ])
340  new_vars = after_vars - before_vars
341  new_vars_list = [v.deref() for v in new_vars]
342  for v in new_vars_list:
343    if not resource_variable_ops.is_resource_variable(v):
344      raise TypeError(
345          "All variables used by a function wrapped with @custom_gradient must "
346          "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
347          "with `use_resource=False`.")
348
349  # The variables that grad_fn needs to return gradients for are the set of
350  # variables used that are *not* part of the inputs.
351  variables_in_tape = frozenset([
352      v.ref() for v in variable_watcher.watched_variables()
353  ])
354
355  graphs = {getattr(o, "graph", None) for o in flat_result}
356  # Not all results may be tensors. However, we want to ensure all tensor
357  # outputs are from the same graph and get a list of captured inputs for
358  # variable search
359  graphs.discard(None)  # Discard non-graph outputs
360  if graphs:
361    if len(graphs) > 1:
362      raise ValueError(
363          "All custom_gradient outputs should be from the same graph")
364    output_graph = graphs.pop()
365    filtered_input_tensors = []
366    for i in args:
367      if i.graph == output_graph:
368        filtered_input_tensors.append(i)
369  else:
370    filtered_input_tensors = args
371
372  variables_in_subgraph = frozenset([
373      v.ref() for v in _get_dependent_variables(
374          input_ops=filtered_input_tensors, output_ops=flat_result)
375  ])
376  variables = list(
377      [v.deref() for v in variables_in_subgraph.union(variables_in_tape)])
378
379  grad_argspec = tf_inspect.getfullargspec(grad_fn)
380  variables_in_signature = ("variables" in grad_argspec.args or
381                            "variables" in grad_argspec.kwonlyargs or
382                            grad_argspec.varkw)
383  if variables and not variables_in_signature:
384    raise TypeError(
385        "@tf.custom_gradient grad_fn must accept keyword argument 'variables', "
386        "since function uses variables: {}".format(variables))
387  if variables_in_signature and not variables:
388    # User seems to intend to use variables but none were captured.
389    logging.warn("@custom_gradient grad_fn has 'variables' in signature, but "
390                 "no ResourceVariables were used on the forward pass.")
391
392  all_tensors = flat_result + args + variables
393
394  def tape_grad_fn(*result_grads):
395    """Custom grad fn wrapper."""
396    result_grads = result_grads[:flat_result_len]
397    if variables:
398      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
399      if len(variable_grads) != len(variables):
400        raise ValueError("Must return gradient for each variable from "
401                         "@custom_gradient grad_fn.")
402    else:
403      input_grads = grad_fn(*result_grads)
404      variable_grads = []
405
406    # Need to return one value per input to the IdentityN, so pad the
407    # gradients of the inputs of the custom_gradient function with the
408    # gradients of the outputs as well.
409    input_grads = nest.flatten(input_grads)
410    return ([None] * flat_result_len) + input_grads + variable_grads
411
412  @ops.RegisterGradient(name)
413  def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
414    """Custom grad fn wrapper."""
415    return tape_grad_fn(*result_grads)
416
417  original_tensors = all_tensors
418  with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
419    all_tensors = array_ops.identity_n(all_tensors)
420
421  original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]
422
423  # Propagate handle data for happier shape inference for resource variables.
424  for i, t in enumerate(original_tensors):
425    if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
426      all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
427  tape_lib.record_operation(
428      f.__name__, all_tensors, original_tensors, tape_grad_fn)
429  for ot, t in zip(original_tensors, all_tensors):
430    copy_handle_data(ot, t)
431  return nest.pack_sequence_as(
432      structure=result, flat_sequence=all_tensors[:flat_result_len])
433
434
435def _eager_mode_decorator(f, args, kwargs):
436  """Implement custom gradient decorator for eager mode."""
437  with tape_lib.VariableWatcher() as variable_watcher:
438    result, grad_fn = f(*args, **kwargs)
439  args = nest.flatten(args)
440  all_inputs = list(args) + list(kwargs.values())
441  # The variables that grad_fn needs to return gradients for are the set of
442  # variables used that are *not* part of the inputs.
443  variables = [
444      v.deref()  # pylint: disable=g-complex-comprehension
445      for v in set(v.ref() for v in variable_watcher.watched_variables())
446      if all(v.deref() is not i for i in all_inputs)
447  ]
448  grad_argspec = tf_inspect.getfullargspec(grad_fn)
449  if (variables and ("variables" not in grad_argspec.args) and
450      ("variables" not in grad_argspec.kwonlyargs) and
451      not grad_argspec.varkw):
452    raise TypeError(
453        "@tf.custom_gradient grad_fn must accept keyword argument 'variables', "
454        "since function uses variables: {}".format(variables))
455  flat_result = nest.flatten(result)
456  # TODO(apassos) consider removing the identity below.
457  flat_result = [gen_array_ops.identity(x) for x in flat_result]
458
459  input_tensors = [ops.convert_to_tensor(x) for x
460                   in list(args) + list(variables)]
461
462  recorded_inputs = input_tensors
463  arg_count = len(args)
464
465  def actual_grad_fn(*result_grads):
466    """Custom grad fn wrapper."""
467    if variables:
468      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
469      if len(variable_grads) != len(variables):
470        raise ValueError("Must return gradient for each variable from "
471                         "@custom_gradient grad_fn.")
472    else:
473      input_grads = grad_fn(*result_grads)
474      variable_grads = []
475    flat_grads = nest.flatten(input_grads)
476    if len(flat_grads) != arg_count:
477      raise ValueError(
478          "custom_gradient function expected to return", arg_count,
479          "gradients but returned", len(flat_grads), "instead.")
480    return flat_grads + variable_grads
481
482  tape_lib.record_operation(f.__name__, flat_result, recorded_inputs,
483                            actual_grad_fn)
484  flat_result = list(flat_result)
485  return nest.pack_sequence_as(result, flat_result)
486
487
488@tf_export("recompute_grad")
489def recompute_grad(f):
490  """An eager-compatible version of recompute_grad.
491
492  For f(*args, **kwargs), this supports gradients with respect to args or
493  kwargs, but kwargs are currently only supported in eager-mode.
494  Note that for keras layer and model objects, this is handled automatically.
495
496  Warning: If `f` was originally a tf.keras Model or Layer object, `g` will not
497  be able to access the member variables of that object, because `g` returns
498  through the wrapper function `inner`.  When recomputing gradients through
499  objects that inherit from keras, we suggest keeping a reference to the
500  underlying object around for the purpose of accessing these variables.
501
502  Args:
503    f: function `f(*x)` that returns a `Tensor` or sequence of `Tensor` outputs.
504
505  Returns:
506    A function `g` that wraps `f`, but which recomputes `f` on the backwards
507    pass of a gradient call.
508  """
509  # TODO(cdfreeman) Add is_recomputing functionality from graph mode version
510
511  @custom_gradient
512  def inner(*args, **kwargs):
513    """Inner function closure for calculating gradients."""
514    current_var_scope = variable_scope.get_variable_scope()
515    with tape_lib.stop_recording():
516      result = f(*args, **kwargs)
517
518    def grad_wrapper(*wrapper_args, **grad_kwargs):
519      """Wrapper function to accomodate lack of kwargs in graph mode decorator."""
520
521      @custom_gradient
522      def inner_recompute_grad(*dresult):
523        """Nested custom gradient function for computing grads in reverse and forward mode autodiff."""
524        # Gradient calculation for reverse mode autodiff.
525        variables = grad_kwargs.get("variables")
526        with backprop.GradientTape() as t:
527          id_args = nest.map_structure(gen_array_ops.identity, args)
528          t.watch(id_args)
529          if variables is not None:
530            t.watch(variables)
531          with ops.control_dependencies(dresult):
532            with variable_scope.variable_scope(current_var_scope):
533              result = f(*id_args, **kwargs)
534        kw_vars = []
535        if variables is not None:
536          kw_vars = list(variables)
537        grads = t.gradient(
538            result,
539            list(id_args) + kw_vars,
540            output_gradients=dresult,
541            unconnected_gradients=UnconnectedGradients.ZERO)
542
543        def transpose(*t_args, **t_kwargs):
544          """Gradient function calculation for forward mode autodiff."""
545          # Just throw an error since gradients / activations are not stored on
546          # tape for recompute.
547          raise NotImplementedError(
548              "recompute_grad tried to transpose grad of {}. "
549              "Consider not using recompute_grad in forward mode"
550              "autodiff".format(f.__name__))
551
552        return (grads[:len(id_args)], grads[len(id_args):]), transpose
553
554      return inner_recompute_grad(*wrapper_args)
555
556    return result, grad_wrapper
557
558  return inner
559
560
561@tf_export("grad_pass_through")
562def grad_pass_through(f):
563  """Creates a grad-pass-through op with the forward behavior provided in f.
564
565  Use this function to wrap any op, maintaining its behavior in the forward
566  pass, but replacing the original op in the backward graph with an identity.
567  For example:
568
569  ```python
570  x = tf.Variable(1.0, name="x")
571  z = tf.Variable(3.0, name="z")
572
573  with tf.GradientTape() as tape:
574    # y will evaluate to 9.0
575    y = tf.grad_pass_through(x.assign)(z**2)
576  # grads will evaluate to 6.0
577  grads = tape.gradient(y, z)
578  ```
579
580  Another example is a 'differentiable' moving average approximation, where
581  gradients are allowed to flow into the last value fed to the moving average,
582  but the moving average is still used for the forward pass:
583
584  ```python
585  x = ... # Some scalar value
586  # A moving average object, we don't need to know how this is implemented
587  moving_average = MovingAverage()
588  with backprop.GradientTape() as tape:
589    # mavg_x will evaluate to the current running average value
590    mavg_x = tf.grad_pass_through(moving_average)(x)
591  grads = tape.gradient(mavg_x, x) # grads will evaluate to 1.0
592  ```
593
594  Args:
595    f: function `f(*x)` that returns a `Tensor` or nested structure of `Tensor`
596      outputs.
597
598  Returns:
599    A function `h(x)` which returns the same values as `f(x)` and whose
600    gradients are the same as those of an identity function.
601  """
602  @custom_gradient
603  def _grad_pass_through_op(*args, **kwargs):
604    def grad(*args, **kwargs):
605      variables = kwargs.get("variables")
606      if variables is not None:
607        # Variables involved in the wrapped op will not receive gradients.
608        return args, [None] * len(variables)
609      return args
610    return f(*args, **kwargs), grad
611  return tf_decorator.make_decorator(f, _grad_pass_through_op)
612