1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for TensorFlow Debugger (tfdbg) Utilities."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.core.protobuf import config_pb2
23from tensorflow.python.client import session
24from tensorflow.python.debug.lib import debug_utils
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import math_ops
28# Import resource_variable_ops for the variables-to-tensor implicit conversion.
29from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
30from tensorflow.python.ops import variables
31from tensorflow.python.platform import googletest
32
33
34class DebugUtilsTest(test_util.TensorFlowTestCase):
35
36  @classmethod
37  def setUpClass(cls):
38    cls._sess = session.Session()
39    with cls._sess:
40      cls._a_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]])
41      cls._b_init_val = np.array([[2.0], [-1.0]])
42      cls._c_val = np.array([[-4.0], [np.nan]])
43
44      cls._a_init = constant_op.constant(
45          cls._a_init_val, shape=[2, 2], name="a1_init")
46      cls._b_init = constant_op.constant(
47          cls._b_init_val, shape=[2, 1], name="b_init")
48
49      cls._a = variables.VariableV1(cls._a_init, name="a1")
50      cls._b = variables.VariableV1(cls._b_init, name="b")
51      cls._c = constant_op.constant(cls._c_val, shape=[2, 1], name="c")
52
53      # Matrix product of a and b.
54      cls._p = math_ops.matmul(cls._a, cls._b, name="p1")
55
56      # Sum of two vectors.
57      cls._s = math_ops.add(cls._p, cls._c, name="s")
58
59    cls._graph = cls._sess.graph
60
61    # These are all the expected nodes in the graph:
62    #   Two variables (a, b), each with four nodes (Variable, init, Assign,
63    #       read).
64    #   One constant (c).
65    #   One add operation and one matmul operation.
66    cls._expected_num_nodes = 4 * 2 + 1 + 1 + 1
67
68  def setUp(self):
69    self._run_options = config_pb2.RunOptions()
70
71  def _verify_watches(self, watch_opts, expected_output_slot,
72                      expected_debug_ops, expected_debug_urls):
73    """Verify a list of debug tensor watches.
74
75    This requires all watches in the watch list have exactly the same
76    output_slot, debug_ops and debug_urls.
77
78    Args:
79      watch_opts: Repeated protobuf field of DebugTensorWatch.
80      expected_output_slot: Expected output slot index, as an integer.
81      expected_debug_ops: Expected debug ops, as a list of strings.
82      expected_debug_urls: Expected debug URLs, as a list of strings.
83
84    Returns:
85      List of node names from the list of debug tensor watches.
86    """
87    node_names = []
88    for watch in watch_opts:
89      node_names.append(watch.node_name)
90
91      self.assertEqual(expected_output_slot, watch.output_slot)
92      self.assertEqual(expected_debug_ops, watch.debug_ops)
93      self.assertEqual(expected_debug_urls, watch.debug_urls)
94
95    return node_names
96
97  def testAddDebugTensorWatches_defaultDebugOp(self):
98    debug_utils.add_debug_tensor_watch(
99        self._run_options, "foo/node_a", 1, debug_urls="file:///tmp/tfdbg_1")
100    debug_utils.add_debug_tensor_watch(
101        self._run_options, "foo/node_b", 0, debug_urls="file:///tmp/tfdbg_2")
102
103    debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
104    self.assertEqual(2, len(debug_watch_opts))
105
106    watch_0 = debug_watch_opts[0]
107    watch_1 = debug_watch_opts[1]
108
109    self.assertEqual("foo/node_a", watch_0.node_name)
110    self.assertEqual(1, watch_0.output_slot)
111    self.assertEqual("foo/node_b", watch_1.node_name)
112    self.assertEqual(0, watch_1.output_slot)
113    # Verify default debug op name.
114    self.assertEqual(["DebugIdentity"], watch_0.debug_ops)
115    self.assertEqual(["DebugIdentity"], watch_1.debug_ops)
116
117    # Verify debug URLs.
118    self.assertEqual(["file:///tmp/tfdbg_1"], watch_0.debug_urls)
119    self.assertEqual(["file:///tmp/tfdbg_2"], watch_1.debug_urls)
120
121  def testAddDebugTensorWatches_explicitDebugOp(self):
122    debug_utils.add_debug_tensor_watch(
123        self._run_options,
124        "foo/node_a",
125        0,
126        debug_ops="DebugNanCount",
127        debug_urls="file:///tmp/tfdbg_1")
128
129    debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
130    self.assertEqual(1, len(debug_watch_opts))
131
132    watch_0 = debug_watch_opts[0]
133
134    self.assertEqual("foo/node_a", watch_0.node_name)
135    self.assertEqual(0, watch_0.output_slot)
136
137    # Verify default debug op name.
138    self.assertEqual(["DebugNanCount"], watch_0.debug_ops)
139
140    # Verify debug URLs.
141    self.assertEqual(["file:///tmp/tfdbg_1"], watch_0.debug_urls)
142
143  def testAddDebugTensorWatches_multipleDebugOps(self):
144    debug_utils.add_debug_tensor_watch(
145        self._run_options,
146        "foo/node_a",
147        0,
148        debug_ops=["DebugNanCount", "DebugIdentity"],
149        debug_urls="file:///tmp/tfdbg_1")
150
151    debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
152    self.assertEqual(1, len(debug_watch_opts))
153
154    watch_0 = debug_watch_opts[0]
155
156    self.assertEqual("foo/node_a", watch_0.node_name)
157    self.assertEqual(0, watch_0.output_slot)
158
159    # Verify default debug op name.
160    self.assertEqual(["DebugNanCount", "DebugIdentity"], watch_0.debug_ops)
161
162    # Verify debug URLs.
163    self.assertEqual(["file:///tmp/tfdbg_1"], watch_0.debug_urls)
164
165  def testAddDebugTensorWatches_multipleURLs(self):
166    debug_utils.add_debug_tensor_watch(
167        self._run_options,
168        "foo/node_a",
169        0,
170        debug_ops="DebugNanCount",
171        debug_urls=["file:///tmp/tfdbg_1", "file:///tmp/tfdbg_2"])
172
173    debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
174    self.assertEqual(1, len(debug_watch_opts))
175
176    watch_0 = debug_watch_opts[0]
177
178    self.assertEqual("foo/node_a", watch_0.node_name)
179    self.assertEqual(0, watch_0.output_slot)
180
181    # Verify default debug op name.
182    self.assertEqual(["DebugNanCount"], watch_0.debug_ops)
183
184    # Verify debug URLs.
185    self.assertEqual(["file:///tmp/tfdbg_1", "file:///tmp/tfdbg_2"],
186                     watch_0.debug_urls)
187
188  @test_util.run_v1_only("b/120545219")
189  def testWatchGraph_allNodes(self):
190    debug_utils.watch_graph(
191        self._run_options,
192        self._graph,
193        debug_ops=["DebugIdentity", "DebugNanCount"],
194        debug_urls="file:///tmp/tfdbg_1")
195
196    debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
197    self.assertEqual(self._expected_num_nodes, len(debug_watch_opts))
198
199    # Verify that each of the nodes in the graph with output tensors in the
200    # graph have debug tensor watch.
201    node_names = self._verify_watches(debug_watch_opts, 0,
202                                      ["DebugIdentity", "DebugNanCount"],
203                                      ["file:///tmp/tfdbg_1"])
204
205    # Verify the node names.
206    self.assertTrue("a1_init" in node_names)
207    self.assertTrue("a1" in node_names)
208    self.assertTrue("a1/Assign" in node_names)
209    self.assertTrue("a1/read" in node_names)
210
211    self.assertTrue("b_init" in node_names)
212    self.assertTrue("b" in node_names)
213    self.assertTrue("b/Assign" in node_names)
214    self.assertTrue("b/read" in node_names)
215
216    self.assertTrue("c" in node_names)
217    self.assertTrue("p1" in node_names)
218    self.assertTrue("s" in node_names)
219
220  @test_util.run_v1_only("b/120545219")
221  def testWatchGraph_nodeNameWhitelist(self):
222    debug_utils.watch_graph(
223        self._run_options,
224        self._graph,
225        debug_urls="file:///tmp/tfdbg_1",
226        node_name_regex_whitelist="(a1$|a1_init$|a1/.*|p1$)")
227
228    node_names = self._verify_watches(
229        self._run_options.debug_options.debug_tensor_watch_opts, 0,
230        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
231    self.assertEqual(
232        sorted(["a1_init", "a1", "a1/Assign", "a1/read", "p1"]),
233        sorted(node_names))
234
235  @test_util.run_v1_only("b/120545219")
236  def testWatchGraph_opTypeWhitelist(self):
237    debug_utils.watch_graph(
238        self._run_options,
239        self._graph,
240        debug_urls="file:///tmp/tfdbg_1",
241        op_type_regex_whitelist="(Variable|MatMul)")
242
243    node_names = self._verify_watches(
244        self._run_options.debug_options.debug_tensor_watch_opts, 0,
245        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
246    self.assertEqual(sorted(["a1", "b", "p1"]), sorted(node_names))
247
248  def testWatchGraph_nodeNameAndOpTypeWhitelists(self):
249    debug_utils.watch_graph(
250        self._run_options,
251        self._graph,
252        debug_urls="file:///tmp/tfdbg_1",
253        node_name_regex_whitelist="([a-z]+1$)",
254        op_type_regex_whitelist="(MatMul)")
255
256    node_names = self._verify_watches(
257        self._run_options.debug_options.debug_tensor_watch_opts, 0,
258        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
259    self.assertEqual(["p1"], node_names)
260
261  @test_util.run_v1_only("b/120545219")
262  def testWatchGraph_tensorDTypeWhitelist(self):
263    debug_utils.watch_graph(
264        self._run_options,
265        self._graph,
266        debug_urls="file:///tmp/tfdbg_1",
267        tensor_dtype_regex_whitelist=".*_ref")
268
269    node_names = self._verify_watches(
270        self._run_options.debug_options.debug_tensor_watch_opts, 0,
271        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
272    self.assertItemsEqual(["a1", "a1/Assign", "b", "b/Assign"], node_names)
273
274  @test_util.run_v1_only("b/120545219")
275  def testWatchGraph_nodeNameAndTensorDTypeWhitelists(self):
276    debug_utils.watch_graph(
277        self._run_options,
278        self._graph,
279        debug_urls="file:///tmp/tfdbg_1",
280        node_name_regex_whitelist="^a.*",
281        tensor_dtype_regex_whitelist=".*_ref")
282
283    node_names = self._verify_watches(
284        self._run_options.debug_options.debug_tensor_watch_opts, 0,
285        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
286    self.assertItemsEqual(["a1", "a1/Assign"], node_names)
287
288  @test_util.run_v1_only("b/120545219")
289  def testWatchGraph_nodeNameBlacklist(self):
290    debug_utils.watch_graph_with_blacklists(
291        self._run_options,
292        self._graph,
293        debug_urls="file:///tmp/tfdbg_1",
294        node_name_regex_blacklist="(a1$|a1_init$|a1/.*|p1$)")
295
296    node_names = self._verify_watches(
297        self._run_options.debug_options.debug_tensor_watch_opts, 0,
298        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
299    self.assertEqual(
300        sorted(["b_init", "b", "b/Assign", "b/read", "c", "s"]),
301        sorted(node_names))
302
303  @test_util.run_v1_only("b/120545219")
304  def testWatchGraph_opTypeBlacklist(self):
305    debug_utils.watch_graph_with_blacklists(
306        self._run_options,
307        self._graph,
308        debug_urls="file:///tmp/tfdbg_1",
309        op_type_regex_blacklist="(Variable|Identity|Assign|Const)")
310
311    node_names = self._verify_watches(
312        self._run_options.debug_options.debug_tensor_watch_opts, 0,
313        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
314    self.assertEqual(sorted(["p1", "s"]), sorted(node_names))
315
316  @test_util.run_v1_only("b/120545219")
317  def testWatchGraph_nodeNameAndOpTypeBlacklists(self):
318    debug_utils.watch_graph_with_blacklists(
319        self._run_options,
320        self._graph,
321        debug_urls="file:///tmp/tfdbg_1",
322        node_name_regex_blacklist="p1$",
323        op_type_regex_blacklist="(Variable|Identity|Assign|Const)")
324
325    node_names = self._verify_watches(
326        self._run_options.debug_options.debug_tensor_watch_opts, 0,
327        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
328    self.assertEqual(["s"], node_names)
329
330  @test_util.run_v1_only("b/120545219")
331  def testWatchGraph_tensorDTypeBlacklists(self):
332    debug_utils.watch_graph_with_blacklists(
333        self._run_options,
334        self._graph,
335        debug_urls="file:///tmp/tfdbg_1",
336        tensor_dtype_regex_blacklist=".*_ref")
337
338    node_names = self._verify_watches(
339        self._run_options.debug_options.debug_tensor_watch_opts, 0,
340        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
341    self.assertNotIn("a1", node_names)
342    self.assertNotIn("a1/Assign", node_names)
343    self.assertNotIn("b", node_names)
344    self.assertNotIn("b/Assign", node_names)
345    self.assertIn("s", node_names)
346
347  @test_util.run_v1_only("b/120545219")
348  def testWatchGraph_nodeNameAndTensorDTypeBlacklists(self):
349    debug_utils.watch_graph_with_blacklists(
350        self._run_options,
351        self._graph,
352        debug_urls="file:///tmp/tfdbg_1",
353        node_name_regex_blacklist="^s$",
354        tensor_dtype_regex_blacklist=".*_ref")
355
356    node_names = self._verify_watches(
357        self._run_options.debug_options.debug_tensor_watch_opts, 0,
358        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
359    self.assertNotIn("a1", node_names)
360    self.assertNotIn("a1/Assign", node_names)
361    self.assertNotIn("b", node_names)
362    self.assertNotIn("b/Assign", node_names)
363    self.assertNotIn("s", node_names)
364
365
366if __name__ == "__main__":
367  googletest.main()
368