1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Decorator to overrides the gradient for a function."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python import pywrap_tensorflow
22from tensorflow.python.eager import backprop
23from tensorflow.python.eager import context
24from tensorflow.python.eager import tape as tape_lib
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import gen_array_ops
29from tensorflow.python.ops import resource_variable_ops
30from tensorflow.python.ops import variable_scope
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.util import nest
33from tensorflow.python.util import tf_decorator
34from tensorflow.python.util import tf_inspect
35from tensorflow.python.util.tf_export import tf_export
36
37
38def copy_handle_data(source_t, target_t):
39  """Copies HandleData for variant and resource type tensors if available.
40
41  The CppShapeInferenceResult::HandleData proto contains information about the
42  shapes and types of the element tensors of resource/variant type tensors.
43  We need to copy this across function boundaries, i.e., when capturing a
44  placeholder or when returning a function tensor as output. If we don't do this
45  the element tensors will have unknown shapes, e.g., if a TensorList variant
46  tensor is captured as a placeholder, elements popped from that list would have
47  unknown shape.
48
49  Args:
50    source_t: The tensor to copy HandleData from.
51    target_t: The tensor to copy HandleData to.
52  """
53  if (target_t.dtype == dtypes.resource or
54      target_t.dtype == dtypes.variant):
55    if isinstance(source_t, ops.EagerTensor):
56      handle_data = source_t._handle_data  # pylint: disable=protected-access
57    else:
58      handle_data = resource_variable_ops.get_resource_handle_data(source_t)
59    if (handle_data is not None
60        and handle_data.is_set
61        and handle_data.shape_and_type):
62      # pylint: disable=protected-access
63      pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
64                                              target_t._as_tf_output(),
65                                              handle_data.SerializeToString())
66      # pylint: enable=protected-access
67      # Ensure that shapes and dtypes are propagated.
68      shapes, types = zip(*[(pair.shape, pair.dtype)
69                            for pair in handle_data.shape_and_type])
70      ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
71      shapes = [[d.size for d in s.dim]
72                if not s.unknown_rank else None for s in shapes]
73      pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
74          target_t._op._graph._c_graph,  # pylint: disable=protected-access
75          target_t._as_tf_output(),  # pylint: disable=protected-access
76          shapes, ranks, types)
77
78
79@tf_export("custom_gradient")
80def custom_gradient(f):
81  """Decorator to define a function with a custom gradient.
82
83  This decorator allows fine grained control over the gradients of a sequence
84  for operations.  This may be useful for multiple reasons, including providing
85  a more efficient or numerically stable gradient for a sequence of operations.
86
87  For example, consider the following function that commonly occurs in the
88  computation of cross entropy and log likelihoods:
89
90  ```python
91  def log1pexp(x):
92    return tf.log(1 + tf.exp(x))
93  ```
94
95  Due to numerical instability, the gradient this function evaluated at x=100 is
96  NaN.  For example:
97
98  ```python
99  x = tf.constant(100.)
100  y = log1pexp(x)
101  dy = tf.gradients(y, x) # Will be NaN when evaluated.
102  ```
103
104  The gradient expression can be analytically simplified to provide numerical
105  stability:
106
107  ```python
108  @tf.custom_gradient
109  def log1pexp(x):
110    e = tf.exp(x)
111    def grad(dy):
112      return dy * (1 - 1 / (1 + e))
113    return tf.log(1 + e), grad
114  ```
115
116  With this definition, the gradient at x=100 will be correctly evaluated as
117  1.0.
118
119  See also `tf.RegisterGradient` which registers a gradient function for a
120  primitive TensorFlow operation. `tf.custom_gradient` on the other hand allows
121  for fine grained control over the gradient computation of a sequence of
122  operations.
123
124  Note that if the decorated function uses `Variable`s, the enclosing variable
125  scope must be using `ResourceVariable`s.
126
127  Args:
128    f: function `f(*x)` that returns a tuple `(y, grad_fn)` where:
129       - `x` is a sequence of `Tensor` inputs to the function.
130       - `y` is a `Tensor` or sequence of `Tensor` outputs of applying
131         TensorFlow operations in `f` to `x`.
132       - `grad_fn` is a function with the signature `g(*grad_ys)` which returns
133         a list of `Tensor`s - the derivatives of `Tensor`s in `y` with respect
134         to the `Tensor`s in `x`.  `grad_ys` is a `Tensor` or sequence of
135         `Tensor`s the same size as `y` holding the initial value gradients for
136         each `Tensor` in `y`. In a pure mathematical sense, a vector-argument
137         vector-valued function `f`'s derivatives should be its Jacobian matrix
138         `J`. Here we are expressing the Jacobian `J` as a function `grad_fn`
139         which defines how `J` will transform a vector `grad_ys` when
140         left-multiplied with it (`grad_ys * J`). This functional representation
141         of a matrix is convenient to use for chain-rule calculation
142         (in e.g. the back-propagation algorithm).
143
144         If `f` uses `Variable`s (that are not part of the
145         inputs), i.e. through `get_variable`, then `grad_fn` should have
146         signature `g(*grad_ys, variables=None)`, where `variables` is a list of
147         the `Variable`s, and return a 2-tuple `(grad_xs, grad_vars)`, where
148         `grad_xs` is the same as above, and `grad_vars` is a `list<Tensor>`
149         with the derivatives of `Tensor`s in `y` with respect to the variables
150         (that is, grad_vars has one Tensor per variable in variables).
151
152  Returns:
153    A function `h(x)` which returns the same value as `f(x)[0]` and whose
154    gradient (as calculated by `tf.gradients`) is determined by `f(x)[1]`.
155  """
156
157  def decorated(*args, **kwargs):
158    """Decorated function with custom gradient."""
159    if context.executing_eagerly():
160      return _eager_mode_decorator(f, *args, **kwargs)
161    else:
162      return _graph_mode_decorator(f, *args, **kwargs)
163
164  return tf_decorator.make_decorator(f, decorated)
165
166
167def _graph_mode_decorator(f, *args, **kwargs):
168  """Implement custom gradient decorator for graph mode."""
169  # TODO(rsepassi): Add support for kwargs
170  if kwargs:
171    raise ValueError(
172        "The custom_gradient decorator currently supports keywords "
173        "arguments only when eager execution is enabled.")
174  name = "CustomGradient-%s" % ops.uid()
175  args = [ops.convert_to_tensor(x) for x in args]
176
177  # Checking global and local variables attempts to ensure that no non-resource
178  # Variables are added to the graph.
179  current_var_scope = variable_scope.get_variable_scope()
180  before_vars = set(current_var_scope.global_variables() +
181                    current_var_scope.local_variables())
182  with backprop.GradientTape() as tape:
183    result, grad_fn = f(*args)
184  after_vars = set(current_var_scope.global_variables() +
185                   current_var_scope.local_variables())
186  new_vars = after_vars - before_vars
187  for v in new_vars:
188    if not resource_variable_ops.is_resource_variable(v):
189      raise TypeError(
190          "All variables used by a function wrapped with @custom_gradient must "
191          "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
192          "with `use_resource=False`.")
193  # The variables that grad_fn needs to return gradients for are the set of
194  # variables used that are *not* part of the inputs.
195  variables = list(set(tape.watched_variables()) - set(args))
196  grad_argspec = tf_inspect.getfullargspec(grad_fn)
197  variables_in_signature = ("variables" in grad_argspec.args or
198                            grad_argspec.varkw)
199  if variables and not variables_in_signature:
200    raise TypeError("If using @custom_gradient with a function that "
201                    "uses variables, then grad_fn must accept a keyword "
202                    "argument 'variables'.")
203  if variables_in_signature and not variables:
204    # User seems to intend to use variables but none were captured.
205    if not variable_scope.get_variable_scope().use_resource:
206      raise TypeError("If using @custom_gradient with a function that "
207                      "uses variables, the enclosing variable scope must "
208                      "have use_resource=True.")
209    else:
210      logging.warn("@custom_gradient grad_fn has 'variables' in signature, but "
211                   "no ResourceVariables were used on the forward pass.")
212  flat_result = nest.flatten(result)
213  all_tensors = flat_result + args + variables
214
215  def tape_grad_fn(*result_grads):
216    """Custom grad fn wrapper."""
217    result_grads = result_grads[:len(flat_result)]
218    if variables:
219      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
220      if len(variable_grads) != len(variables):
221        raise ValueError("Must return gradient for each variable from "
222                         "@custom_gradient grad_fn.")
223    else:
224      input_grads = grad_fn(*result_grads)
225      variable_grads = []
226
227    # Need to return one value per input to the IdentityN, so pad the
228    # gradients of the inputs of the custom_gradient function with the
229    # gradients of the outputs as well.
230    input_grads = nest.flatten(input_grads)
231    return ([None] * len(flat_result)) + input_grads + variable_grads
232
233  @ops.RegisterGradient(name)
234  def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
235    """Custom grad fn wrapper."""
236    return tape_grad_fn(*result_grads)
237
238  original_tensors = all_tensors
239  with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
240    all_tensors = array_ops.identity_n(all_tensors)
241
242  original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]
243
244  # Propagate handle data for happier shape inference for resource variables.
245  for i, t in enumerate(original_tensors):
246    if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
247      all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
248  tape_lib.record_operation(
249      f.__name__, all_tensors, original_tensors, tape_grad_fn)
250  for ot, t in zip(original_tensors, all_tensors):
251    copy_handle_data(ot, t)
252  return nest.pack_sequence_as(
253      structure=result, flat_sequence=all_tensors[:len(flat_result)])
254
255
256def _eager_mode_decorator(f, *args, **kwargs):
257  """Implement custom gradient decorator for eager mode."""
258  with backprop.GradientTape() as tape:
259    result, grad_fn = f(*args, **kwargs)
260  all_inputs = list(args) + list(kwargs.values())
261  # The variables that grad_fn needs to return gradients for are the set of
262  # variables used that are *not* part of the inputs.
263  variables = [v for v in set(tape.watched_variables()) if v not in all_inputs]
264  grad_argspec = tf_inspect.getfullargspec(grad_fn)
265  if (variables and ("variables" not in grad_argspec.args) and
266      not grad_argspec.varkw):
267    raise TypeError("If using @custom_gradient with a function that "
268                    "uses variables, then grad_fn must accept a keyword "
269                    "argument 'variables'.")
270  flat_result = nest.flatten(result)
271  # TODO(apassos) consider removing the identity below.
272  flat_result = [gen_array_ops.identity(x) for x in flat_result]
273
274  input_tensors = [ops.convert_to_tensor(x) for x
275                   in list(args) + list(variables)]
276  arg_count = len(args)
277  def actual_grad_fn(*result_grads):
278    """Custom grad fn wrapper."""
279    if variables:
280      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
281      if len(variable_grads) != len(variables):
282        raise ValueError("Must return gradient for each variable from "
283                         "@custom_gradient grad_fn.")
284    else:
285      input_grads = grad_fn(*result_grads)
286      variable_grads = []
287    flat_grads = nest.flatten(input_grads)
288    if len(flat_grads) != arg_count:
289      raise ValueError(
290          "custom_gradient function expected to return", arg_count,
291          "gradients but returned", len(flat_grads), "instead.")
292    return nest.flatten(input_grads) + variable_grads
293
294  tape_lib.record_operation(f.__name__, flat_result, input_tensors,
295                            actual_grad_fn)
296  flat_result = list(flat_result)
297  return nest.pack_sequence_as(result, flat_result)
298