1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.numerics."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import control_flow_ops
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import numerics
31from tensorflow.python.platform import test
32
33
34class VerifyTensorAllFiniteTest(test.TestCase):
35
36  def testVerifyTensorAllFiniteSucceeds(self):
37    x_shape = [5, 4]
38    x = np.random.random_sample(x_shape).astype(np.float32)
39    with test_util.use_gpu():
40      t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32)
41      t_verified = numerics.verify_tensor_all_finite(t,
42                                                     "Input is not a number.")
43      self.assertAllClose(x, self.evaluate(t_verified))
44
45  def testVerifyTensorAllFiniteFails(self):
46    x_shape = [5, 4]
47    x = np.random.random_sample(x_shape).astype(np.float32)
48    my_msg = "Input is not a number."
49
50    # Test NaN.
51    x[0] = np.nan
52    with test_util.use_gpu():
53      with self.assertRaisesOpError(my_msg):
54        t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32)
55        t_verified = numerics.verify_tensor_all_finite(t, my_msg)
56        self.evaluate(t_verified)
57
58    # Test Inf.
59    x[0] = np.inf
60    with test_util.use_gpu():
61      with self.assertRaisesOpError(my_msg):
62        t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32)
63        t_verified = numerics.verify_tensor_all_finite(t, my_msg)
64        self.evaluate(t_verified)
65
66
67@test_util.run_v1_only("b/120545219")
68class NumericsTest(test.TestCase):
69
70  def testInf(self):
71    with self.session(graph=ops.Graph()):
72      t1 = constant_op.constant(1.0)
73      t2 = constant_op.constant(0.0)
74      a = math_ops.div(t1, t2)
75      check = numerics.add_check_numerics_ops()
76      a = control_flow_ops.with_dependencies([check], a)
77      with self.assertRaisesOpError("Inf"):
78        self.evaluate(a)
79
80  def testNaN(self):
81    with self.session(graph=ops.Graph()):
82      t1 = constant_op.constant(0.0)
83      t2 = constant_op.constant(0.0)
84      a = math_ops.div(t1, t2)
85      check = numerics.add_check_numerics_ops()
86      a = control_flow_ops.with_dependencies([check], a)
87      with self.assertRaisesOpError("NaN"):
88        self.evaluate(a)
89
90  def testBoth(self):
91    with self.session(graph=ops.Graph()):
92      t1 = constant_op.constant([1.0, 0.0])
93      t2 = constant_op.constant([0.0, 0.0])
94      a = math_ops.div(t1, t2)
95      check = numerics.add_check_numerics_ops()
96      a = control_flow_ops.with_dependencies([check], a)
97      with self.assertRaisesOpError("Inf and NaN"):
98        self.evaluate(a)
99
100  def testPassThrough(self):
101    with self.session(graph=ops.Graph()):
102      t1 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3])
103      checked = array_ops.check_numerics(t1, message="pass through test")
104      value = self.evaluate(checked)
105      self.assertAllEqual(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), value)
106      self.assertEqual([2, 3], checked.get_shape())
107
108  def testControlFlowCond(self):
109    predicate = array_ops.placeholder(dtypes.bool, shape=[])
110    _ = control_flow_ops.cond(predicate,
111                              lambda: constant_op.constant([37.]),
112                              lambda: constant_op.constant([42.]))
113    with self.assertRaisesRegexp(
114        ValueError,
115        r"`tf\.add_check_numerics_ops\(\) is not compatible with "
116        r"TensorFlow control flow operations such as `tf\.cond\(\)` "
117        r"or `tf.while_loop\(\)`\."):
118      numerics.add_check_numerics_ops()
119
120  def testControlFlowWhile(self):
121    predicate = array_ops.placeholder(dtypes.bool, shape=[])
122    _ = control_flow_ops.while_loop(lambda _: predicate,
123                                    lambda _: constant_op.constant([37.]),
124                                    [constant_op.constant([42.])])
125    with self.assertRaisesRegexp(
126        ValueError,
127        r"`tf\.add_check_numerics_ops\(\) is not compatible with "
128        r"TensorFlow control flow operations such as `tf\.cond\(\)` "
129        r"or `tf.while_loop\(\)`\."):
130      numerics.add_check_numerics_ops()
131
132
133if __name__ == "__main__":
134  test.main()
135