1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Connects all half, float and double tensors to CheckNumericsOp."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.eager import context
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import control_flow_ops
27from tensorflow.python.util import deprecation
28from tensorflow.python.util.tf_export import tf_export
29
30
31@tf_export(v1=["debugging.assert_all_finite", "verify_tensor_all_finite"])
32@deprecation.deprecated_endpoints("verify_tensor_all_finite")
33def verify_tensor_all_finite(t=None, msg=None, name=None, x=None, message=None):
34  """Assert that the tensor does not contain any NaN's or Inf's.
35
36  Args:
37    t: Tensor to check.
38    msg: Message to log on failure.
39    name: A name for this operation (optional).
40    x: Alias for t.
41    message: Alias for msg.
42
43  Returns:
44    Same tensor as `t`.
45  """
46  x = deprecation.deprecated_argument_lookup("x", x, "t", t)
47  message = deprecation.deprecated_argument_lookup(
48      "message", message, "msg", msg)
49  return verify_tensor_all_finite_v2(x, message, name)
50
51
52@tf_export("debugging.assert_all_finite", v1=[])
53def verify_tensor_all_finite_v2(x, message, name=None):
54  """Assert that the tensor does not contain any NaN's or Inf's.
55
56  Args:
57    x: Tensor to check.
58    message: Message to log on failure.
59    name: A name for this operation (optional).
60
61  Returns:
62    Same tensor as `x`.
63  """
64  with ops.name_scope(name, "VerifyFinite", [x]) as name:
65    x = ops.convert_to_tensor(x, name="x")
66    with ops.colocate_with(x):
67      verify_input = array_ops.check_numerics(x, message=message)
68      out = control_flow_ops.with_dependencies([verify_input], x)
69  return out
70
71
72@tf_export(v1=["add_check_numerics_ops"])
73def add_check_numerics_ops():
74  """Connect a `check_numerics` to every floating point tensor.
75
76  `check_numerics` operations themselves are added for each `half`, `float`,
77  or `double` tensor in the graph. For all ops in the graph, the
78  `check_numerics` op for all of its (`half`, `float`, or `double`) inputs
79  is guaranteed to run before the `check_numerics` op on any of its outputs.
80
81  Note: This API is not compatible with the use of `tf.cond` or
82  `tf.while_loop`, and will raise a `ValueError` if you attempt to call it
83  in such a graph.
84
85  Returns:
86    A `group` op depending on all `check_numerics` ops added.
87
88  Raises:
89    ValueError: If the graph contains any numeric operations in a control flow
90      structure.
91    RuntimeError: If called with eager execution enabled.
92
93  @compatibility(eager)
94  Not compatible with eager execution. To check for `Inf`s and `NaN`s under
95  eager execution, call tfe.seterr(inf_or_nan='raise') once before executing
96  the checked operations.
97  @enc_compatibility
98  """
99  if context.executing_eagerly():
100    raise RuntimeError(
101        "add_check_numerics_ops() is not compatible with eager execution. "
102        "To check for Inf's and NaN's under eager execution, call "
103        "tfe.seterr(inf_or_nan='raise') once before executing the "
104        "checked operations.")
105
106  check_op = []
107  # This code relies on the ordering of ops in get_operations().
108  # The producer of a tensor always comes before that tensor's consumer in
109  # this list. This is true because get_operations() returns ops in the order
110  # added, and an op can only be added after its inputs are added.
111  for op in ops.get_default_graph().get_operations():
112    for output in op.outputs:
113      if output.dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
114        if op._get_control_flow_context() is not None:  # pylint: disable=protected-access
115          raise ValueError("`tf.add_check_numerics_ops() is not compatible "
116                           "with TensorFlow control flow operations such as "
117                           "`tf.cond()` or `tf.while_loop()`.")
118
119        message = op.name + ":" + str(output.value_index)
120        with ops.control_dependencies(check_op):
121          check_op = [array_ops.check_numerics(output, message=message)]
122  return control_flow_ops.group(*check_op)
123