1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for TensorFlow Debugger (tfdbg) Utilities."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.core.protobuf import config_pb2
23from tensorflow.python.client import session
24from tensorflow.python.debug.lib import debug_utils
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import math_ops
28# Import resource_variable_ops for the variables-to-tensor implicit conversion.
29from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
30from tensorflow.python.ops import variables
31from tensorflow.python.platform import googletest
32
33
34@test_util.run_v1_only("Requires tf.Session")
35class DebugUtilsTest(test_util.TensorFlowTestCase):
36
37  @classmethod
38  def setUpClass(cls):
39    cls._sess = session.Session()
40    with cls._sess:
41      cls._a_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]])
42      cls._b_init_val = np.array([[2.0], [-1.0]])
43      cls._c_val = np.array([[-4.0], [np.nan]])
44
45      cls._a_init = constant_op.constant(
46          cls._a_init_val, shape=[2, 2], name="a1_init")
47      cls._b_init = constant_op.constant(
48          cls._b_init_val, shape=[2, 1], name="b_init")
49
50      cls._a = variables.VariableV1(cls._a_init, name="a1")
51      cls._b = variables.VariableV1(cls._b_init, name="b")
52      cls._c = constant_op.constant(cls._c_val, shape=[2, 1], name="c")
53
54      # Matrix product of a and b.
55      cls._p = math_ops.matmul(cls._a, cls._b, name="p1")
56
57      # Sum of two vectors.
58      cls._s = math_ops.add(cls._p, cls._c, name="s")
59
60    cls._graph = cls._sess.graph
61
62    # These are all the expected nodes in the graph:
63    #   - Two variables (a, b), each with four nodes (Variable, init, Assign,
64    #     read).
65    #   - One constant (c).
66    #   - One add operation and one matmul operation.
67    #   - One wildcard node name ("*") that covers nodes created internally
68    #     by TensorFlow itself (e.g., Grappler).
69    cls._expected_num_nodes = 4 * 2 + 1 + 1 + 1 + 1
70
71  def setUp(self):
72    self._run_options = config_pb2.RunOptions()
73
74  def _verify_watches(self, watch_opts, expected_output_slot,
75                      expected_debug_ops, expected_debug_urls):
76    """Verify a list of debug tensor watches.
77
78    This requires all watches in the watch list have exactly the same
79    output_slot, debug_ops and debug_urls.
80
81    Args:
82      watch_opts: Repeated protobuf field of DebugTensorWatch.
83      expected_output_slot: Expected output slot index, as an integer.
84      expected_debug_ops: Expected debug ops, as a list of strings.
85      expected_debug_urls: Expected debug URLs, as a list of strings.
86
87    Returns:
88      List of node names from the list of debug tensor watches.
89    """
90    node_names = []
91    for watch in watch_opts:
92      node_names.append(watch.node_name)
93
94      if watch.node_name == "*":
95        self.assertEqual(-1, watch.output_slot)
96        self.assertEqual(expected_debug_ops, watch.debug_ops)
97        self.assertEqual(expected_debug_urls, watch.debug_urls)
98      else:
99        self.assertEqual(expected_output_slot, watch.output_slot)
100        self.assertEqual(expected_debug_ops, watch.debug_ops)
101        self.assertEqual(expected_debug_urls, watch.debug_urls)
102
103    return node_names
104
105  def testAddDebugTensorWatches_defaultDebugOp(self):
106    debug_utils.add_debug_tensor_watch(
107        self._run_options, "foo/node_a", 1, debug_urls="file:///tmp/tfdbg_1")
108    debug_utils.add_debug_tensor_watch(
109        self._run_options, "foo/node_b", 0, debug_urls="file:///tmp/tfdbg_2")
110
111    debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
112    self.assertEqual(2, len(debug_watch_opts))
113
114    watch_0 = debug_watch_opts[0]
115    watch_1 = debug_watch_opts[1]
116
117    self.assertEqual("foo/node_a", watch_0.node_name)
118    self.assertEqual(1, watch_0.output_slot)
119    self.assertEqual("foo/node_b", watch_1.node_name)
120    self.assertEqual(0, watch_1.output_slot)
121    # Verify default debug op name.
122    self.assertEqual(["DebugIdentity"], watch_0.debug_ops)
123    self.assertEqual(["DebugIdentity"], watch_1.debug_ops)
124
125    # Verify debug URLs.
126    self.assertEqual(["file:///tmp/tfdbg_1"], watch_0.debug_urls)
127    self.assertEqual(["file:///tmp/tfdbg_2"], watch_1.debug_urls)
128
129  def testAddDebugTensorWatches_explicitDebugOp(self):
130    debug_utils.add_debug_tensor_watch(
131        self._run_options,
132        "foo/node_a",
133        0,
134        debug_ops="DebugNanCount",
135        debug_urls="file:///tmp/tfdbg_1")
136
137    debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
138    self.assertEqual(1, len(debug_watch_opts))
139
140    watch_0 = debug_watch_opts[0]
141
142    self.assertEqual("foo/node_a", watch_0.node_name)
143    self.assertEqual(0, watch_0.output_slot)
144
145    # Verify default debug op name.
146    self.assertEqual(["DebugNanCount"], watch_0.debug_ops)
147
148    # Verify debug URLs.
149    self.assertEqual(["file:///tmp/tfdbg_1"], watch_0.debug_urls)
150
151  def testAddDebugTensorWatches_multipleDebugOps(self):
152    debug_utils.add_debug_tensor_watch(
153        self._run_options,
154        "foo/node_a",
155        0,
156        debug_ops=["DebugNanCount", "DebugIdentity"],
157        debug_urls="file:///tmp/tfdbg_1")
158
159    debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
160    self.assertEqual(1, len(debug_watch_opts))
161
162    watch_0 = debug_watch_opts[0]
163
164    self.assertEqual("foo/node_a", watch_0.node_name)
165    self.assertEqual(0, watch_0.output_slot)
166
167    # Verify default debug op name.
168    self.assertEqual(["DebugNanCount", "DebugIdentity"], watch_0.debug_ops)
169
170    # Verify debug URLs.
171    self.assertEqual(["file:///tmp/tfdbg_1"], watch_0.debug_urls)
172
173  def testAddDebugTensorWatches_multipleURLs(self):
174    debug_utils.add_debug_tensor_watch(
175        self._run_options,
176        "foo/node_a",
177        0,
178        debug_ops="DebugNanCount",
179        debug_urls=["file:///tmp/tfdbg_1", "file:///tmp/tfdbg_2"])
180
181    debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
182    self.assertEqual(1, len(debug_watch_opts))
183
184    watch_0 = debug_watch_opts[0]
185
186    self.assertEqual("foo/node_a", watch_0.node_name)
187    self.assertEqual(0, watch_0.output_slot)
188
189    # Verify default debug op name.
190    self.assertEqual(["DebugNanCount"], watch_0.debug_ops)
191
192    # Verify debug URLs.
193    self.assertEqual(["file:///tmp/tfdbg_1", "file:///tmp/tfdbg_2"],
194                     watch_0.debug_urls)
195
196  def testWatchGraph_allNodes(self):
197    debug_utils.watch_graph(
198        self._run_options,
199        self._graph,
200        debug_ops=["DebugIdentity", "DebugNanCount"],
201        debug_urls="file:///tmp/tfdbg_1")
202
203    debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
204    self.assertEqual(self._expected_num_nodes, len(debug_watch_opts))
205
206    # Verify that each of the nodes in the graph with output tensors in the
207    # graph have debug tensor watch.
208    node_names = self._verify_watches(debug_watch_opts, 0,
209                                      ["DebugIdentity", "DebugNanCount"],
210                                      ["file:///tmp/tfdbg_1"])
211
212    # Verify the node names.
213    self.assertIn("a1_init", node_names)
214    self.assertIn("a1", node_names)
215    self.assertIn("a1/Assign", node_names)
216    self.assertIn("a1/read", node_names)
217
218    self.assertIn("b_init", node_names)
219    self.assertIn("b", node_names)
220    self.assertIn("b/Assign", node_names)
221    self.assertIn("b/read", node_names)
222
223    self.assertIn("c", node_names)
224    self.assertIn("p1", node_names)
225    self.assertIn("s", node_names)
226
227    # Assert that the wildcard node name has been created.
228    self.assertIn("*", node_names)
229
230  def testWatchGraph_nodeNameAllowlist(self):
231    debug_utils.watch_graph(
232        self._run_options,
233        self._graph,
234        debug_urls="file:///tmp/tfdbg_1",
235        node_name_regex_allowlist="(a1$|a1_init$|a1/.*|p1$)")
236
237    node_names = self._verify_watches(
238        self._run_options.debug_options.debug_tensor_watch_opts, 0,
239        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
240    self.assertEqual(
241        sorted(["a1_init", "a1", "a1/Assign", "a1/read", "p1"]),
242        sorted(node_names))
243
244  def testWatchGraph_opTypeAllowlist(self):
245    debug_utils.watch_graph(
246        self._run_options,
247        self._graph,
248        debug_urls="file:///tmp/tfdbg_1",
249        op_type_regex_allowlist="(Variable|MatMul)")
250
251    node_names = self._verify_watches(
252        self._run_options.debug_options.debug_tensor_watch_opts, 0,
253        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
254    self.assertEqual(sorted(["a1", "b", "p1"]), sorted(node_names))
255
256  def testWatchGraph_nodeNameAndOpTypeAllowlists(self):
257    debug_utils.watch_graph(
258        self._run_options,
259        self._graph,
260        debug_urls="file:///tmp/tfdbg_1",
261        node_name_regex_allowlist="([a-z]+1$)",
262        op_type_regex_allowlist="(MatMul)")
263
264    node_names = self._verify_watches(
265        self._run_options.debug_options.debug_tensor_watch_opts, 0,
266        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
267    self.assertEqual(["p1"], node_names)
268
269  def testWatchGraph_tensorDTypeAllowlist(self):
270    debug_utils.watch_graph(
271        self._run_options,
272        self._graph,
273        debug_urls="file:///tmp/tfdbg_1",
274        tensor_dtype_regex_allowlist=".*_ref")
275
276    node_names = self._verify_watches(
277        self._run_options.debug_options.debug_tensor_watch_opts, 0,
278        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
279    self.assertItemsEqual(["a1", "a1/Assign", "b", "b/Assign"], node_names)
280
281  def testWatchGraph_nodeNameAndTensorDTypeAllowlists(self):
282    debug_utils.watch_graph(
283        self._run_options,
284        self._graph,
285        debug_urls="file:///tmp/tfdbg_1",
286        node_name_regex_allowlist="^a.*",
287        tensor_dtype_regex_allowlist=".*_ref")
288
289    node_names = self._verify_watches(
290        self._run_options.debug_options.debug_tensor_watch_opts, 0,
291        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
292    self.assertItemsEqual(["a1", "a1/Assign"], node_names)
293
294  def testWatchGraph_nodeNameDenylist(self):
295    debug_utils.watch_graph_with_denylists(
296        self._run_options,
297        self._graph,
298        debug_urls="file:///tmp/tfdbg_1",
299        node_name_regex_denylist="(a1$|a1_init$|a1/.*|p1$)")
300
301    node_names = self._verify_watches(
302        self._run_options.debug_options.debug_tensor_watch_opts, 0,
303        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
304    self.assertEqual(
305        sorted(["b_init", "b", "b/Assign", "b/read", "c", "s"]),
306        sorted(node_names))
307
308  def testWatchGraph_opTypeDenylist(self):
309    debug_utils.watch_graph_with_denylists(
310        self._run_options,
311        self._graph,
312        debug_urls="file:///tmp/tfdbg_1",
313        op_type_regex_denylist="(Variable|Identity|Assign|Const)")
314
315    node_names = self._verify_watches(
316        self._run_options.debug_options.debug_tensor_watch_opts, 0,
317        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
318    self.assertEqual(sorted(["p1", "s"]), sorted(node_names))
319
320  def testWatchGraph_nodeNameAndOpTypeDenylists(self):
321    debug_utils.watch_graph_with_denylists(
322        self._run_options,
323        self._graph,
324        debug_urls="file:///tmp/tfdbg_1",
325        node_name_regex_denylist="p1$",
326        op_type_regex_denylist="(Variable|Identity|Assign|Const)")
327
328    node_names = self._verify_watches(
329        self._run_options.debug_options.debug_tensor_watch_opts, 0,
330        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
331    self.assertEqual(["s"], node_names)
332
333  def testWatchGraph_tensorDTypeDenylists(self):
334    debug_utils.watch_graph_with_denylists(
335        self._run_options,
336        self._graph,
337        debug_urls="file:///tmp/tfdbg_1",
338        tensor_dtype_regex_denylist=".*_ref")
339
340    node_names = self._verify_watches(
341        self._run_options.debug_options.debug_tensor_watch_opts, 0,
342        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
343    self.assertNotIn("a1", node_names)
344    self.assertNotIn("a1/Assign", node_names)
345    self.assertNotIn("b", node_names)
346    self.assertNotIn("b/Assign", node_names)
347    self.assertIn("s", node_names)
348
349  def testWatchGraph_nodeNameAndTensorDTypeDenylists(self):
350    debug_utils.watch_graph_with_denylists(
351        self._run_options,
352        self._graph,
353        debug_urls="file:///tmp/tfdbg_1",
354        node_name_regex_denylist="^s$",
355        tensor_dtype_regex_denylist=".*_ref")
356
357    node_names = self._verify_watches(
358        self._run_options.debug_options.debug_tensor_watch_opts, 0,
359        ["DebugIdentity"], ["file:///tmp/tfdbg_1"])
360    self.assertNotIn("a1", node_names)
361    self.assertNotIn("a1/Assign", node_names)
362    self.assertNotIn("b", node_names)
363    self.assertNotIn("b/Assign", node_names)
364    self.assertNotIn("s", node_names)
365
366
367if __name__ == "__main__":
368  googletest.main()
369