1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.numerics.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import test_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import numerics 31from tensorflow.python.platform import test 32 33 34class VerifyTensorAllFiniteTest(test.TestCase): 35 36 def testVerifyTensorAllFiniteSucceeds(self): 37 x_shape = [5, 4] 38 x = np.random.random_sample(x_shape).astype(np.float32) 39 with test_util.use_gpu(): 40 t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32) 41 t_verified = numerics.verify_tensor_all_finite(t, 42 "Input is not a number.") 43 self.assertAllClose(x, self.evaluate(t_verified)) 44 45 def testVerifyTensorAllFiniteFails(self): 46 x_shape = [5, 4] 47 x = np.random.random_sample(x_shape).astype(np.float32) 48 my_msg = "Input is not a number." 49 50 # Test NaN. 51 x[0] = np.nan 52 with test_util.use_gpu(): 53 with self.assertRaisesOpError(my_msg): 54 t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32) 55 t_verified = numerics.verify_tensor_all_finite(t, my_msg) 56 self.evaluate(t_verified) 57 58 # Test Inf. 59 x[0] = np.inf 60 with test_util.use_gpu(): 61 with self.assertRaisesOpError(my_msg): 62 t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32) 63 t_verified = numerics.verify_tensor_all_finite(t, my_msg) 64 self.evaluate(t_verified) 65 66 67@test_util.run_v1_only("b/120545219") 68class NumericsTest(test.TestCase): 69 70 def testInf(self): 71 with self.session(graph=ops.Graph()): 72 t1 = constant_op.constant(1.0) 73 t2 = constant_op.constant(0.0) 74 a = math_ops.div(t1, t2) 75 check = numerics.add_check_numerics_ops() 76 a = control_flow_ops.with_dependencies([check], a) 77 with self.assertRaisesOpError("Inf"): 78 self.evaluate(a) 79 80 def testNaN(self): 81 with self.session(graph=ops.Graph()): 82 t1 = constant_op.constant(0.0) 83 t2 = constant_op.constant(0.0) 84 a = math_ops.div(t1, t2) 85 check = numerics.add_check_numerics_ops() 86 a = control_flow_ops.with_dependencies([check], a) 87 with self.assertRaisesOpError("NaN"): 88 self.evaluate(a) 89 90 def testBoth(self): 91 with self.session(graph=ops.Graph()): 92 t1 = constant_op.constant([1.0, 0.0]) 93 t2 = constant_op.constant([0.0, 0.0]) 94 a = math_ops.div(t1, t2) 95 check = numerics.add_check_numerics_ops() 96 a = control_flow_ops.with_dependencies([check], a) 97 with self.assertRaisesOpError("Inf and NaN"): 98 self.evaluate(a) 99 100 def testPassThrough(self): 101 with self.session(graph=ops.Graph()): 102 t1 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3]) 103 checked = array_ops.check_numerics(t1, message="pass through test") 104 value = self.evaluate(checked) 105 self.assertAllEqual(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), value) 106 self.assertEqual([2, 3], checked.get_shape()) 107 108 def testControlFlowCond(self): 109 predicate = array_ops.placeholder(dtypes.bool, shape=[]) 110 _ = control_flow_ops.cond(predicate, 111 lambda: constant_op.constant([37.]), 112 lambda: constant_op.constant([42.])) 113 with self.assertRaisesRegexp( 114 ValueError, 115 r"`tf\.add_check_numerics_ops\(\) is not compatible with " 116 r"TensorFlow control flow operations such as `tf\.cond\(\)` " 117 r"or `tf.while_loop\(\)`\."): 118 numerics.add_check_numerics_ops() 119 120 def testControlFlowWhile(self): 121 predicate = array_ops.placeholder(dtypes.bool, shape=[]) 122 _ = control_flow_ops.while_loop(lambda _: predicate, 123 lambda _: constant_op.constant([37.]), 124 [constant_op.constant([42.])]) 125 with self.assertRaisesRegexp( 126 ValueError, 127 r"`tf\.add_check_numerics_ops\(\) is not compatible with " 128 r"TensorFlow control flow operations such as `tf\.cond\(\)` " 129 r"or `tf.while_loop\(\)`\."): 130 numerics.add_check_numerics_ops() 131 132 133if __name__ == "__main__": 134 test.main() 135