1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Implements the graph generation for computation of gradients."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_grad  # pylint: disable=unused-import
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import check_ops  # pylint: disable=unused-import
26from tensorflow.python.ops import control_flow_grad  # pylint: disable=unused-import
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import gradients_util
29from tensorflow.python.ops import image_grad  # pylint: disable=unused-import
30from tensorflow.python.ops import linalg_grad  # pylint: disable=unused-import
31from tensorflow.python.ops import linalg_ops  # pylint: disable=unused-import
32from tensorflow.python.ops import logging_ops  # pylint: disable=unused-import
33from tensorflow.python.ops import manip_grad  # pylint: disable=unused-import
34from tensorflow.python.ops import math_grad  # pylint: disable=unused-import
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops import optional_grad  # pylint: disable=unused-import
37from tensorflow.python.ops import random_grad  # pylint: disable=unused-import
38from tensorflow.python.ops import tensor_array_ops
39from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
40from tensorflow.python.util.tf_export import tf_export
41
42
43@tf_export(v1=["gradients"])
44def gradients(ys,
45              xs,
46              grad_ys=None,
47              name="gradients",
48              colocate_gradients_with_ops=False,
49              gate_gradients=False,
50              aggregation_method=None,
51              stop_gradients=None,
52              unconnected_gradients=UnconnectedGradients.NONE):
53  """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`.
54
55  `ys` and `xs` are each a `Tensor` or a list of tensors.  `grad_ys`
56  is a list of `Tensor`, holding the gradients received by the
57  `ys`. The list must be the same length as `ys`.
58
59  `gradients()` adds ops to the graph to output the derivatives of `ys` with
60  respect to `xs`.  It returns a list of `Tensor` of length `len(xs)` where
61  each tensor is the `sum(dy/dx)` for y in `ys` and for x in `xs`.
62
63  `grad_ys` is a list of tensors of the same length as `ys` that holds
64  the initial gradients for each y in `ys`.  When `grad_ys` is None,
65  we fill in a tensor of '1's of the shape of y for each y in `ys`.  A
66  user can provide their own initial `grad_ys` to compute the
67  derivatives using a different initial gradient for each y (e.g., if
68  one wanted to weight the gradient differently for each value in
69  each y).
70
71  `stop_gradients` is a `Tensor` or a list of tensors to be considered constant
72  with respect to all `xs`. These tensors will not be backpropagated through,
73  as though they had been explicitly disconnected using `stop_gradient`.  Among
74  other things, this allows computation of partial derivatives as opposed to
75  total derivatives. For example:
76
77  ```python
78  a = tf.constant(0.)
79  b = 2 * a
80  g = tf.gradients(a + b, [a, b], stop_gradients=[a, b])
81  ```
82
83  Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the
84  total derivatives `tf.gradients(a + b, [a, b])`, which take into account the
85  influence of `a` on `b` and evaluate to `[3.0, 1.0]`.  Note that the above is
86  equivalent to:
87
88  ```python
89  a = tf.stop_gradient(tf.constant(0.))
90  b = tf.stop_gradient(2 * a)
91  g = tf.gradients(a + b, [a, b])
92  ```
93
94  `stop_gradients` provides a way of stopping gradient after the graph has
95  already been constructed, as compared to `tf.stop_gradient` which is used
96  during graph construction.  When the two approaches are combined,
97  backpropagation stops at both `tf.stop_gradient` nodes and nodes in
98  `stop_gradients`, whichever is encountered first.
99
100  All integer tensors are considered constant with respect to all `xs`, as if
101  they were included in `stop_gradients`.
102
103  `unconnected_gradients` determines the value returned for each x in xs if it
104  is unconnected in the graph to ys. By default this is None to safeguard
105  against errors. Mathematically these gradients are zero which can be requested
106  using the `'zero'` option. `tf.UnconnectedGradients` provides the
107  following options and behaviors:
108
109  ```python
110  a = tf.ones([1, 2])
111  b = tf.ones([3, 1])
112  g1 = tf.gradients([b], [a], unconnected_gradients='none')
113  sess.run(g1)  # [None]
114
115  g2 = tf.gradients([b], [a], unconnected_gradients='zero')
116  sess.run(g2)  # [array([[0., 0.]], dtype=float32)]
117  ```
118
119  Let us take one practical example which comes during the back propogation
120  phase. This function is used to evaluate the derivatives of the cost function
121  with respect to Weights `Ws` and Biases `bs`. Below sample implementation
122  provides the exaplantion of what it is actually used for :
123
124  ```python
125  Ws = tf.constant(0.)
126  bs = 2 * Ws
127  cost = Ws + bs  # This is just an example. So, please ignore the formulas.
128  g = tf.gradients(cost, [Ws, bs])
129  dCost_dW, dCost_db = g
130  ```
131
132
133  Args:
134    ys: A `Tensor` or list of tensors to be differentiated.
135    xs: A `Tensor` or list of tensors to be used for differentiation.
136    grad_ys: Optional. A `Tensor` or list of tensors the same size as
137      `ys` and holding the gradients computed for each y in `ys`.
138    name: Optional name to use for grouping all the gradient ops together.
139      defaults to 'gradients'.
140    colocate_gradients_with_ops: If True, try colocating gradients with
141      the corresponding op.
142    gate_gradients: If True, add a tuple around the gradients returned
143      for an operations.  This avoids some race conditions.
144    aggregation_method: Specifies the method used to combine gradient terms.
145      Accepted values are constants defined in the class `AggregationMethod`.
146    stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate
147      through.
148    unconnected_gradients: Optional. Specifies the gradient value returned when
149      the given input tensors are unconnected. Accepted values are constants
150      defined in the class `tf.UnconnectedGradients` and the default value is
151      `none`.
152
153  Returns:
154    A list of `Tensor` of length `len(xs)` where each tensor is the `sum(dy/dx)`
155    for y in `ys` and for x in `xs`.
156
157  Raises:
158    LookupError: if one of the operations between `x` and `y` does not
159      have a registered gradient function.
160    ValueError: if the arguments are invalid.
161    RuntimeError: if called in Eager mode.
162
163  """
164  # Creating the gradient graph for control flow mutates Operations.
165  # _mutation_lock ensures a Session.run call cannot occur between creating and
166  # mutating new ops.
167  # pylint: disable=protected-access
168  with ops.get_default_graph()._mutation_lock():
169    return gradients_util._GradientsHelper(
170        ys, xs, grad_ys, name, colocate_gradients_with_ops,
171        gate_gradients, aggregation_method, stop_gradients,
172        unconnected_gradients)
173  # pylint: enable=protected-access
174
175
176@tf_export("gradients", v1=[])
177def gradients_v2(ys,  # pylint: disable=invalid-name
178                 xs,
179                 grad_ys=None,
180                 name="gradients",
181                 gate_gradients=False,
182                 aggregation_method=None,
183                 stop_gradients=None,
184                 unconnected_gradients=UnconnectedGradients.NONE):
185  """Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`.
186
187  `tf.gradients` is only valid in a graph context. In particular,
188  it is valid in the context of a `tf.function` wrapper, where code
189  is executing as a graph.
190
191  `ys` and `xs` are each a `Tensor` or a list of tensors.  `grad_ys`
192  is a list of `Tensor`, holding the gradients received by the
193  `ys`. The list must be the same length as `ys`.
194
195  `gradients()` adds ops to the graph to output the derivatives of `ys` with
196  respect to `xs`.  It returns a list of `Tensor` of length `len(xs)` where
197  each tensor is the `sum(dy/dx)` for y in `ys` and for x in `xs`.
198
199  `grad_ys` is a list of tensors of the same length as `ys` that holds
200  the initial gradients for each y in `ys`.  When `grad_ys` is None,
201  we fill in a tensor of '1's of the shape of y for each y in `ys`.  A
202  user can provide their own initial `grad_ys` to compute the
203  derivatives using a different initial gradient for each y (e.g., if
204  one wanted to weight the gradient differently for each value in
205  each y).
206
207  `stop_gradients` is a `Tensor` or a list of tensors to be considered constant
208  with respect to all `xs`. These tensors will not be backpropagated through,
209  as though they had been explicitly disconnected using `stop_gradient`.  Among
210  other things, this allows computation of partial derivatives as opposed to
211  total derivatives. For example:
212
213  >>> @tf.function
214  ... def example():
215  ...   a = tf.constant(0.)
216  ...   b = 2 * a
217  ...   return tf.gradients(a + b, [a, b], stop_gradients=[a, b])
218  >>> example()
219  [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
220  <tf.Tensor: shape=(), dtype=float32, numpy=1.0>]
221
222  Here the partial derivatives `g` evaluate to `[1.0, 1.0]`, compared to the
223  total derivatives `tf.gradients(a + b, [a, b])`, which take into account the
224  influence of `a` on `b` and evaluate to `[3.0, 1.0]`.  Note that the above is
225  equivalent to:
226
227  >>> @tf.function
228  ... def example():
229  ...   a = tf.stop_gradient(tf.constant(0.))
230  ...   b = tf.stop_gradient(2 * a)
231  ...   return tf.gradients(a + b, [a, b])
232  >>> example()
233  [<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,
234  <tf.Tensor: shape=(), dtype=float32, numpy=1.0>]
235
236  `stop_gradients` provides a way of stopping gradient after the graph has
237  already been constructed, as compared to `tf.stop_gradient` which is used
238  during graph construction.  When the two approaches are combined,
239  backpropagation stops at both `tf.stop_gradient` nodes and nodes in
240  `stop_gradients`, whichever is encountered first.
241
242  All integer tensors are considered constant with respect to all `xs`, as if
243  they were included in `stop_gradients`.
244
245  `unconnected_gradients` determines the value returned for each x in xs if it
246  is unconnected in the graph to ys. By default this is None to safeguard
247  against errors. Mathematically these gradients are zero which can be requested
248  using the `'zero'` option. `tf.UnconnectedGradients` provides the
249  following options and behaviors:
250
251  >>> @tf.function
252  ... def example(use_zero):
253  ...   a = tf.ones([1, 2])
254  ...   b = tf.ones([3, 1])
255  ...   if use_zero:
256  ...     return tf.gradients([b], [a], unconnected_gradients='zero')
257  ...   else:
258  ...     return tf.gradients([b], [a], unconnected_gradients='none')
259  >>> example(False)
260  [None]
261  >>> example(True)
262  [<tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0., 0.]], ...)>]
263
264  Let us take one practical example which comes during the back propogation
265  phase. This function is used to evaluate the derivatives of the cost function
266  with respect to Weights `Ws` and Biases `bs`. Below sample implementation
267  provides the exaplantion of what it is actually used for :
268
269  >>> @tf.function
270  ... def example():
271  ...   Ws = tf.constant(0.)
272  ...   bs = 2 * Ws
273  ...   cost = Ws + bs  # This is just an example. Please ignore the formulas.
274  ...   g = tf.gradients(cost, [Ws, bs])
275  ...   dCost_dW, dCost_db = g
276  ...   return dCost_dW, dCost_db
277  >>> example()
278  (<tf.Tensor: shape=(), dtype=float32, numpy=3.0>,
279  <tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
280
281  Args:
282    ys: A `Tensor` or list of tensors to be differentiated.
283    xs: A `Tensor` or list of tensors to be used for differentiation.
284    grad_ys: Optional. A `Tensor` or list of tensors the same size as
285      `ys` and holding the gradients computed for each y in `ys`.
286    name: Optional name to use for grouping all the gradient ops together.
287      defaults to 'gradients'.
288    gate_gradients: If True, add a tuple around the gradients returned
289      for an operations.  This avoids some race conditions.
290    aggregation_method: Specifies the method used to combine gradient terms.
291      Accepted values are constants defined in the class `AggregationMethod`.
292    stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate
293      through.
294    unconnected_gradients: Optional. Specifies the gradient value returned when
295      the given input tensors are unconnected. Accepted values are constants
296      defined in the class `tf.UnconnectedGradients` and the default value is
297      `none`.
298
299  Returns:
300    A list of `Tensor` of length `len(xs)` where each tensor is the `sum(dy/dx)`
301    for y in `ys` and for x in `xs`.
302
303  Raises:
304    LookupError: if one of the operations between `x` and `y` does not
305      have a registered gradient function.
306    ValueError: if the arguments are invalid.
307    RuntimeError: if called in Eager mode.
308
309  """
310  # Creating the gradient graph for control flow mutates Operations.
311  # _mutation_lock ensures a Session.run call cannot occur between creating and
312  # mutating new ops.
313  # pylint: disable=protected-access
314  with ops.get_default_graph()._mutation_lock():
315    return gradients_util._GradientsHelper(
316        ys, xs, grad_ys, name, True, gate_gradients,
317        aggregation_method, stop_gradients,
318        unconnected_gradients)
319  # pylint: enable=protected-access
320
321
322# TODO(vrv): Make this available when we want to make it public.
323def _hessian_vector_product(ys, xs, v):
324  """Multiply the Hessian of `ys` wrt `xs` by `v`.
325
326  This is an efficient construction that uses a backprop-like approach
327  to compute the product between the Hessian and another vector. The
328  Hessian is usually too large to be explicitly computed or even
329  represented, but this method allows us to at least multiply by it
330  for the same big-O cost as backprop.
331
332  Implicit Hessian-vector products are the main practical, scalable way
333  of using second derivatives with neural networks. They allow us to
334  do things like construct Krylov subspaces and approximate conjugate
335  gradient descent.
336
337  Example: if `y` = 1/2 `x`^T A `x`, then `hessian_vector_product(y,
338  x, v)` will return an expression that evaluates to the same values
339  as (A + A.T) `v`.
340
341  Args:
342    ys: A scalar value, or a tensor or list of tensors to be summed to
343        yield a scalar.
344    xs: A list of tensors that we should construct the Hessian over.
345    v: A list of tensors, with the same shapes as xs, that we want to
346       multiply by the Hessian.
347
348  Returns:
349    A list of tensors (or if the list would be length 1, a single tensor)
350    containing the product between the Hessian and `v`.
351
352  Raises:
353    ValueError: `xs` and `v` have different length.
354
355  """
356
357  # Validate the input
358  length = len(xs)
359  if len(v) != length:
360    raise ValueError("xs and v must have the same length.")
361
362  # First backprop
363  grads = gradients(ys, xs)
364
365  assert len(grads) == length
366  elemwise_products = [
367      math_ops.multiply(grad_elem, array_ops.stop_gradient(v_elem))
368      for grad_elem, v_elem in zip(grads, v)
369      if grad_elem is not None
370  ]
371
372  # Second backprop
373  return gradients(elemwise_products, xs)
374
375
376@tf_export(v1=["hessians"])
377def hessians(ys,
378             xs,
379             name="hessians",
380             colocate_gradients_with_ops=False,
381             gate_gradients=False,
382             aggregation_method=None):
383  """Constructs the Hessian of sum of `ys` with respect to `x` in `xs`.
384
385  `hessians()` adds ops to the graph to output the Hessian matrix of `ys`
386  with respect to `xs`.  It returns a list of `Tensor` of length `len(xs)`
387  where each tensor is the Hessian of `sum(ys)`.
388
389  The Hessian is a matrix of second-order partial derivatives of a scalar
390  tensor (see https://en.wikipedia.org/wiki/Hessian_matrix for more details).
391
392  Args:
393    ys: A `Tensor` or list of tensors to be differentiated.
394    xs: A `Tensor` or list of tensors to be used for differentiation.
395    name: Optional name to use for grouping all the gradient ops together.
396      defaults to 'hessians'.
397    colocate_gradients_with_ops: See `gradients()` documentation for details.
398    gate_gradients: See `gradients()` documentation for details.
399    aggregation_method: See `gradients()` documentation for details.
400
401  Returns:
402    A list of Hessian matrices of `sum(ys)` for each `x` in `xs`.
403
404  Raises:
405    LookupError: if one of the operations between `xs` and `ys` does not
406      have a registered gradient function.
407  """
408  xs = gradients_util._AsList(xs)  # pylint: disable=protected-access
409  kwargs = {
410      "colocate_gradients_with_ops": colocate_gradients_with_ops,
411      "gate_gradients": gate_gradients,
412      "aggregation_method": aggregation_method
413  }
414  # Compute first-order derivatives and iterate for each x in xs.
415  hessians = []
416  _gradients = gradients(ys, xs, **kwargs)
417  for gradient, x in zip(_gradients, xs):
418    # change shape to one-dimension without graph branching
419    gradient = array_ops.reshape(gradient, [-1])
420
421    # Declare an iterator and tensor array loop variables for the gradients.
422    n = array_ops.size(x)
423    loop_vars = [
424        array_ops.constant(0, dtypes.int32),
425        tensor_array_ops.TensorArray(x.dtype, n)
426    ]
427    # Iterate over all elements of the gradient and compute second order
428    # derivatives.
429    _, hessian = control_flow_ops.while_loop(
430        lambda j, _: j < n,
431        lambda j, result: (j + 1,
432                           result.write(j, gradients(gradient[j], x)[0])),
433        loop_vars
434    )
435
436    _shape = array_ops.shape(x)
437    _reshaped_hessian = array_ops.reshape(hessian.stack(),
438                                          array_ops.concat((_shape, _shape), 0))
439    hessians.append(_reshaped_hessian)
440  return hessians
441
442
443@tf_export("hessians", v1=[])
444def HessiansV2(ys,
445               xs,
446               gate_gradients=False,
447               aggregation_method=None,
448               name="hessians"):
449  """Constructs the Hessian of sum of `ys` with respect to `x` in `xs`.
450
451  `hessians()` adds ops to the graph to output the Hessian matrix of `ys`
452  with respect to `xs`.  It returns a list of `Tensor` of length `len(xs)`
453  where each tensor is the Hessian of `sum(ys)`.
454
455  The Hessian is a matrix of second-order partial derivatives of a scalar
456  tensor (see https://en.wikipedia.org/wiki/Hessian_matrix for more details).
457
458  Args:
459    ys: A `Tensor` or list of tensors to be differentiated.
460    xs: A `Tensor` or list of tensors to be used for differentiation.
461    gate_gradients: See `gradients()` documentation for details.
462    aggregation_method: See `gradients()` documentation for details.
463    name: Optional name to use for grouping all the gradient ops together.
464      defaults to 'hessians'.
465
466  Returns:
467    A list of Hessian matrices of `sum(ys)` for each `x` in `xs`.
468
469  Raises:
470    LookupError: if one of the operations between `xs` and `ys` does not
471      have a registered gradient function.
472  """
473  return hessians(
474      ys,
475      xs,
476      name=name,
477      colocate_gradients_with_ops=True,
478      gate_gradients=gate_gradients,
479      aggregation_method=aggregation_method)
480