1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tfdbg module debug_data.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.debug.lib import debug_graphs 21from tensorflow.python.framework import test_util 22from tensorflow.python.platform import test 23 24 25class ParseNodeOrTensorNameTest(test_util.TensorFlowTestCase): 26 27 def testParseNodeName(self): 28 node_name, slot = debug_graphs.parse_node_or_tensor_name( 29 "namespace1/node_1") 30 31 self.assertEqual("namespace1/node_1", node_name) 32 self.assertIsNone(slot) 33 34 def testParseTensorName(self): 35 node_name, slot = debug_graphs.parse_node_or_tensor_name( 36 "namespace1/node_2:3") 37 38 self.assertEqual("namespace1/node_2", node_name) 39 self.assertEqual(3, slot) 40 41 42class GetNodeNameAndOutputSlotTest(test_util.TensorFlowTestCase): 43 44 def testParseTensorNameInputWorks(self): 45 self.assertEqual("a", debug_graphs.get_node_name("a:0")) 46 self.assertEqual(0, debug_graphs.get_output_slot("a:0")) 47 48 self.assertEqual("_b", debug_graphs.get_node_name("_b:1")) 49 self.assertEqual(1, debug_graphs.get_output_slot("_b:1")) 50 51 def testParseNodeNameInputWorks(self): 52 self.assertEqual("a", debug_graphs.get_node_name("a")) 53 self.assertEqual(0, debug_graphs.get_output_slot("a")) 54 55 56class NodeNameChecksTest(test_util.TensorFlowTestCase): 57 58 def testIsCopyNode(self): 59 self.assertTrue(debug_graphs.is_copy_node("__copy_ns1/ns2/node3_0")) 60 61 self.assertFalse(debug_graphs.is_copy_node("copy_ns1/ns2/node3_0")) 62 self.assertFalse(debug_graphs.is_copy_node("_copy_ns1/ns2/node3_0")) 63 self.assertFalse(debug_graphs.is_copy_node("_copyns1/ns2/node3_0")) 64 self.assertFalse(debug_graphs.is_copy_node("__dbg_ns1/ns2/node3_0")) 65 66 def testIsDebugNode(self): 67 self.assertTrue( 68 debug_graphs.is_debug_node("__dbg_ns1/ns2/node3:0_0_DebugIdentity")) 69 70 self.assertFalse( 71 debug_graphs.is_debug_node("dbg_ns1/ns2/node3:0_0_DebugIdentity")) 72 self.assertFalse( 73 debug_graphs.is_debug_node("_dbg_ns1/ns2/node3:0_0_DebugIdentity")) 74 self.assertFalse( 75 debug_graphs.is_debug_node("_dbgns1/ns2/node3:0_0_DebugIdentity")) 76 self.assertFalse(debug_graphs.is_debug_node("__copy_ns1/ns2/node3_0")) 77 78 79class ParseDebugNodeNameTest(test_util.TensorFlowTestCase): 80 81 def testParseDebugNodeName_valid(self): 82 debug_node_name_1 = "__dbg_ns_a/ns_b/node_c:1_0_DebugIdentity" 83 (watched_node, watched_output_slot, debug_op_index, 84 debug_op) = debug_graphs.parse_debug_node_name(debug_node_name_1) 85 86 self.assertEqual("ns_a/ns_b/node_c", watched_node) 87 self.assertEqual(1, watched_output_slot) 88 self.assertEqual(0, debug_op_index) 89 self.assertEqual("DebugIdentity", debug_op) 90 91 def testParseDebugNodeName_invalidPrefix(self): 92 invalid_debug_node_name_1 = "__copy_ns_a/ns_b/node_c:1_0_DebugIdentity" 93 94 with self.assertRaisesRegexp(ValueError, "Invalid prefix"): 95 debug_graphs.parse_debug_node_name(invalid_debug_node_name_1) 96 97 def testParseDebugNodeName_missingDebugOpIndex(self): 98 invalid_debug_node_name_1 = "__dbg_node1:0_DebugIdentity" 99 100 with self.assertRaisesRegexp(ValueError, "Invalid debug node name"): 101 debug_graphs.parse_debug_node_name(invalid_debug_node_name_1) 102 103 def testParseDebugNodeName_invalidWatchedTensorName(self): 104 invalid_debug_node_name_1 = "__dbg_node1_0_DebugIdentity" 105 106 with self.assertRaisesRegexp(ValueError, 107 "Invalid tensor name in debug node name"): 108 debug_graphs.parse_debug_node_name(invalid_debug_node_name_1) 109 110 111if __name__ == "__main__": 112 test.main() 113