1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Connects all half, float and double tensors to CheckNumericsOp.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python.eager import context 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.util import deprecation 28from tensorflow.python.util.tf_export import tf_export 29 30 31@tf_export(v1=["debugging.assert_all_finite", "verify_tensor_all_finite"]) 32@deprecation.deprecated_endpoints("verify_tensor_all_finite") 33def verify_tensor_all_finite(t=None, msg=None, name=None, x=None, message=None): 34 """Assert that the tensor does not contain any NaN's or Inf's. 35 36 Args: 37 t: Tensor to check. 38 msg: Message to log on failure. 39 name: A name for this operation (optional). 40 x: Alias for t. 41 message: Alias for msg. 42 43 Returns: 44 Same tensor as `t`. 45 """ 46 x = deprecation.deprecated_argument_lookup("x", x, "t", t) 47 message = deprecation.deprecated_argument_lookup( 48 "message", message, "msg", msg) 49 return verify_tensor_all_finite_v2(x, message, name) 50 51 52@tf_export("debugging.assert_all_finite", v1=[]) 53def verify_tensor_all_finite_v2(x, message, name=None): 54 """Assert that the tensor does not contain any NaN's or Inf's. 55 56 Args: 57 x: Tensor to check. 58 message: Message to log on failure. 59 name: A name for this operation (optional). 60 61 Returns: 62 Same tensor as `x`. 63 """ 64 with ops.name_scope(name, "VerifyFinite", [x]) as name: 65 x = ops.convert_to_tensor(x, name="x") 66 with ops.colocate_with(x): 67 verify_input = array_ops.check_numerics(x, message=message) 68 out = control_flow_ops.with_dependencies([verify_input], x) 69 return out 70 71 72@tf_export(v1=["add_check_numerics_ops"]) 73def add_check_numerics_ops(): 74 """Connect a `check_numerics` to every floating point tensor. 75 76 `check_numerics` operations themselves are added for each `half`, `float`, 77 or `double` tensor in the graph. For all ops in the graph, the 78 `check_numerics` op for all of its (`half`, `float`, or `double`) inputs 79 is guaranteed to run before the `check_numerics` op on any of its outputs. 80 81 Note: This API is not compatible with the use of `tf.cond` or 82 `tf.while_loop`, and will raise a `ValueError` if you attempt to call it 83 in such a graph. 84 85 Returns: 86 A `group` op depending on all `check_numerics` ops added. 87 88 Raises: 89 ValueError: If the graph contains any numeric operations in a control flow 90 structure. 91 RuntimeError: If called with eager execution enabled. 92 93 @compatibility(eager) 94 Not compatible with eager execution. To check for `Inf`s and `NaN`s under 95 eager execution, call tfe.seterr(inf_or_nan='raise') once before executing 96 the checked operations. 97 @enc_compatibility 98 """ 99 if context.executing_eagerly(): 100 raise RuntimeError( 101 "add_check_numerics_ops() is not compatible with eager execution. " 102 "To check for Inf's and NaN's under eager execution, call " 103 "tfe.seterr(inf_or_nan='raise') once before executing the " 104 "checked operations.") 105 106 check_op = [] 107 # This code relies on the ordering of ops in get_operations(). 108 # The producer of a tensor always comes before that tensor's consumer in 109 # this list. This is true because get_operations() returns ops in the order 110 # added, and an op can only be added after its inputs are added. 111 for op in ops.get_default_graph().get_operations(): 112 for output in op.outputs: 113 if output.dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 114 if op._get_control_flow_context() is not None: # pylint: disable=protected-access 115 raise ValueError("`tf.add_check_numerics_ops() is not compatible " 116 "with TensorFlow control flow operations such as " 117 "`tf.cond()` or `tf.while_loop()`.") 118 119 message = op.name + ":" + str(output.value_index) 120 with ops.control_dependencies(check_op): 121 check_op = [array_ops.check_numerics(output, message=message)] 122 return control_flow_ops.group(*check_op) 123