1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Gradient checker for any ops, graphs.
17
18The gradient checker verifies numerically that an op/graph properly
19computes the gradients
20"""
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import numpy as np
26
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import gradients
32from tensorflow.python.ops import math_ops
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.util import deprecation
35from tensorflow.python.util.tf_export import tf_export
36
37
38def _product(t):
39  if isinstance(t, int):
40    return t
41  else:
42    y = 1
43    for x in t:
44      y *= x
45    return y
46
47
48def _extra_feeds(extra_feed_dict, new_feeds):
49  if not extra_feed_dict:
50    return new_feeds
51  r = {}
52  r.update(extra_feed_dict)
53  r.update(new_feeds)
54  return r
55
56
57def _compute_theoretical_jacobian(x, x_shape, x_data, dy, dy_shape, dx,
58                                  extra_feed_dict):
59  """Computes the theoretical Jacobian for dy/dx.
60
61  Computes the theoretical Jacobian using the ops generated by
62  compute_gradient().
63
64  Args:
65    x: the tensor "x".
66    x_shape: the dimensions of x as a tuple or an array of ints.
67    x_data: a numpy parray as the input data for x
68    dy: the tensor "dy".
69    dy_shape: the dimensions of dy as a tuple or an array of ints.
70    dx: Tensor or IndexedSlices representing dx
71    extra_feed_dict: dict that allows fixing specified tensor values
72      during the jacobian calculation.
73
74  Returns:
75    A 2-d numpy array representing the Jacobian for dy/dx. It has "x_size" rows
76    and "dy_size" columns where "x_size" is the number of elements in x and
77    "dy_size" is the number of elements in dy.
78
79  Raises:
80    ValueError: If `dy` is empty but the gradient is nonzero.
81  """
82  # Complex vectors are treated as vectors of twice as many reals.
83  if x.dtype.is_complex:
84    x_shape = tuple(x_shape) + (2,)
85  dy_factor = 2 if dy.dtype.is_complex else 1
86
87  # To compute the jacobian, we treat x and y as one-dimensional vectors.
88  x_size = _product(x_shape)
89  x_val_size = _product(x_shape[1:])  # This is used for sparse gradients
90  dy_size = _product(dy_shape) * dy_factor
91
92  # Allocate 2-D Jacobian, with x dimensions smashed into the first
93  # dimension and y dimensions smashed into the second.
94  jacobian = np.zeros((x_size, dy_size),
95                      dtype=x.dtype.real_dtype.as_numpy_dtype)
96
97  # For each of the entry of dy, we set this to be 1 and
98  # everything else to be 0 and compute the backprop -- this will give us one
99  # one column of the Jacobian matrix.
100  dy_data = np.zeros(dy_shape, dtype=dy.dtype.as_numpy_dtype)
101  dy_data_flat = dy_data.ravel().view(dy.dtype.real_dtype.as_numpy_dtype)
102  sess = ops.get_default_session()
103  for col in range(dy_size):
104    dy_data_flat[col] = 1
105    if isinstance(dx, ops.IndexedSlices):
106      backprop_indices, backprop_values = sess.run(
107          [dx.indices, dx.values],
108          feed_dict=_extra_feeds(extra_feed_dict, {x: x_data, dy: dy_data}))
109      for i, v in zip(backprop_indices, backprop_values):
110        r_begin = i * x_val_size
111        r_end = r_begin + x_val_size
112        jacobian[r_begin:r_end, col] += v.flat
113    else:
114      assert isinstance(dx, ops.Tensor), "dx = " + str(dx)
115      backprop = sess.run(
116          dx, feed_dict=_extra_feeds(extra_feed_dict, {x: x_data, dy: dy_data}))
117      jacobian[:, col] = backprop.ravel().view(jacobian.dtype)
118    dy_data_flat[col] = 0
119
120  # If the output is empty, run the gradients at least once and make sure
121  # they produce zeros.
122  if not dy_size:
123    backprop = sess.run(
124        dx, feed_dict=_extra_feeds(extra_feed_dict, {x: x_data, dy: dy_data}))
125    if backprop.shape != x_data.shape:
126      raise ValueError("Empty gradient has wrong shape: expected %s, got %s" %
127                       (x_data.shape, backprop.shape))
128    if np.any(backprop):
129      raise ValueError("Empty tensor with nonzero gradients")
130
131  logging.vlog(1, "Theoretical Jacobian =\n%s", jacobian)
132  return jacobian
133
134
135def _compute_numeric_jacobian(x, x_shape, x_data, y, y_shape, delta,
136                              extra_feed_dict):
137  """Computes the numeric Jacobian for dy/dx.
138
139  Computes the numeric Jacobian by slightly perturbing the inputs and
140  measuring the differences on the output.
141
142  Args:
143    x: the tensor "x".
144    x_shape: the dimensions of x as a tuple or an array of ints.
145    x_data: a numpy array as the input data for x
146    y: the tensor "y".
147    y_shape: the dimensions of y as a tuple or an array of ints.
148    delta: the amount of perturbation we give to the input
149    extra_feed_dict: dict that allows fixing specified tensor values
150      during the jacobian calculation.
151
152  Returns:
153    A 2-d numpy array representing the Jacobian for dy/dx. It has "x_size" rows
154    and "y_size" columns where "x_size" is the number of elements in x and
155    "y_size" is the number of elements in y.
156  """
157  # bfloat16 doesn't have enough bits to represent high precision numbers such
158  # as delta. Convert to float32 here. Since numeric_jacobian is expected to
159  # be the groundtruth to compare against, it shouldn't lose any information.
160  if x.dtype == dtypes.bfloat16:
161    x = math_ops.cast(x, dtypes.float32)  # TODO(wangpeng): Now that the new x
162            # is an output of the old x, isn't feeding to the new x a mistake?
163  if y.dtype == dtypes.bfloat16:
164    y = math_ops.cast(y, dtypes.float32)
165  if x_data.dtype == dtypes.bfloat16.as_numpy_dtype:
166    x_data = x_data.astype(np.float32)
167
168  # To compute the jacobian, we treat x and y as one-dimensional vectors
169  x_size = _product(x_shape) * (2 if x.dtype.is_complex else 1)
170  y_size = _product(y_shape) * (2 if y.dtype.is_complex else 1)
171  x_dtype = x.dtype.real_dtype.as_numpy_dtype
172  y_dtype = y.dtype.real_dtype.as_numpy_dtype
173
174  # Make sure we have the right types
175  x_data = np.asarray(x_data, dtype=x.dtype.as_numpy_dtype)
176  scale = np.asarray(2 * delta, dtype=y_dtype)[()]
177
178  jacobian = np.zeros((x_size, y_size), dtype=x_dtype)
179  # For each of the entry of x, we slightly perturbs this by adding and
180  # subtracting a delta and then compute difference between the outputs. This
181  # will give us one row of the Jacobian matrix.
182  for row in range(x_size):
183    x_pos = x_data.copy()
184    x_neg = x_data.copy()
185    x_pos.ravel().view(x_dtype)[row] += delta
186    y_pos = y.eval(feed_dict=_extra_feeds(extra_feed_dict, {x: x_pos}))
187    x_neg.ravel().view(x_dtype)[row] -= delta
188    y_neg = y.eval(feed_dict=_extra_feeds(extra_feed_dict, {x: x_neg}))
189    diff = (y_pos - y_neg) / scale
190    jacobian[row, :] = diff.ravel().view(y_dtype)
191
192  logging.vlog(1, "Numeric Jacobian =\n%s", jacobian)
193  return jacobian
194
195
196def _compute_dx_and_dy(x, y, y_shape):
197  """Returns a node to compute gradient of y wrt x."""
198  # We make up a dy so that we can compute the gradients. We don't really use
199  # the value of dy -- we will always feed it. We need to add an identity node
200  # so that we can always feed it properly. Otherwise, for the Add operation,
201  # dx is the same as dy and we cannot fetch the tensor that we are feeding.
202  with x.graph.as_default():
203    dy_orig = constant_op.constant(1.0, shape=y_shape, dtype=y.dtype)
204    dy = array_ops.identity(dy_orig)
205  # We compute the gradients for y wrt. x
206  grads = gradients.gradients(y, x, dy)
207  assert len(grads) == 1
208  return grads[0], dy_orig
209
210
211def _compute_gradient(x,
212                      x_shape,
213                      dx,
214                      y,
215                      y_shape,
216                      dy,
217                      x_init_value=None,
218                      delta=1e-3,
219                      extra_feed_dict=None):
220  """Computes the theoretical and numerical jacobian."""
221  t = dtypes.as_dtype(x.dtype)
222  allowed_types = [dtypes.float16, dtypes.bfloat16, dtypes.float32,
223                   dtypes.float64, dtypes.complex64, dtypes.complex128]
224  assert t.base_dtype in allowed_types, "Don't support type %s for x" % t.name
225  t2 = dtypes.as_dtype(y.dtype)
226  assert t2.base_dtype in allowed_types, "Don't support type %s for y" % t2.name
227
228  if x_init_value is not None:
229    i_shape = list(x_init_value.shape)
230    assert(list(x_shape) == i_shape), "x_shape = %s, init_data shape = %s" % (
231        x_shape, i_shape)
232    x_data = x_init_value
233  else:
234    x_data = np.random.random_sample(x_shape).astype(t.as_numpy_dtype)
235    if t.is_complex:
236      x_data.imag = np.random.random_sample(x_shape)
237
238  jacob_t = _compute_theoretical_jacobian(
239      x, x_shape, x_data, dy, y_shape, dx, extra_feed_dict=extra_feed_dict)
240  jacob_n = _compute_numeric_jacobian(
241      x, x_shape, x_data, y, y_shape, delta, extra_feed_dict=extra_feed_dict)
242  return jacob_t, jacob_n
243
244
245def _compute_gradient_list(x,
246                           x_shape,
247                           y,
248                           y_shape,
249                           x_init_value=None,
250                           delta=1e-3,
251                           init_targets=None,
252                           extra_feed_dict=None):
253  """Compute gradients for a list of x values."""
254  assert isinstance(x, list)
255  dx, dy = zip(*[_compute_dx_and_dy(xi, y, y_shape) for xi in x])
256
257  if init_targets is not None:
258    assert isinstance(init_targets, (list, tuple))
259    for init in init_targets:
260      init.run()
261  if x_init_value is None:
262    x_init_value = [None] * len(x)
263  # pylint: disable=g-complex-comprehension
264  ret = [_compute_gradient(xi, x_shapei, dxi, y, y_shape, dyi, x_init_valuei,
265                           delta, extra_feed_dict=extra_feed_dict)
266         for xi, x_shapei, dxi, dyi, x_init_valuei in zip(x, x_shape, dx, dy,
267                                                          x_init_value)]
268  return ret
269
270
271@tf_export(v1=["test.compute_gradient"])
272@deprecation.deprecated(
273    date=None,
274    instructions="Use tf.test.compute_gradient in 2.0, which has better "
275    "support for functions. Note that the two versions have different usage, "
276    "so code change is needed.")
277def compute_gradient(x,
278                     x_shape,
279                     y,
280                     y_shape,
281                     x_init_value=None,
282                     delta=1e-3,
283                     init_targets=None,
284                     extra_feed_dict=None):
285  """Computes and returns the theoretical and numerical Jacobian.
286
287  If `x` or `y` is complex, the Jacobian will still be real but the
288  corresponding Jacobian dimension(s) will be twice as large.  This is required
289  even if both input and output is complex since TensorFlow graphs are not
290  necessarily holomorphic, and may have gradients not expressible as complex
291  numbers.  For example, if `x` is complex with shape `[m]` and `y` is complex
292  with shape `[n]`, each Jacobian `J` will have shape `[m * 2, n * 2]` with
293
294      J[:m, :n] = d(Re y)/d(Re x)
295      J[:m, n:] = d(Im y)/d(Re x)
296      J[m:, :n] = d(Re y)/d(Im x)
297      J[m:, n:] = d(Im y)/d(Im x)
298
299  Args:
300    x: a tensor or list of tensors
301    x_shape: the dimensions of x as a tuple or an array of ints. If x is a list,
302    then this is the list of shapes.
303    y: a tensor
304    y_shape: the dimensions of y as a tuple or an array of ints.
305    x_init_value: (optional) a numpy array of the same shape as "x"
306      representing the initial value of x. If x is a list, this should be a list
307      of numpy arrays.  If this is none, the function will pick a random tensor
308      as the initial value.
309    delta: (optional) the amount of perturbation.
310    init_targets: list of targets to run to initialize model params.
311    extra_feed_dict: dict that allows fixing specified tensor values
312      during the Jacobian calculation.
313
314  Returns:
315    Two 2-d numpy arrays representing the theoretical and numerical
316    Jacobian for dy/dx. Each has "x_size" rows and "y_size" columns
317    where "x_size" is the number of elements in x and "y_size" is the
318    number of elements in y. If x is a list, returns a list of two numpy arrays.
319  """
320  # TODO(mrry): remove argument `init_targets`
321  if extra_feed_dict is None:
322    extra_feed_dict = {}
323
324  if isinstance(x, list):
325    return _compute_gradient_list(x, x_shape, y, y_shape, x_init_value, delta,
326                                  init_targets, extra_feed_dict=extra_feed_dict)
327  else:
328    if init_targets is not None:
329      assert isinstance(init_targets, (list, tuple))
330      for init in init_targets:
331        init.run()
332    dx, dy = _compute_dx_and_dy(x, y, y_shape)
333    ret = _compute_gradient(x, x_shape, dx, y, y_shape, dy, x_init_value, delta,
334                            extra_feed_dict=extra_feed_dict)
335    return ret
336
337
338def _compute_error(grad):
339  if isinstance(grad, tuple):
340    grad = [grad]
341  error = 0
342  for j_t, j_n in grad:
343    if j_t.size or j_n.size:  # Handle zero size tensors correctly
344      error = np.maximum(error, np.fabs(j_t - j_n).max())
345  return error
346
347
348@tf_export(v1=["test.compute_gradient_error"])
349@deprecation.deprecated(
350    date=None,
351    instructions="Use tf.test.compute_gradient in 2.0, which has better "
352    "support for functions. Note that the two versions have different usage, "
353    "so code change is needed.")
354def compute_gradient_error(x,
355                           x_shape,
356                           y,
357                           y_shape,
358                           x_init_value=None,
359                           delta=1e-3,
360                           init_targets=None,
361                           extra_feed_dict=None):
362  """Computes the gradient error.
363
364  Computes the maximum error for dy/dx between the computed Jacobian and the
365  numerically estimated Jacobian.
366
367  This function will modify the tensors passed in as it adds more operations
368  and hence changing the consumers of the operations of the input tensors.
369
370  This function adds operations to the current session. To compute the error
371  using a particular device, such as a GPU, use the standard methods for
372  setting a device (e.g. using with sess.graph.device() or setting a device
373  function in the session constructor).
374
375  Args:
376    x: a tensor or list of tensors
377    x_shape: the dimensions of x as a tuple or an array of ints. If x is a list,
378    then this is the list of shapes.
379    y: a tensor
380    y_shape: the dimensions of y as a tuple or an array of ints.
381    x_init_value: (optional) a numpy array of the same shape as "x"
382      representing the initial value of x. If x is a list, this should be a list
383      of numpy arrays.  If this is none, the function will pick a random tensor
384      as the initial value.
385    delta: (optional) the amount of perturbation.
386    init_targets: list of targets to run to initialize model params.
387    extra_feed_dict: dict that allows fixing specified tensor values
388      during the Jacobian calculation.
389
390  Returns:
391    The maximum error in between the two Jacobians.
392  """
393  grad = compute_gradient(x, x_shape, y, y_shape, x_init_value, delta,
394                          init_targets, extra_feed_dict=extra_feed_dict)
395  return _compute_error(grad)
396