1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tf.test.compute_gradient and tf.compute_gradient_error."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import gradient_checker
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import nn_ops
31import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
32from tensorflow.python.platform import test
33from tensorflow.python.platform import tf_logging
34
35
36@ops.RegisterGradient("BadGrad")
37def _bad_grad(unused_op, grad):
38  """A gradient that returns the wrong shape."""
39  return array_ops.transpose(grad)
40
41
42@ops.RegisterGradient("NaNGrad")
43def _nan_grad(unused_op, grad):
44  """A gradient that returns NaN."""
45  return np.nan * grad
46
47
48class GradientCheckerTest(test.TestCase):
49
50  @test_util.run_deprecated_v1
51  def testAddSimple(self):
52    np.random.seed(1)  # Fix seed to avoid flakiness
53    with self.session(use_gpu=False):
54      # a test case for Add operation
55      size = (2, 3)
56      x1 = constant_op.constant(2.0, shape=size, name="x1")
57      x2 = constant_op.constant(3.0, shape=size, name="x2")
58      y = math_ops.add(x1, x2, name="y")
59
60      # checking gradients for x1
61      error = gradient_checker.compute_gradient_error(x1, size, y, size)
62    tf_logging.info("x1 error = %f", error)
63    assert error < 1e-4
64
65  @test_util.run_deprecated_v1
66  def testAddSimpleGPU(self):
67    np.random.seed(2)  # Fix seed to avoid flakiness
68    with self.session(use_gpu=True):
69      # a test case for Add operation
70      size = (2, 3)
71      x1 = constant_op.constant(2.0, shape=size, name="x1")
72      x2 = constant_op.constant(3.0, shape=size, name="x2")
73      y = math_ops.add(x1, x2, name="y")
74
75      # checking gradients for x1
76      error = gradient_checker.compute_gradient_error(x1, size, y, size)
77    tf_logging.info("x1 error = %f", error)
78    assert error < 1e-4
79
80  @test_util.run_deprecated_v1
81  def testAddCustomized(self):
82    np.random.seed(3)  # Fix seed to avoid flakiness
83    with self.cached_session():
84      # a test case for Add operation
85      size = (2, 3)
86      x1 = constant_op.constant(
87          2.0, shape=size, dtype=dtypes.float64, name="x1")
88      x2 = constant_op.constant(
89          3.0, shape=size, dtype=dtypes.float64, name="x2")
90      y = math_ops.add(x1, x2, name="y")
91
92      # checkint gradients for x2 using a special init_value and delta
93      x_init_value = np.asarray(np.arange(6, dtype=np.float64).reshape(2, 3))
94      error = gradient_checker.compute_gradient_error(
95          x2, size, y, size, x_init_value=x_init_value, delta=1e-2)
96    tf_logging.info("x2 error = %f", error)
97    assert error < 1e-10
98
99  @test_util.run_deprecated_v1
100  def testGather(self):
101    np.random.seed(4)  # Fix seed to avoid flakiness
102    with self.cached_session():
103      p_shape = (4, 2)
104      p_size = 8
105      index_values = [1, 3]
106      y_shape = [2, 2]
107      params = constant_op.constant(
108          np.arange(p_size).astype(np.float), shape=p_shape, name="p")
109      indices = constant_op.constant(index_values, name="i")
110      y = array_ops.gather(params, indices, name="y")
111
112      error = gradient_checker.compute_gradient_error(params, p_shape, y,
113                                                      y_shape)
114    tf_logging.info("gather error = %f", error)
115    assert error < 1e-4
116
117  @test_util.run_deprecated_v1
118  def testNestedGather(self):
119    np.random.seed(5)  # Fix seed to avoid flakiness
120    with self.cached_session():
121      p_shape = (8, 2)
122      p_size = 16
123      index_values = [1, 3, 5, 6]
124      index_values2 = [0, 2]
125      y2_shape = [2, 2]
126
127      params = constant_op.constant(
128          np.arange(p_size).astype(np.float), shape=p_shape, name="p")
129      indices = constant_op.constant(index_values, name="i")
130      y = array_ops.gather(params, indices, name="y")
131      indices2 = constant_op.constant(index_values2, name="i2")
132      y2 = array_ops.gather(y, indices2, name="y2")
133
134      error = gradient_checker.compute_gradient_error(params, p_shape, y2,
135                                                      y2_shape)
136    tf_logging.info("nested gather error = %f", error)
137    assert error < 1e-4
138
139  @test_util.run_deprecated_v1
140  def testComplexMul(self):
141    with self.cached_session():
142      size = ()
143      c = constant_op.constant(5 + 7j, dtype=dtypes.complex64)
144      x = constant_op.constant(11 - 13j, dtype=dtypes.complex64)
145      y = c * x
146      analytical, numerical = gradient_checker.compute_gradient(x, size, y,
147                                                                size)
148      correct = np.array([[5, 7], [-7, 5]])
149      self.assertAllEqual(correct, analytical)
150      self.assertAllClose(correct, numerical, rtol=1e-4)
151      self.assertLess(
152          gradient_checker.compute_gradient_error(x, size, y, size), 2e-4)
153
154  @test_util.run_deprecated_v1
155  def testComplexConj(self):
156    with self.cached_session():
157      size = ()
158      x = constant_op.constant(11 - 13j, dtype=dtypes.complex64)
159      y = math_ops.conj(x)
160      analytical, numerical = gradient_checker.compute_gradient(x, size, y,
161                                                                size)
162      correct = np.array([[1, 0], [0, -1]])
163      self.assertAllEqual(correct, analytical)
164      self.assertAllClose(correct, numerical, rtol=2e-5)
165      self.assertLess(
166          gradient_checker.compute_gradient_error(x, size, y, size), 2e-5)
167
168  @test_util.run_deprecated_v1
169  def testEmptySucceeds(self):
170    with self.cached_session():
171      x = array_ops.placeholder(dtypes.float32)
172      y = array_ops.identity(x)
173      for grad in gradient_checker.compute_gradient(x, (0, 3), y, (0, 3)):
174        self.assertEqual(grad.shape, (0, 0))
175      error = gradient_checker.compute_gradient_error(x, (0, 3), y, (0, 3))
176      self.assertEqual(error, 0)
177
178  def testEmptyFails(self):
179    with ops.Graph().as_default() as g:
180      with self.session(graph=g):
181        x = array_ops.placeholder(dtypes.float32)
182        with g.gradient_override_map({"Identity": "BadGrad"}):
183          y = array_ops.identity(x)
184        bad = r"Empty gradient has wrong shape: expected \(0, 3\), got \(3, 0\)"
185        with self.assertRaisesRegexp(ValueError, bad):
186          gradient_checker.compute_gradient(x, (0, 3), y, (0, 3))
187        with self.assertRaisesRegexp(ValueError, bad):
188          gradient_checker.compute_gradient_error(x, (0, 3), y, (0, 3))
189
190  def testNaNGradFails(self):
191    with ops.Graph().as_default() as g:
192      with self.session(graph=g):
193        x = array_ops.placeholder(dtypes.float32)
194        with g.gradient_override_map({"Identity": "NaNGrad"}):
195          y = array_ops.identity(x)
196          error = gradient_checker.compute_gradient_error(x, (), y, ())
197          # Typical test would assert error < max_err, so assert this test would
198          # raise AssertionError, since NaN is not < 1.0.
199          with self.assertRaisesRegexp(AssertionError, "False is not true"):
200            self.assertTrue(error < 1.0)
201
202
203class MiniMNISTTest(test.TestCase):
204
205  # Gradient checker for MNIST.
206  def _BuildAndTestMiniMNIST(self, param_index, tag):
207    # Fix seed to avoid occasional flakiness
208    np.random.seed(6)
209
210    # Hyperparameters
211    batch = 3
212    inputs = 16
213    features = 32
214    classes = 10
215
216    # Define the parameters
217    inp_data = np.random.random_sample(inputs * batch)
218    hidden_weight_data = np.random.randn(inputs * features) / np.sqrt(inputs)
219    hidden_bias_data = np.random.random_sample(features)
220    sm_weight_data = np.random.randn(features * classes) / np.sqrt(features)
221    sm_bias_data = np.random.random_sample(classes)
222
223    # special care for labels since they need to be normalized per batch
224    label_data = np.random.random(batch * classes).reshape((batch, classes))
225    s = label_data.sum(axis=1)
226    label_data /= s[:, None]
227
228    with self.session(use_gpu=True):
229      # We treat the inputs as "parameters" here
230      inp = constant_op.constant(
231          inp_data.tolist(),
232          shape=[batch, inputs],
233          dtype=dtypes.float64,
234          name="inp")
235      hidden_weight = constant_op.constant(
236          hidden_weight_data.tolist(),
237          shape=[inputs, features],
238          dtype=dtypes.float64,
239          name="hidden_weight")
240      hidden_bias = constant_op.constant(
241          hidden_bias_data.tolist(),
242          shape=[features],
243          dtype=dtypes.float64,
244          name="hidden_bias")
245      softmax_weight = constant_op.constant(
246          sm_weight_data.tolist(),
247          shape=[features, classes],
248          dtype=dtypes.float64,
249          name="softmax_weight")
250      softmax_bias = constant_op.constant(
251          sm_bias_data.tolist(),
252          shape=[classes],
253          dtype=dtypes.float64,
254          name="softmax_bias")
255
256      # List all the parameter so that we can test them one at a time
257      all_params = [
258          inp, hidden_weight, hidden_bias, softmax_weight, softmax_bias
259      ]
260      param_sizes = [
261          [batch, inputs],  # inp
262          [inputs, features],  # hidden_weight,
263          [features],  # hidden_bias
264          [features, classes],  # softmax_weight,
265          [classes]
266      ]  # softmax_bias
267
268      # Now, Building MNIST
269      features = nn_ops.relu(
270          nn_ops.xw_plus_b(inp, hidden_weight, hidden_bias), name="features")
271      logits = nn_ops.xw_plus_b(
272          features, softmax_weight, softmax_bias, name="logits")
273      labels = constant_op.constant(
274          label_data.tolist(),
275          shape=[batch, classes],
276          dtype=dtypes.float64,
277          name="labels")
278      cost = nn_ops.softmax_cross_entropy_with_logits(
279          labels=labels, logits=logits, name="cost")
280
281      # Test the gradients.
282      err = gradient_checker.compute_gradient_error(
283          all_params[param_index],
284          param_sizes[param_index],
285          cost, [batch],
286          delta=1e-5)
287
288    tf_logging.info("Mini MNIST: %s gradient error = %g", tag, err)
289    return err
290
291  @test_util.run_deprecated_v1
292  def testInputGradient(self):
293    self.assertLess(self._BuildAndTestMiniMNIST(0, "input"), 1e-8)
294
295  @test_util.run_deprecated_v1
296  def testHiddenWeightGradient(self):
297    self.assertLess(self._BuildAndTestMiniMNIST(1, "hidden_weight"), 1e-8)
298
299  @test_util.run_deprecated_v1
300  def testHiddenBiasGradient(self):
301    self.assertLess(self._BuildAndTestMiniMNIST(2, "hidden_bias"), 1e-8)
302
303  @test_util.run_deprecated_v1
304  def testSoftmaxWeightGradient(self):
305    self.assertLess(self._BuildAndTestMiniMNIST(3, "softmax_weight"), 1e-8)
306
307  @test_util.run_deprecated_v1
308  def testSoftmaxBiasGradient(self):
309    self.assertLess(self._BuildAndTestMiniMNIST(4, "softmax_bias"), 1e-8)
310
311
312if __name__ == "__main__":
313  test.main()
314