1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for compute_gradient.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numpy as np
23
24from tensorflow.python.eager import backprop
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import custom_gradient
30from tensorflow.python.ops import \
31gradient_checker_v2 as gradient_checker
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import nn_ops
34# needs this to register gradient for SoftmaxCrossEntropyWithLogits:
35import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
36from tensorflow.python.platform import test
37from tensorflow.python.platform import tf_logging
38
39
40def _random_complex(shape, dtype):
41  data = np.random.random_sample(shape).astype(dtype.as_numpy_dtype)
42  if dtype.is_complex:
43    data.imag = np.random.random_sample(shape)
44  return data
45
46
47@test_util.run_all_in_graph_and_eager_modes
48class GradientCheckerTest(test.TestCase):
49
50  def testAddSimple(self):
51    size = (2, 3)
52    x1 = constant_op.constant(2.0, shape=size, name="x1")
53    x2 = constant_op.constant(3.0, shape=size, name="x2")
54    error = gradient_checker.max_error(*gradient_checker.compute_gradient(
55        lambda x1: math_ops.add(x1, x2), [x1]))
56    tf_logging.info("x1 error = %f", error)
57    assert error < 1e-4
58
59  def testAddCustomized(self):
60    size = (2, 3)
61    x1 = constant_op.constant(
62        2.0, shape=size, dtype=dtypes.float64, name="x1")
63    x2 = np.asarray(np.arange(6, dtype=np.float64).reshape(2, 3))
64    # checkint gradients for x2 using a special delta
65    error = gradient_checker.max_error(*gradient_checker.compute_gradient(
66        lambda x2: math_ops.add(x1, x2),
67        [x2], delta=1e-2))
68    tf_logging.info("x2 error = %f", error)
69    assert error < 1e-10
70
71  def testGather(self):
72    def f(params):
73      index_values = [1, 3]
74      indices = constant_op.constant(index_values, name="i")
75      return array_ops.gather(params, indices, name="y")
76    p_shape = (4, 2)
77    p_size = 8
78    params = constant_op.constant(
79        np.arange(p_size).astype(np.float), shape=p_shape, name="p")
80    error = gradient_checker.max_error(*gradient_checker.compute_gradient(
81        f, [params]))
82    tf_logging.info("gather error = %f", error)
83    assert error < 1e-4
84
85  def testNestedGather(self):
86    def f(params):
87      index_values = [1, 3, 5, 6]
88      indices = constant_op.constant(index_values, name="i")
89      y = array_ops.gather(params, indices, name="y")
90      index_values2 = [0, 2]
91      indices2 = constant_op.constant(index_values2, name="i2")
92      return array_ops.gather(y, indices2, name="y2")
93    p_shape = (8, 2)
94    p_size = 16
95    params = constant_op.constant(
96        np.arange(p_size).astype(np.float), shape=p_shape, name="p")
97    error = gradient_checker.max_error(*gradient_checker.compute_gradient(
98        f, [params]))
99    tf_logging.info("nested gather error = %f", error)
100    assert error < 1e-4
101
102  def testComplexMul(self):
103    c = constant_op.constant(5 + 7j, dtype=dtypes.complex64)
104    def f(x):
105      return c * x
106    x_shape = c.shape
107    x_dtype = c.dtype
108    x = constant_op.constant(_random_complex(x_shape, x_dtype))
109    analytical, numerical = gradient_checker.compute_gradient(
110        f, [x])
111    correct = np.array([[5, 7], [-7, 5]])
112    self.assertAllEqual(correct, analytical[0])
113    self.assertAllClose(correct, numerical[0], rtol=1e-4)
114    x = constant_op.constant(_random_complex(x_shape, x_dtype))
115    self.assertLess(
116        gradient_checker.max_error(*gradient_checker.compute_gradient(
117            f, [x])), 3e-4)
118
119  def testComplexConj(self):
120    def f(x):
121      return math_ops.conj(x)
122    x_shape = ()
123    x_dtype = dtypes.complex64
124    x = constant_op.constant(_random_complex(x_shape, x_dtype))
125    analytical, numerical = gradient_checker.compute_gradient(
126        f, [x])
127    correct = np.array([[1, 0], [0, -1]])
128    self.assertAllEqual(correct, analytical[0])
129    self.assertAllClose(correct, numerical[0], rtol=2e-5)
130    x = constant_op.constant(_random_complex(x_shape, x_dtype))
131    self.assertLess(
132        gradient_checker.max_error(*gradient_checker.compute_gradient(
133            f, [x])), 2e-5)
134
135  def testEmptySucceeds(self):
136    def f(x):
137      return array_ops.identity(x)
138    x = constant_op.constant(np.random.random_sample((0, 3)),
139                             dtype=dtypes.float32)
140    for grad in gradient_checker.compute_gradient(f, [x]):
141      self.assertEqual(grad[0].shape, (0, 0))
142    error = gradient_checker.max_error(*gradient_checker.compute_gradient(
143        f, [x]))
144    self.assertEqual(error, 0)
145
146  def testEmptyFails(self):
147    @custom_gradient.custom_gradient
148    def id_bad_grad(x):
149      y = array_ops.identity(x)
150      def grad_fn(dy):
151        # dx = constant_op.constant(np.zeros((1, 4)), dtype=dtypes.float32)
152        dx = array_ops.transpose(dy)
153        return dx
154      return y, grad_fn
155    def f(x):
156      return id_bad_grad(x)
157    x = constant_op.constant(np.random.random_sample((0, 3)),
158                             dtype=dtypes.float32)
159    bad = r"Empty gradient has wrong shape: expected \(0, 3\), got \(3, 0\)"
160    with self.assertRaisesRegexp(ValueError, bad):
161      gradient_checker.compute_gradient(f, [x])
162
163  def testNaNGradFails(self):
164    @custom_gradient.custom_gradient
165    def id_nan_grad(x):
166      y = array_ops.identity(x)
167      def grad_fn(dy):
168        dx = np.nan * dy
169        # dx = dy
170        return dx
171      return y, grad_fn
172    def f(x):
173      return id_nan_grad(x)
174    x = constant_op.constant(np.random.random_sample((1, 1)),
175                             dtype=dtypes.float32)
176    error = gradient_checker.max_error(*gradient_checker.compute_gradient(
177        f, [x]))
178    # Typical test would assert error < max_err, so assert this test would
179    # raise AssertionError, since NaN is not < 1.0.
180    with self.assertRaisesRegexp(AssertionError, "False is not true"):
181      self.assertTrue(error < 1.0)
182
183  def testGradGrad(self):
184
185    def f(x):
186      with backprop.GradientTape() as tape:
187        tape.watch(x)
188        y = math_ops.square(x)
189        z = math_ops.square(y)
190      return tape.gradient(z, x)
191
192    analytical, numerical = gradient_checker.compute_gradient(f, [2.0])
193    self.assertAllEqual([[[48.]]], analytical)
194    self.assertAllClose([[[48.]]], numerical, rtol=1e-4)
195
196
197@test_util.run_all_in_graph_and_eager_modes
198class MiniMNISTTest(test.TestCase):
199
200  # Gradient checker for MNIST.
201  def _BuildAndTestMiniMNIST(self, param_index, tag):
202    # Fix seed to avoid occasional flakiness
203    np.random.seed(6)
204
205    # Hyperparameters
206    batch = 3
207    inputs = 16
208    features = 32
209    classes = 10
210
211    # Define the parameters
212    inp_data = np.random.random_sample(inputs * batch)
213    hidden_weight_data = np.random.randn(inputs * features) / np.sqrt(inputs)
214    hidden_bias_data = np.random.random_sample(features)
215    sm_weight_data = np.random.randn(features * classes) / np.sqrt(features)
216    sm_bias_data = np.random.random_sample(classes)
217
218    # special care for labels since they need to be normalized per batch
219    label_data = np.random.random(batch * classes).reshape((batch, classes))
220    s = label_data.sum(axis=1)
221    label_data /= s[:, None]
222
223    # We treat the inputs as "parameters" here
224    inp = constant_op.constant(
225        inp_data.tolist(),
226        shape=[batch, inputs],
227        dtype=dtypes.float64,
228        name="inp")
229    hidden_weight = constant_op.constant(
230        hidden_weight_data.tolist(),
231        shape=[inputs, features],
232        dtype=dtypes.float64,
233        name="hidden_weight")
234    hidden_bias = constant_op.constant(
235        hidden_bias_data.tolist(),
236        shape=[features],
237        dtype=dtypes.float64,
238        name="hidden_bias")
239    softmax_weight = constant_op.constant(
240        sm_weight_data.tolist(),
241        shape=[features, classes],
242        dtype=dtypes.float64,
243        name="softmax_weight")
244    softmax_bias = constant_op.constant(
245        sm_bias_data.tolist(),
246        shape=[classes],
247        dtype=dtypes.float64,
248        name="softmax_bias")
249
250    # List all the parameter so that we can test them one at a time
251    all_params = [
252        inp, hidden_weight, hidden_bias, softmax_weight, softmax_bias
253    ]
254
255    # Now, Building MNIST
256    def f(inp, hidden_weight, hidden_bias, softmax_weight, softmax_bias):
257      features = nn_ops.relu(
258          nn_ops.xw_plus_b(inp, hidden_weight, hidden_bias), name="features")
259      logits = nn_ops.xw_plus_b(
260          features, softmax_weight, softmax_bias, name="logits")
261      labels = constant_op.constant(
262          label_data.tolist(),
263          shape=[batch, classes],
264          dtype=dtypes.float64,
265          name="labels")
266      cost = nn_ops.softmax_cross_entropy_with_logits(
267          labels=labels, logits=logits, name="cost")
268      return cost
269
270    def f_restricted(x):
271      xs = all_params
272      i = param_index
273      # use x for the i-th parameter
274      xs = xs[0:i]+[x]+xs[i+1:]
275      return f(*xs)
276    # Test the gradients.
277    err = gradient_checker.max_error(*gradient_checker.compute_gradient(
278        f_restricted, [all_params[param_index]], delta=1e-5))
279
280    tf_logging.info("Mini MNIST: %s gradient error = %g", tag, err)
281    return err
282
283  def testInputGradient(self):
284    self.assertLess(self._BuildAndTestMiniMNIST(0, "input"), 1e-8)
285
286  def testHiddenWeightGradient(self):
287    self.assertLess(self._BuildAndTestMiniMNIST(1, "hidden_weight"), 1e-8)
288
289  def testHiddenBiasGradient(self):
290    self.assertLess(self._BuildAndTestMiniMNIST(2, "hidden_bias"), 1e-8)
291
292  def testSoftmaxWeightGradient(self):
293    self.assertLess(self._BuildAndTestMiniMNIST(3, "softmax_weight"), 1e-8)
294
295  def testSoftmaxBiasGradient(self):
296    self.assertLess(self._BuildAndTestMiniMNIST(4, "softmax_bias"), 1e-8)
297
298
299if __name__ == "__main__":
300  test.main()
301