1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tf.test.compute_gradient and tf.compute_gradient_error.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import test_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import gradient_checker 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import nn_ops 31import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 32from tensorflow.python.platform import test 33from tensorflow.python.platform import tf_logging 34 35 36@ops.RegisterGradient("BadGrad") 37def _bad_grad(unused_op, grad): 38 """A gradient that returns the wrong shape.""" 39 return array_ops.transpose(grad) 40 41 42@ops.RegisterGradient("NaNGrad") 43def _nan_grad(unused_op, grad): 44 """A gradient that returns NaN.""" 45 return np.nan * grad 46 47 48class GradientCheckerTest(test.TestCase): 49 50 @test_util.run_deprecated_v1 51 def testAddSimple(self): 52 np.random.seed(1) # Fix seed to avoid flakiness 53 with self.session(use_gpu=False): 54 # a test case for Add operation 55 size = (2, 3) 56 x1 = constant_op.constant(2.0, shape=size, name="x1") 57 x2 = constant_op.constant(3.0, shape=size, name="x2") 58 y = math_ops.add(x1, x2, name="y") 59 60 # checking gradients for x1 61 error = gradient_checker.compute_gradient_error(x1, size, y, size) 62 tf_logging.info("x1 error = %f", error) 63 assert error < 1e-4 64 65 @test_util.run_deprecated_v1 66 def testAddSimpleGPU(self): 67 np.random.seed(2) # Fix seed to avoid flakiness 68 with self.session(use_gpu=True): 69 # a test case for Add operation 70 size = (2, 3) 71 x1 = constant_op.constant(2.0, shape=size, name="x1") 72 x2 = constant_op.constant(3.0, shape=size, name="x2") 73 y = math_ops.add(x1, x2, name="y") 74 75 # checking gradients for x1 76 error = gradient_checker.compute_gradient_error(x1, size, y, size) 77 tf_logging.info("x1 error = %f", error) 78 assert error < 1e-4 79 80 @test_util.run_deprecated_v1 81 def testAddCustomized(self): 82 np.random.seed(3) # Fix seed to avoid flakiness 83 with self.cached_session(): 84 # a test case for Add operation 85 size = (2, 3) 86 x1 = constant_op.constant( 87 2.0, shape=size, dtype=dtypes.float64, name="x1") 88 x2 = constant_op.constant( 89 3.0, shape=size, dtype=dtypes.float64, name="x2") 90 y = math_ops.add(x1, x2, name="y") 91 92 # checkint gradients for x2 using a special init_value and delta 93 x_init_value = np.asarray(np.arange(6, dtype=np.float64).reshape(2, 3)) 94 error = gradient_checker.compute_gradient_error( 95 x2, size, y, size, x_init_value=x_init_value, delta=1e-2) 96 tf_logging.info("x2 error = %f", error) 97 assert error < 1e-10 98 99 @test_util.run_deprecated_v1 100 def testGather(self): 101 np.random.seed(4) # Fix seed to avoid flakiness 102 with self.cached_session(): 103 p_shape = (4, 2) 104 p_size = 8 105 index_values = [1, 3] 106 y_shape = [2, 2] 107 params = constant_op.constant( 108 np.arange(p_size).astype(np.float), shape=p_shape, name="p") 109 indices = constant_op.constant(index_values, name="i") 110 y = array_ops.gather(params, indices, name="y") 111 112 error = gradient_checker.compute_gradient_error(params, p_shape, y, 113 y_shape) 114 tf_logging.info("gather error = %f", error) 115 assert error < 1e-4 116 117 @test_util.run_deprecated_v1 118 def testNestedGather(self): 119 np.random.seed(5) # Fix seed to avoid flakiness 120 with self.cached_session(): 121 p_shape = (8, 2) 122 p_size = 16 123 index_values = [1, 3, 5, 6] 124 index_values2 = [0, 2] 125 y2_shape = [2, 2] 126 127 params = constant_op.constant( 128 np.arange(p_size).astype(np.float), shape=p_shape, name="p") 129 indices = constant_op.constant(index_values, name="i") 130 y = array_ops.gather(params, indices, name="y") 131 indices2 = constant_op.constant(index_values2, name="i2") 132 y2 = array_ops.gather(y, indices2, name="y2") 133 134 error = gradient_checker.compute_gradient_error(params, p_shape, y2, 135 y2_shape) 136 tf_logging.info("nested gather error = %f", error) 137 assert error < 1e-4 138 139 @test_util.run_deprecated_v1 140 def testComplexMul(self): 141 with self.cached_session(): 142 size = () 143 c = constant_op.constant(5 + 7j, dtype=dtypes.complex64) 144 x = constant_op.constant(11 - 13j, dtype=dtypes.complex64) 145 y = c * x 146 analytical, numerical = gradient_checker.compute_gradient(x, size, y, 147 size) 148 correct = np.array([[5, 7], [-7, 5]]) 149 self.assertAllEqual(correct, analytical) 150 self.assertAllClose(correct, numerical, rtol=1e-4) 151 self.assertLess( 152 gradient_checker.compute_gradient_error(x, size, y, size), 2e-4) 153 154 @test_util.run_deprecated_v1 155 def testComplexConj(self): 156 with self.cached_session(): 157 size = () 158 x = constant_op.constant(11 - 13j, dtype=dtypes.complex64) 159 y = math_ops.conj(x) 160 analytical, numerical = gradient_checker.compute_gradient(x, size, y, 161 size) 162 correct = np.array([[1, 0], [0, -1]]) 163 self.assertAllEqual(correct, analytical) 164 self.assertAllClose(correct, numerical, rtol=2e-5) 165 self.assertLess( 166 gradient_checker.compute_gradient_error(x, size, y, size), 2e-5) 167 168 @test_util.run_deprecated_v1 169 def testEmptySucceeds(self): 170 with self.cached_session(): 171 x = array_ops.placeholder(dtypes.float32) 172 y = array_ops.identity(x) 173 for grad in gradient_checker.compute_gradient(x, (0, 3), y, (0, 3)): 174 self.assertEqual(grad.shape, (0, 0)) 175 error = gradient_checker.compute_gradient_error(x, (0, 3), y, (0, 3)) 176 self.assertEqual(error, 0) 177 178 def testEmptyFails(self): 179 with ops.Graph().as_default() as g: 180 with self.session(graph=g): 181 x = array_ops.placeholder(dtypes.float32) 182 with g.gradient_override_map({"Identity": "BadGrad"}): 183 y = array_ops.identity(x) 184 bad = r"Empty gradient has wrong shape: expected \(0, 3\), got \(3, 0\)" 185 with self.assertRaisesRegexp(ValueError, bad): 186 gradient_checker.compute_gradient(x, (0, 3), y, (0, 3)) 187 with self.assertRaisesRegexp(ValueError, bad): 188 gradient_checker.compute_gradient_error(x, (0, 3), y, (0, 3)) 189 190 def testNaNGradFails(self): 191 with ops.Graph().as_default() as g: 192 with self.session(graph=g): 193 x = array_ops.placeholder(dtypes.float32) 194 with g.gradient_override_map({"Identity": "NaNGrad"}): 195 y = array_ops.identity(x) 196 error = gradient_checker.compute_gradient_error(x, (), y, ()) 197 # Typical test would assert error < max_err, so assert this test would 198 # raise AssertionError, since NaN is not < 1.0. 199 with self.assertRaisesRegexp(AssertionError, "False is not true"): 200 self.assertTrue(error < 1.0) 201 202 203class MiniMNISTTest(test.TestCase): 204 205 # Gradient checker for MNIST. 206 def _BuildAndTestMiniMNIST(self, param_index, tag): 207 # Fix seed to avoid occasional flakiness 208 np.random.seed(6) 209 210 # Hyperparameters 211 batch = 3 212 inputs = 16 213 features = 32 214 classes = 10 215 216 # Define the parameters 217 inp_data = np.random.random_sample(inputs * batch) 218 hidden_weight_data = np.random.randn(inputs * features) / np.sqrt(inputs) 219 hidden_bias_data = np.random.random_sample(features) 220 sm_weight_data = np.random.randn(features * classes) / np.sqrt(features) 221 sm_bias_data = np.random.random_sample(classes) 222 223 # special care for labels since they need to be normalized per batch 224 label_data = np.random.random(batch * classes).reshape((batch, classes)) 225 s = label_data.sum(axis=1) 226 label_data /= s[:, None] 227 228 with self.session(use_gpu=True): 229 # We treat the inputs as "parameters" here 230 inp = constant_op.constant( 231 inp_data.tolist(), 232 shape=[batch, inputs], 233 dtype=dtypes.float64, 234 name="inp") 235 hidden_weight = constant_op.constant( 236 hidden_weight_data.tolist(), 237 shape=[inputs, features], 238 dtype=dtypes.float64, 239 name="hidden_weight") 240 hidden_bias = constant_op.constant( 241 hidden_bias_data.tolist(), 242 shape=[features], 243 dtype=dtypes.float64, 244 name="hidden_bias") 245 softmax_weight = constant_op.constant( 246 sm_weight_data.tolist(), 247 shape=[features, classes], 248 dtype=dtypes.float64, 249 name="softmax_weight") 250 softmax_bias = constant_op.constant( 251 sm_bias_data.tolist(), 252 shape=[classes], 253 dtype=dtypes.float64, 254 name="softmax_bias") 255 256 # List all the parameter so that we can test them one at a time 257 all_params = [ 258 inp, hidden_weight, hidden_bias, softmax_weight, softmax_bias 259 ] 260 param_sizes = [ 261 [batch, inputs], # inp 262 [inputs, features], # hidden_weight, 263 [features], # hidden_bias 264 [features, classes], # softmax_weight, 265 [classes] 266 ] # softmax_bias 267 268 # Now, Building MNIST 269 features = nn_ops.relu( 270 nn_ops.xw_plus_b(inp, hidden_weight, hidden_bias), name="features") 271 logits = nn_ops.xw_plus_b( 272 features, softmax_weight, softmax_bias, name="logits") 273 labels = constant_op.constant( 274 label_data.tolist(), 275 shape=[batch, classes], 276 dtype=dtypes.float64, 277 name="labels") 278 cost = nn_ops.softmax_cross_entropy_with_logits( 279 labels=labels, logits=logits, name="cost") 280 281 # Test the gradients. 282 err = gradient_checker.compute_gradient_error( 283 all_params[param_index], 284 param_sizes[param_index], 285 cost, [batch], 286 delta=1e-5) 287 288 tf_logging.info("Mini MNIST: %s gradient error = %g", tag, err) 289 return err 290 291 @test_util.run_deprecated_v1 292 def testInputGradient(self): 293 self.assertLess(self._BuildAndTestMiniMNIST(0, "input"), 1e-8) 294 295 @test_util.run_deprecated_v1 296 def testHiddenWeightGradient(self): 297 self.assertLess(self._BuildAndTestMiniMNIST(1, "hidden_weight"), 1e-8) 298 299 @test_util.run_deprecated_v1 300 def testHiddenBiasGradient(self): 301 self.assertLess(self._BuildAndTestMiniMNIST(2, "hidden_bias"), 1e-8) 302 303 @test_util.run_deprecated_v1 304 def testSoftmaxWeightGradient(self): 305 self.assertLess(self._BuildAndTestMiniMNIST(3, "softmax_weight"), 1e-8) 306 307 @test_util.run_deprecated_v1 308 def testSoftmaxBiasGradient(self): 309 self.assertLess(self._BuildAndTestMiniMNIST(4, "softmax_bias"), 1e-8) 310 311 312if __name__ == "__main__": 313 test.main() 314