1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for debugger functionalities in tf.Session."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import functools
22import glob
23import os
24import shutil
25import tempfile
26import threading
27
28import numpy as np
29from six.moves import xrange  # pylint: disable=redefined-builtin
30
31from tensorflow.core.protobuf import config_pb2
32from tensorflow.core.protobuf import rewriter_config_pb2
33from tensorflow.core.util import event_pb2
34from tensorflow.python.client import session
35from tensorflow.python.debug.lib import debug_data
36from tensorflow.python.debug.lib import debug_graphs
37from tensorflow.python.debug.lib import debug_utils
38from tensorflow.python.framework import constant_op
39from tensorflow.python.framework import dtypes
40from tensorflow.python.framework import errors
41from tensorflow.python.framework import ops
42from tensorflow.python.framework import test_util
43from tensorflow.python.ops import array_ops
44from tensorflow.python.ops import control_flow_ops
45from tensorflow.python.ops import data_flow_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import parsing_ops
48from tensorflow.python.ops import rnn
49from tensorflow.python.ops import rnn_cell_impl
50from tensorflow.python.ops import state_ops
51from tensorflow.python.ops import variables
52import tensorflow.python.ops.tensor_array_grad  # pylint: disable=unused-import
53from tensorflow.python.platform import googletest
54from tensorflow.python.platform import test
55from tensorflow.python.training import gradient_descent
56
57
58def no_rewrite_session_config():
59  rewriter_config = rewriter_config_pb2.RewriterConfig(
60      disable_model_pruning=True,
61      arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
62      dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF)
63  graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
64  return config_pb2.ConfigProto(graph_options=graph_options)
65
66
67class _RNNCellForTest(rnn_cell_impl.RNNCell):
68  """RNN cell for testing."""
69
70  def __init__(self, input_output_size, state_size):
71    self._input_output_size = input_output_size
72    self._state_size = state_size
73    self._w = variables.VariableV1(1.0, dtype=dtypes.float32, name="w")
74
75  @property
76  def output_size(self):
77    return self._input_output_size
78
79  @property
80  def state_size(self):
81    return self._state_size
82
83  def __call__(self, input_, state, scope=None):
84    return (math_ops.multiply(self._w, input_), state)
85
86
87@test_util.run_v1_only("b/120545219")
88class SessionDebugTestBase(test_util.TensorFlowTestCase):
89  """Base class for unit tests of tfdbg running with tf.Session."""
90
91  @classmethod
92  def setUpClass(cls):
93    if test.is_gpu_available():
94      cls._expected_partition_graph_count = 2
95      cls._expected_num_devices = 2
96      gpu_name = test_util.gpu_device_name()
97      cls._main_device = "/job:localhost/replica:0/task:0" + gpu_name
98    else:
99      cls._expected_partition_graph_count = 1
100      cls._expected_num_devices = 1
101      cls._main_device = "/job:localhost/replica:0/task:0/device:CPU:0"
102
103  @classmethod
104  def tearDownClass(cls):
105    pass
106
107  def setUp(self):
108    self._dump_root = tempfile.mkdtemp()
109
110  def tearDown(self):
111    ops.reset_default_graph()
112
113    # Tear down temporary dump directory.
114    if os.path.isdir(self._dump_root):
115      shutil.rmtree(self._dump_root)
116
117  def _debug_urls(self, run_number=None):
118    raise NotImplementedError(
119        "_debug_urls() method is not implemented in the base test class.")
120
121  def _debug_dump_dir(self, run_number=None):
122    raise NotImplementedError(
123        "_debug_dump_dir() method is not implemented in the base test class.")
124
125  def _debug_run_and_get_dump(self,
126                              sess,
127                              fetches,
128                              feed_dict=None,
129                              debug_ops="DebugIdentity",
130                              tolerate_debug_op_creation_failures=False,
131                              global_step=-1,
132                              validate=True,
133                              expected_partition_graph_count=None):
134    """Run fetches with debugging and obtain DebugDumpDir.
135
136    Args:
137      sess: the tf.Session to be used.
138      fetches: fetches of the Session.run().
139      feed_dict: feed dict for the Session.run().
140      debug_ops: name(s) of the debug ops to be used.
141      tolerate_debug_op_creation_failures: whether to tolerate debug op
142        creation failures.
143      global_step: Optional global step.
144      validate: whether to validate dumped tensors against graph.
145      expected_partition_graph_count: optional count of partition graphs to
146        assert on.
147
148    Returns:
149      1. Return values of the Session.run().
150      2. The DebugDumpDir object from the debugged run().
151    """
152
153    run_options = config_pb2.RunOptions(output_partition_graphs=True)
154    debug_utils.watch_graph(
155        run_options,
156        sess.graph,
157        debug_ops=debug_ops,
158        debug_urls=self._debug_urls(),
159        tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures,
160        global_step=global_step)
161    run_metadata = config_pb2.RunMetadata()
162    run_output = sess.run(fetches,
163                          feed_dict=feed_dict,
164                          options=run_options,
165                          run_metadata=run_metadata)
166
167    if expected_partition_graph_count is not None:
168      self.assertEqual(expected_partition_graph_count,
169                       len(run_metadata.partition_graphs))
170    return run_output, debug_data.DebugDumpDir(
171        self._dump_root, partition_graphs=run_metadata.partition_graphs,
172        validate=validate)
173
174  def _generate_dump_from_simple_addition_graph(self):
175    with session.Session(config=no_rewrite_session_config()) as sess:
176      u_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]])
177      v_init_val = np.array([[2.0], [-1.0]])
178
179      # Use node names with overlapping namespace (i.e., parent directory) to
180      # test concurrent, non-racing directory creation.
181      u_name = "u"
182      v_name = "v"
183      w_name = "w"
184
185      u_init = constant_op.constant(u_init_val, shape=[2, 2])
186      u = variables.VariableV1(u_init, name=u_name)
187      v_init = constant_op.constant(v_init_val, shape=[2, 1])
188      v = variables.VariableV1(v_init, name=v_name)
189
190      w = math_ops.matmul(u, v, name=w_name)
191
192      u.initializer.run()
193      v.initializer.run()
194
195      run_options = config_pb2.RunOptions(output_partition_graphs=True)
196      debug_urls = "file://%s" % self._dump_root
197
198      # Add debug tensor watch for u.
199      debug_utils.add_debug_tensor_watch(
200          run_options, "%s/read" % u_name, 0, debug_urls=debug_urls)
201      # Add debug tensor watch for v.
202      debug_utils.add_debug_tensor_watch(
203          run_options, "%s/read" % v_name, 0, debug_urls=debug_urls)
204
205      run_metadata = config_pb2.RunMetadata()
206
207      # Invoke Session.run().
208      sess.run(w, options=run_options, run_metadata=run_metadata)
209
210      self.assertEqual(self._expected_partition_graph_count,
211                       len(run_metadata.partition_graphs))
212
213      dump = debug_data.DebugDumpDir(
214          self._dump_root, partition_graphs=run_metadata.partition_graphs)
215
216    simple_add_results = collections.namedtuple("SimpleAddResults", [
217        "u_init_val", "v_init_val", "u", "v", "w", "u_name", "v_name", "w_name",
218        "dump"
219    ])
220    return simple_add_results(u_init_val, v_init_val, u, v, w, u_name, v_name,
221                              w_name, dump)
222
223  def testCopyNodesHaveCorrectDebugOpsAndURLsAttributeValues(self):
224    with session.Session() as sess:
225      u = variables.VariableV1(2.1, name="u")
226      v = variables.VariableV1(20.0, name="v")
227      w = math_ops.multiply(u, v, name="w")
228
229      sess.run(variables.global_variables_initializer())
230
231      run_options = config_pb2.RunOptions(output_partition_graphs=True)
232      debug_urls = self._debug_urls()
233      debug_utils.add_debug_tensor_watch(
234          run_options,
235          "u",
236          0, ["DebugNumericSummary(gated_grpc=True)", "DebugIdentity"],
237          debug_urls=debug_urls)
238      debug_utils.add_debug_tensor_watch(
239          run_options, "v", 0, ["DebugNumericSummary"], debug_urls=debug_urls)
240
241      run_metadata = config_pb2.RunMetadata()
242      r = sess.run(w, options=run_options, run_metadata=run_metadata)
243      self.assertAllClose(42.0, r)
244
245      u_copy_node_def = None
246      v_copy_node_def = None
247      for partition_graph in run_metadata.partition_graphs:
248        for node_def in partition_graph.node:
249          if debug_graphs.is_copy_node(node_def.name):
250            if node_def.name == "__copy_u_0":
251              u_copy_node_def = node_def
252            elif node_def.name == "__copy_v_0":
253              v_copy_node_def = node_def
254
255      self.assertIsNotNone(u_copy_node_def)
256      debug_ops_spec = u_copy_node_def.attr["debug_ops_spec"].list.s
257      self.assertEqual(2, len(debug_ops_spec))
258      self.assertEqual("DebugNumericSummary;%s;1" % debug_urls[0],
259                       debug_ops_spec[0].decode("utf-8"))
260      self.assertEqual("DebugIdentity;%s;0" % debug_urls[0],
261                       debug_ops_spec[1].decode("utf-8"))
262
263      self.assertIsNotNone(v_copy_node_def)
264      debug_ops_spec = v_copy_node_def.attr["debug_ops_spec"].list.s
265      self.assertEqual(1, len(debug_ops_spec))
266      self.assertEqual("DebugNumericSummary;%s;0" % debug_urls[0],
267                       debug_ops_spec[0].decode("utf-8"))
268
269  def testConcurrentDumpingToPathsWithOverlappingParentDirsWorks(self):
270    results = self._generate_dump_from_simple_addition_graph()
271    self.assertTrue(results.dump.loaded_partition_graphs())
272
273    # Since global_step is not explicitly specified, it should take its default
274    # value: -1.
275    self.assertEqual(-1, results.dump.core_metadata.global_step)
276    self.assertGreaterEqual(results.dump.core_metadata.session_run_index, 0)
277    self.assertGreaterEqual(results.dump.core_metadata.executor_step_index, 0)
278    self.assertEqual([], results.dump.core_metadata.input_names)
279    self.assertEqual([results.w.name], results.dump.core_metadata.output_names)
280    self.assertEqual([], results.dump.core_metadata.target_nodes)
281
282    # Verify the dumped tensor values for u and v.
283    self.assertEqual(2, results.dump.size)
284
285    self.assertAllClose([results.u_init_val],
286                        results.dump.get_tensors("%s/read" % results.u_name, 0,
287                                                 "DebugIdentity"))
288    self.assertAllClose([results.v_init_val],
289                        results.dump.get_tensors("%s/read" % results.v_name, 0,
290                                                 "DebugIdentity"))
291
292    self.assertGreaterEqual(
293        results.dump.get_rel_timestamps("%s/read" % results.u_name, 0,
294                                        "DebugIdentity")[0], 0)
295    self.assertGreaterEqual(
296        results.dump.get_rel_timestamps("%s/read" % results.v_name, 0,
297                                        "DebugIdentity")[0], 0)
298
299    self.assertGreater(
300        results.dump.get_dump_sizes_bytes("%s/read" % results.u_name, 0,
301                                          "DebugIdentity")[0], 0)
302    self.assertGreater(
303        results.dump.get_dump_sizes_bytes("%s/read" % results.v_name, 0,
304                                          "DebugIdentity")[0], 0)
305
306  def testGetOpTypeWorks(self):
307    results = self._generate_dump_from_simple_addition_graph()
308
309    self.assertEqual(results.u.op.type,
310                     results.dump.node_op_type(results.u_name))
311    self.assertIn(results.v.op.type, results.dump.node_op_type(results.v_name))
312    self.assertIn(results.w.op.type, results.dump.node_op_type(results.w_name))
313
314    with self.assertRaisesRegexp(
315        ValueError, r"None of the .* device\(s\) has a node named "):
316      results.dump.node_op_type("foo_bar")
317
318  def testDumpStringTensorsWorks(self):
319    with session.Session(config=no_rewrite_session_config()) as sess:
320      str1_init_val = np.array(b"abc")
321      str2_init_val = np.array(b"def")
322
323      str1_init = constant_op.constant(str1_init_val)
324      str2_init = constant_op.constant(str2_init_val)
325
326      str1_name = "str1"
327      str2_name = "str2"
328      str1 = variables.VariableV1(str1_init, name=str1_name)
329      str2 = variables.VariableV1(str2_init, name=str2_name)
330      # Concatenate str1 and str2
331      str_concat = math_ops.add(str1, str2, name="str_concat")
332
333      str1.initializer.run()
334      str2.initializer.run()
335
336      run_options = config_pb2.RunOptions(output_partition_graphs=True)
337      debug_urls = self._debug_urls()
338
339      # Add debug tensor watch for u.
340      debug_utils.add_debug_tensor_watch(
341          run_options, "%s/read" % str1_name, 0, debug_urls=debug_urls)
342      # Add debug tensor watch for v.
343      debug_utils.add_debug_tensor_watch(
344          run_options, "%s/read" % str2_name, 0, debug_urls=debug_urls)
345
346      run_metadata = config_pb2.RunMetadata()
347      sess.run(str_concat, options=run_options, run_metadata=run_metadata)
348
349      # String ops are located on CPU.
350      self.assertEqual(1, len(run_metadata.partition_graphs))
351
352      dump = debug_data.DebugDumpDir(
353          self._dump_root, partition_graphs=run_metadata.partition_graphs)
354
355      self.assertIn(str1_name, dump.nodes())
356      self.assertIn(str2_name, dump.nodes())
357
358      self.assertEqual(2, dump.size)
359
360      self.assertEqual([str1_init_val],
361                       dump.get_tensors("%s/read" % str1_name, 0,
362                                        "DebugIdentity"))
363      self.assertEqual([str2_init_val],
364                       dump.get_tensors("%s/read" % str2_name, 0,
365                                        "DebugIdentity"))
366
367      self.assertGreaterEqual(
368          dump.get_rel_timestamps("%s/read" % str1_name, 0, "DebugIdentity")[0],
369          0)
370      self.assertGreaterEqual(
371          dump.get_rel_timestamps("%s/read" % str2_name, 0, "DebugIdentity")[0],
372          0)
373
374      self.assertGreater(
375          dump.get_dump_sizes_bytes("%s/read" % str1_name, 0,
376                                    "DebugIdentity")[0], 0)
377      self.assertGreater(
378          dump.get_dump_sizes_bytes("%s/read" % str2_name, 0,
379                                    "DebugIdentity")[0], 0)
380
381  def testDumpUninitializedVariable(self):
382    op_namespace = "testDumpUninitializedVariable"
383    with session.Session() as sess:
384      u_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]])
385      s_init_val = b"str1"
386
387      u_name = "%s/u" % op_namespace
388      s_name = "%s/s" % op_namespace
389
390      u_init = constant_op.constant(u_init_val, shape=[2, 2])
391      u = variables.VariableV1(u_init, name=u_name)
392      s_init = constant_op.constant(s_init_val)
393      s = variables.VariableV1(s_init, name=s_name)
394
395      run_options = config_pb2.RunOptions(output_partition_graphs=True)
396      debug_urls = self._debug_urls()
397
398      # Add debug tensor watch for u.
399      debug_utils.add_debug_tensor_watch(
400          run_options, u_name, 0, debug_urls=debug_urls)
401      debug_utils.add_debug_tensor_watch(
402          run_options, s_name, 0, debug_urls=debug_urls)
403
404      run_metadata = config_pb2.RunMetadata()
405
406      # Initialize u and s.
407      sess.run(variables.global_variables_initializer(),
408               options=run_options,
409               run_metadata=run_metadata)
410
411      # Verify the dump file for the uninitialized value of u.
412      dump = debug_data.DebugDumpDir(
413          self._dump_root, partition_graphs=run_metadata.partition_graphs)
414
415      self.assertEqual(2, dump.size)
416      self.assertEqual(self._expected_partition_graph_count,
417                       len(run_metadata.partition_graphs))
418
419      # Verify that the variable is properly initialized by the run() call.
420      u_vals = dump.get_tensors(u_name, 0, "DebugIdentity")
421      s_vals = dump.get_tensors(s_name, 0, "DebugIdentity")
422      self.assertEqual(1, len(u_vals))
423      self.assertIsInstance(u_vals[0], debug_data.InconvertibleTensorProto)
424      self.assertFalse(u_vals[0].initialized)
425      self.assertEqual(1, len(s_vals))
426      self.assertIsInstance(s_vals[0], debug_data.InconvertibleTensorProto)
427      self.assertFalse(s_vals[0].initialized)
428
429      # Call run() again, to check that u is initialized properly.
430      self.assertAllClose(u_init_val, sess.run(u))
431      self.assertEqual(s_init_val, sess.run(s))
432
433  def testDebugWhileLoopGeneratesMultipleDumps(self):
434    with session.Session(config=no_rewrite_session_config()) as sess:
435      num_iter = 10
436
437      # "u" is the Variable being updated in the loop.
438      u_name = "testDumpToFileWhileLoop/u"
439      u_namespace = u_name.split("/")[0]
440
441      u_init_val = np.array(11.0)
442      u_init = constant_op.constant(u_init_val)
443      u = variables.VariableV1(u_init, name=u_name)
444
445      # "v" is the increment.
446      v_name = "testDumpToFileWhileLoop/v"
447      v_namespace = v_name.split("/")[0]
448
449      v_init_val = np.array(2.0)
450      v_init = constant_op.constant(v_init_val)
451      v = variables.VariableV1(v_init, name=v_name)
452
453      u.initializer.run()
454      v.initializer.run()
455
456      i = constant_op.constant(0, name="testDumpToFileWhileLoop/i")
457
458      def cond(i):
459        return math_ops.less(i, num_iter)
460
461      def body(i):
462        new_u = state_ops.assign_add(u, v)
463        new_i = math_ops.add(i, 1)
464        op = control_flow_ops.group(new_u)
465        new_i = control_flow_ops.with_dependencies([op], new_i)
466        return [new_i]
467
468      loop = control_flow_ops.while_loop(
469          cond, body, [i], parallel_iterations=10)
470
471      # Create RunOptions for debug-watching tensors
472      run_options = config_pb2.RunOptions(output_partition_graphs=True)
473      debug_urls = self._debug_urls()
474
475      # Add debug tensor watch for u.
476      debug_utils.add_debug_tensor_watch(
477          run_options, u_name, 0, debug_urls=debug_urls)
478      # Add debug tensor watch for v.
479      debug_utils.add_debug_tensor_watch(
480          run_options, "%s/read" % v_name, 0, debug_urls=debug_urls)
481      # Add debug tensor watch for while/Identity.
482      debug_utils.add_debug_tensor_watch(
483          run_options, "while/Identity", 0, debug_urls=debug_urls)
484      # Add debug tensor watch for while/Add/y.
485      debug_utils.add_debug_tensor_watch(
486          run_options, "while/Add/y", 0, debug_urls=debug_urls)
487
488      run_metadata = config_pb2.RunMetadata()
489      r = sess.run(loop, options=run_options, run_metadata=run_metadata)
490
491      self.assertEqual(self._expected_partition_graph_count,
492                       len(run_metadata.partition_graphs))
493
494      self.assertEqual(num_iter, r)
495      u_val_final = sess.run(u)
496      self.assertAllClose(u_init_val + num_iter * v_init_val, u_val_final)
497
498      # Verify dump files
499      self.assertTrue(os.path.isdir(self._dump_root))
500
501      u_glob_out = glob.glob(os.path.join(self._dump_root, "*", u_namespace))
502      v_glob_out = glob.glob(os.path.join(
503          self._dump_root, "*", v_namespace, "v"))
504      self.assertTrue(os.path.isdir(u_glob_out[0]))
505      self.assertTrue(os.path.isdir(v_glob_out[0]))
506
507      dump = debug_data.DebugDumpDir(
508          self._dump_root, partition_graphs=run_metadata.partition_graphs)
509
510      # Expected dumped tensors: u, v/read, 10 iterations of while/Identity,
511      # and 10 iterations of while/Add/y.
512      self.assertEqual(1 + 1 + num_iter + num_iter, dump.size)
513
514      # Verify tensor values.
515      self.assertAllClose([u_init_val],
516                          dump.get_tensors(u_name, 0, "DebugIdentity"))
517      self.assertAllClose([v_init_val],
518                          dump.get_tensors("%s/read" % v_name, 0,
519                                           "DebugIdentity"))
520
521      while_id_tensors = dump.get_tensors("while/Identity", 0, "DebugIdentity")
522      self.assertEqual(10, len(while_id_tensors))
523      for k in xrange(len(while_id_tensors)):
524        self.assertAllClose(np.array(k), while_id_tensors[k])
525
526      # Verify ascending timestamps from the while loops.
527      while_id_rel_timestamps = dump.get_rel_timestamps("while/Identity", 0,
528                                                        "DebugIdentity")
529      while_id_dump_sizes_bytes = dump.get_dump_sizes_bytes("while/Identity", 0,
530                                                            "DebugIdentity")
531      self.assertEqual(10, len(while_id_rel_timestamps))
532      prev_rel_time = 0
533      prev_dump_size_bytes = while_id_dump_sizes_bytes[0]
534      for rel_time, dump_size_bytes in zip(while_id_rel_timestamps,
535                                           while_id_dump_sizes_bytes):
536        self.assertGreaterEqual(rel_time, prev_rel_time)
537        self.assertEqual(dump_size_bytes, prev_dump_size_bytes)
538        prev_rel_time = rel_time
539        prev_dump_size_bytes = dump_size_bytes
540
541      # Test querying debug watch keys from node name.
542      watch_keys = dump.debug_watch_keys("while/Identity")
543      self.assertEqual(["while/Identity:0:DebugIdentity"], watch_keys)
544
545      # Test querying debug datum instances from debug watch key.
546      self.assertEqual(10, len(dump.watch_key_to_data(watch_keys[0])))
547      self.assertEqual([], dump.watch_key_to_data("foo"))
548
549  def testDebugWhileLoopWatchingWholeGraphWorks(self):
550    with session.Session() as sess:
551      loop_body = lambda i: math_ops.add(i, 2)
552      loop_cond = lambda i: math_ops.less(i, 16)
553
554      i = constant_op.constant(10, name="i")
555      loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])
556
557      loop_result, dump = self._debug_run_and_get_dump(sess, loop)
558      self.assertEqual(16, loop_result)
559
560      self.assertEqual(
561          [[10]], dump.get_tensors("while/Enter", 0, "DebugIdentity"))
562      self.assertEqual(
563          [[12], [14], [16]],
564          dump.get_tensors("while/NextIteration", 0, "DebugIdentity"))
565
566  def testDebugTrainingDynamicRNNWorks(self):
567    with session.Session() as sess:
568      input_size = 3
569      state_size = 2
570      time_steps = 4
571      batch_size = 2
572
573      input_values = np.random.randn(time_steps, batch_size, input_size)
574      sequence_length = np.random.randint(0, time_steps, size=batch_size)
575      concat_inputs = array_ops.placeholder(
576          dtypes.float32, shape=(time_steps, batch_size, input_size))
577
578      outputs_dynamic, _ = rnn.dynamic_rnn(
579          _RNNCellForTest(input_size, state_size),
580          inputs=concat_inputs,
581          sequence_length=sequence_length,
582          time_major=True,
583          dtype=dtypes.float32)
584      toy_loss = math_ops.reduce_sum(outputs_dynamic * outputs_dynamic)
585      train_op = gradient_descent.GradientDescentOptimizer(
586          learning_rate=0.1).minimize(toy_loss, name="train_op")
587
588      sess.run(variables.global_variables_initializer())
589
590      run_options = config_pb2.RunOptions(output_partition_graphs=True)
591      debug_utils.watch_graph_with_blacklists(
592          run_options,
593          sess.graph,
594          node_name_regex_blacklist="(.*rnn/while/.*|.*TensorArray.*)",
595          debug_urls=self._debug_urls())
596      # b/36870549: Nodes with these name patterns need to be excluded from
597      # tfdbg in order to prevent MSAN warnings of uninitialized Tensors
598      # under both file:// and grpc:// debug URL schemes.
599
600      run_metadata = config_pb2.RunMetadata()
601      sess.run(train_op, feed_dict={concat_inputs: input_values},
602               options=run_options, run_metadata=run_metadata)
603
604      debug_data.DebugDumpDir(
605          self._dump_root, partition_graphs=run_metadata.partition_graphs)
606
607  def testDebugCondWatchingWholeGraphWorks(self):
608    with session.Session() as sess:
609      x = variables.VariableV1(10.0, name="x")
610      y = variables.VariableV1(20.0, name="y")
611      cond = control_flow_ops.cond(
612          x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
613
614      sess.run(variables.global_variables_initializer())
615
616      cond_result, dump = self._debug_run_and_get_dump(sess, cond)
617      self.assertEqual(21, cond_result)
618
619      self.assertAllClose(
620          [21.0], dump.get_tensors("cond/Merge", 0, "DebugIdentity"))
621
622  def testFindNodesWithBadTensorValues(self):
623    with session.Session() as sess:
624      u_name = "testFindNodesWithBadTensorValues/u"
625      v_name = "testFindNodesWithBadTensorValues/v"
626      w_name = "testFindNodesWithBadTensorValues/w"
627      x_name = "testFindNodesWithBadTensorValues/x"
628      y_name = "testFindNodesWithBadTensorValues/y"
629      z_name = "testFindNodesWithBadTensorValues/z"
630
631      u_init = constant_op.constant([2.0, 4.0])
632      u = variables.VariableV1(u_init, name=u_name)
633      v_init = constant_op.constant([2.0, 1.0])
634      v = variables.VariableV1(v_init, name=v_name)
635
636      # Expected output: [0.0, 3.0]
637      w = math_ops.subtract(u, v, name=w_name)
638
639      # Expected output: [inf, 1.3333]
640      x = math_ops.div(u, w, name=x_name)
641
642      # Expected output: [nan, 4.0]
643      y = math_ops.multiply(w, x, name=y_name)
644
645      z = math_ops.multiply(y, y, name=z_name)
646
647      u.initializer.run()
648      v.initializer.run()
649
650      _, dump = self._debug_run_and_get_dump(
651          sess, z,
652          expected_partition_graph_count=self._expected_partition_graph_count)
653
654      def has_bad_value(_, tensor):
655        return np.any(np.isnan(tensor)) or np.any(np.isinf(tensor))
656
657      # Find all "offending tensors".
658      bad_data = dump.find(has_bad_value)
659
660      # Verify that the nodes with bad values are caught through running find
661      # on the debug dump.
662      self.assertEqual(3, len(bad_data))
663      self.assertEqual(x_name, bad_data[0].node_name)
664      self.assertEqual(y_name, bad_data[1].node_name)
665      self.assertEqual(z_name, bad_data[2].node_name)
666
667      # Test first_n kwarg of find(): Find the first offending tensor.
668      first_bad_datum = dump.find(has_bad_value, first_n=1)
669
670      self.assertEqual(1, len(first_bad_datum))
671      self.assertEqual(x_name, first_bad_datum[0].node_name)
672
673  def testFindInfOrNanWithOpNameExclusion(self):
674    with session.Session() as sess:
675      u_name = "testFindInfOrNanWithOpNameExclusion/u"
676      v_name = "testFindInfOrNanWithOpNameExclusion/v"
677      w_name = "testFindInfOrNanWithOpNameExclusion/w"
678      x_name = "testFindInfOrNanWithOpNameExclusion/x"
679      y_name = "testFindInfOrNanWithOpNameExclusion/y"
680      z_name = "testFindInfOrNanWithOpNameExclusion/z"
681
682      u_init = constant_op.constant([2.0, 4.0])
683      u = variables.VariableV1(u_init, name=u_name)
684      v_init = constant_op.constant([2.0, 1.0])
685      v = variables.VariableV1(v_init, name=v_name)
686
687      # Expected output: [0.0, 3.0]
688      w = math_ops.subtract(u, v, name=w_name)
689
690      # Expected output: [inf, 1.3333]
691      x = math_ops.div(u, w, name=x_name)
692
693      # Expected output: [nan, 4.0]
694      y = math_ops.multiply(w, x, name=y_name)
695
696      z = math_ops.multiply(y, y, name=z_name)
697
698      u.initializer.run()
699      v.initializer.run()
700
701      _, dump = self._debug_run_and_get_dump(
702          sess, z,
703          expected_partition_graph_count=self._expected_partition_graph_count)
704
705      # Find all "offending tensors".
706      bad_data = dump.find(debug_data.has_inf_or_nan,
707                           exclude_node_names=".*/x$")
708
709      # Verify that the nodes with bad values are caught through running find
710      # on the debug dump.
711      self.assertEqual(2, len(bad_data))
712      # Assert that the node `x` should have been excluded.
713      self.assertEqual(y_name, bad_data[0].node_name)
714      self.assertEqual(z_name, bad_data[1].node_name)
715
716      first_bad_datum = dump.find(
717          debug_data.has_inf_or_nan, first_n=1, exclude_node_names=".*/x$")
718
719      self.assertEqual(1, len(first_bad_datum))
720      self.assertEqual(y_name, first_bad_datum[0].node_name)
721
722  def _session_run_for_graph_structure_lookup(self):
723    with session.Session(config=no_rewrite_session_config()) as sess:
724      u_name = "testDumpGraphStructureLookup/u"
725      v_name = "testDumpGraphStructureLookup/v"
726      w_name = "testDumpGraphStructureLookup/w"
727
728      u_init = constant_op.constant([2.0, 4.0])
729      u = variables.VariableV1(u_init, name=u_name)
730      v = math_ops.add(u, u, name=v_name)
731      w = math_ops.add(v, v, name=w_name)
732
733      u.initializer.run()
734
735      _, dump = self._debug_run_and_get_dump(
736          sess, w,
737          expected_partition_graph_count=self._expected_partition_graph_count)
738
739    return u_name, v_name, w_name, dump
740
741  def testGraphStructureLookupGivesDevicesAndNodesInfo(self):
742    u_name, _, _, dump = self._session_run_for_graph_structure_lookup()
743
744    # Test num_devices().
745    self.assertEqual(self._expected_num_devices, len(dump.devices()))
746
747    # Test node_device().
748    self.assertEqual(self._main_device, dump.node_device(u_name))
749
750    with self.assertRaisesRegexp(ValueError,
751                                 "does not exist in partition graphs"):
752      dump.node_device(u_name + "foo")
753
754    # Test node_exists().
755    self.assertTrue(dump.node_exists(u_name))
756    self.assertTrue(dump.node_exists(u_name + "/read"))
757    self.assertFalse(dump.node_exists(u_name + "/read" + "/foo"))
758
759  def testGraphStructureLookupGivesNodesAndAttributes(self):
760    u_name, _, _, dump = self._session_run_for_graph_structure_lookup()
761
762    u_read_name = u_name + "/read"
763
764    # Test node name list lookup of the DebugDumpDir object.
765    if test_util.gpu_device_name():
766      node_names = dump.nodes(
767          device_name="/job:localhost/replica:0/task:0/device:GPU:0")
768    else:
769      node_names = dump.nodes()
770    self.assertTrue(u_name in node_names)
771    self.assertTrue(u_read_name in node_names)
772
773    # Test querying node attributes.
774    u_attr = dump.node_attributes(u_name)
775    self.assertEqual(dtypes.float32, u_attr["dtype"].type)
776    self.assertEqual(1, len(u_attr["shape"].shape.dim))
777    self.assertEqual(2, u_attr["shape"].shape.dim[0].size)
778
779    with self.assertRaisesRegexp(
780        ValueError, r"None of the .* device\(s\) has a node named "):
781      dump.node_attributes("foo")
782
783  def testGraphStructureLookupGivesDebugWatchKeys(self):
784    u_name, v_name, w_name, dump = (
785        self._session_run_for_graph_structure_lookup())
786
787    # Test querying the debug watch keys with node names.
788    self.assertEqual(["%s:0:DebugIdentity" % u_name],
789                     dump.debug_watch_keys(u_name))
790    self.assertEqual(["%s:0:DebugIdentity" % v_name],
791                     dump.debug_watch_keys(v_name))
792    self.assertEqual(["%s:0:DebugIdentity" % w_name],
793                     dump.debug_watch_keys(w_name))
794    self.assertEqual([], dump.debug_watch_keys("foo"))
795
796    # Test querying debug datum instances from debug watch.
797    u_data = dump.watch_key_to_data(dump.debug_watch_keys(u_name)[0])
798    self.assertEqual(1, len(u_data))
799    self.assertEqual(u_name, u_data[0].node_name)
800    self.assertEqual(0, u_data[0].output_slot)
801    self.assertEqual("DebugIdentity", u_data[0].debug_op)
802    self.assertGreaterEqual(u_data[0].timestamp, 0)
803    self.assertEqual([], dump.watch_key_to_data("foo"))
804
805  def testGraphStructureLookupGivesNodeInputsAndRecipients(self):
806    u_name, v_name, w_name, dump = (
807        self._session_run_for_graph_structure_lookup())
808
809    u_read_name = u_name + "/read"
810
811    # Test the inputs lookup of the DebugDumpDir object.
812    self.assertEqual([], dump.node_inputs(u_name))
813    self.assertEqual([u_name], dump.node_inputs(u_read_name))
814    self.assertEqual([u_read_name] * 2, dump.node_inputs(v_name))
815    self.assertEqual([v_name] * 2, dump.node_inputs(w_name))
816
817    self.assertEqual([], dump.node_inputs(u_name, is_control=True))
818    self.assertEqual([], dump.node_inputs(u_read_name, is_control=True))
819    self.assertEqual([], dump.node_inputs(v_name, is_control=True))
820    self.assertEqual([], dump.node_inputs(w_name, is_control=True))
821
822    # Test the outputs recipient lookup of the DebugDumpDir object.
823    self.assertTrue(u_read_name in dump.node_recipients(u_name))
824    self.assertEqual(2, dump.node_recipients(u_read_name).count(v_name))
825    self.assertEqual(2, dump.node_recipients(v_name).count(w_name))
826
827    self.assertEqual([], dump.node_recipients(u_name, is_control=True))
828    self.assertEqual([], dump.node_recipients(u_read_name, is_control=True))
829    self.assertEqual([], dump.node_recipients(v_name, is_control=True))
830    self.assertEqual([], dump.node_recipients(w_name, is_control=True))
831
832    # Test errors raised on invalid node names.
833    with self.assertRaisesRegexp(
834        ValueError, r"None of the .* device\(s\) has a node named "):
835      dump.node_inputs(u_name + "foo")
836    with self.assertRaisesRegexp(
837        ValueError, r"None of the .* device\(s\) has a node named "):
838      dump.node_recipients(u_name + "foo")
839
840    # Test transitive_inputs().
841    self.assertEqual([], dump.transitive_inputs(u_name))
842    self.assertEqual([u_name], dump.transitive_inputs(u_read_name))
843    self.assertEqual(
844        set([u_name, u_read_name]), set(dump.transitive_inputs(v_name)))
845    self.assertEqual(
846        set([u_name, u_read_name, v_name]), set(dump.transitive_inputs(w_name)))
847
848    with self.assertRaisesRegexp(
849        ValueError, r"None of the .* device\(s\) has a node named "):
850      dump.transitive_inputs(u_name + "foo")
851
852  def testGraphStructureLookupWithoutPartitionGraphsDoesNotErrorOut(self):
853    _, _, _, dump = self._session_run_for_graph_structure_lookup()
854
855    # Now load the dump again, without the partition graphs, so we can check
856    # errors are not raised because the partition graphs are loaded from the
857    # dump directory.
858    dump = debug_data.DebugDumpDir(self._dump_root, validate=False)
859    self.assertTrue(dump.loaded_partition_graphs())
860
861  def testGraphPathFindingOnControlEdgesWorks(self):
862    with session.Session(config=no_rewrite_session_config()) as sess:
863      v1 = variables.VariableV1(1.0, name="v1")
864      v2 = variables.VariableV1(2.0, name="v2")
865      v3 = variables.VariableV1(3.0, name="v3")
866      a = math_ops.add(v1, v2, name="a")
867      with ops.control_dependencies([a]):
868        c = math_ops.subtract(v3, v3, name="c")
869
870      sess.run(variables.global_variables_initializer())
871      _, dump = self._debug_run_and_get_dump(sess, c)
872
873      self.assertEqual(["v1", "v1/read", "a", "c"],
874                       dump.find_some_path("v1", "c"))
875      self.assertIsNone(dump.find_some_path("v1", "c", include_control=False))
876
877  def testGraphPathFindingReverseRefEdgeWorks(self):
878    with session.Session(config=no_rewrite_session_config()) as sess:
879      v = variables.VariableV1(10.0, name="v")
880      delta = variables.VariableV1(1.0, name="delta")
881      inc_v = state_ops.assign_add(v, delta, name="inc_v")
882
883      sess.run(variables.global_variables_initializer())
884      _, dump = self._debug_run_and_get_dump(sess, inc_v)
885
886      self.assertEqual(
887          ["delta", "delta/read", "inc_v", "v"],
888          dump.find_some_path("delta", "v", include_reversed_ref=True))
889      self.assertIsNone(dump.find_some_path("delta", "v"))
890
891  def testCausalityCheckOnDumpsDetectsWrongTemporalOrder(self):
892    with session.Session(config=no_rewrite_session_config()) as sess:
893      u_name = "testDumpCausalityCheck/u"
894      v_name = "testDumpCausalityCheck/v"
895      w_name = "testDumpCausalityCheck/w"
896
897      u_init = constant_op.constant([2.0, 4.0])
898      u = variables.VariableV1(u_init, name=u_name)
899      v = math_ops.add(u, u, name=v_name)
900      w = math_ops.add(v, v, name=w_name)
901
902      u.initializer.run()
903
904      run_options = config_pb2.RunOptions(output_partition_graphs=True)
905      debug_utils.watch_graph(
906          run_options,
907          sess.graph,
908          debug_ops=["DebugIdentity"],
909          debug_urls=self._debug_urls())
910
911      run_metadata = config_pb2.RunMetadata()
912      sess.run(w, options=run_options, run_metadata=run_metadata)
913
914      self.assertEqual(self._expected_partition_graph_count,
915                       len(run_metadata.partition_graphs))
916
917      # First, loading the original dump without supplying the
918      # partition_graphs should not cause a LookupError, validation occurs
919      # only with partition_graphs loaded.
920      debug_data.DebugDumpDir(self._dump_root)
921
922      # Now, loading the original dump with partition graphs supplied should
923      # succeed. The validation should pass quietly.
924      dump = debug_data.DebugDumpDir(
925          self._dump_root, partition_graphs=run_metadata.partition_graphs)
926
927      # Get the dump file names and compute their timestamps.
928      self.assertEqual(
929          1, len(dump.get_tensor_file_paths(v_name, 0, "DebugIdentity")))
930      v_file_path = dump.get_tensor_file_paths(v_name, 0, "DebugIdentity")[0]
931
932      self.assertEqual(
933          1, len(dump.get_tensor_file_paths(w_name, 0, "DebugIdentity")))
934      w_file_path = dump.get_tensor_file_paths(w_name, 0, "DebugIdentity")[0]
935
936      v_timestamp = int(v_file_path[v_file_path.rindex("_") + 1:])
937      w_timestamp = int(w_file_path[w_file_path.rindex("_") + 1:])
938
939      # Swap and slightly shift the time stamps of the last two dumped tensors,
940      # to simulate "causality violation", which can happen if the dump
941      # directory contains incomplete data and/or mixes data from different
942      # Session.run() calls.
943      v_file_path_1 = v_file_path[:v_file_path.rindex(
944          "_")] + "_%d" % w_timestamp
945      w_file_path_1 = w_file_path[:w_file_path.rindex("_")] + "_%d" % (
946          v_timestamp - 1)
947
948      os.rename(v_file_path, v_file_path_1)
949      os.rename(w_file_path, w_file_path_1)
950
951      # Load the dump directory again. Now a ValueError is expected to be
952      # raised due to the timestamp swap.
953      with self.assertRaisesRegexp(ValueError, "Causality violated"):
954        dump = debug_data.DebugDumpDir(
955            self._dump_root, partition_graphs=run_metadata.partition_graphs)
956
957      # Loading the dump directory with kwarg "validate" set explicitly to
958      # False should get rid of the error.
959      dump = debug_data.DebugDumpDir(
960          self._dump_root,
961          partition_graphs=run_metadata.partition_graphs,
962          validate=False)
963
964      # Next, set the two times stamps to be the same, which should be fine.
965      v_file_path_2 = v_file_path[:v_file_path.rindex(
966          "_")] + "_%d" % w_timestamp
967      w_file_path_2 = w_file_path[:w_file_path.rindex(
968          "_")] + "_%d" % w_timestamp
969
970      os.rename(v_file_path_1, v_file_path_2)
971      os.rename(w_file_path_1, w_file_path_2)
972
973      debug_data.DebugDumpDir(
974          self._dump_root, partition_graphs=run_metadata.partition_graphs)
975
976  def testWatchingOnlyOneOfTwoOutputSlotsDoesNotLeadToCausalityFailure(self):
977    with session.Session() as sess:
978      x_name = "oneOfTwoSlots/x"
979      u_name = "oneOfTwoSlots/u"
980      v_name = "oneOfTwoSlots/v"
981      w_name = "oneOfTwoSlots/w"
982      y_name = "oneOfTwoSlots/y"
983
984      x = variables.VariableV1([1, 3, 3, 7], dtype=dtypes.int32, name=x_name)
985      sess.run(x.initializer)
986
987      unique_x, indices, _ = array_ops.unique_with_counts(x, name=u_name)
988
989      v = math_ops.add(unique_x, unique_x, name=v_name)
990      w = math_ops.add(indices, indices, name=w_name)
991      y = math_ops.add(w, w, name=y_name)
992
993      run_options = config_pb2.RunOptions(output_partition_graphs=True)
994      # Watch only the first output slot of u, even though it has two output
995      # slots.
996      debug_utils.add_debug_tensor_watch(
997          run_options, u_name, 0, debug_urls=self._debug_urls())
998      debug_utils.add_debug_tensor_watch(
999          run_options, w_name, 0, debug_urls=self._debug_urls())
1000      debug_utils.add_debug_tensor_watch(
1001          run_options, y_name, 0, debug_urls=self._debug_urls())
1002
1003      run_metadata = config_pb2.RunMetadata()
1004      sess.run([v, y], options=run_options, run_metadata=run_metadata)
1005
1006      dump = debug_data.DebugDumpDir(
1007          self._dump_root,
1008          partition_graphs=run_metadata.partition_graphs,
1009          validate=True)
1010
1011      self.assertAllClose([1, 3, 7],
1012                          dump.get_tensors(u_name, 0, "DebugIdentity")[0])
1013
1014  def testOutputSlotWithoutOutgoingEdgeCanBeWatched(self):
1015    """Test watching output slots not attached to any outgoing edges."""
1016
1017    with session.Session(config=no_rewrite_session_config()) as sess:
1018      u_init_val = np.array([[5.0, 3.0], [-1.0, 0.0]])
1019      u = constant_op.constant(u_init_val, shape=[2, 2], name="u")
1020
1021      # Create a control edge from a node with an output: From u to z.
1022      # Node u will get executed only because of the control edge. The output
1023      # tensor u:0 is not attached to any outgoing edge in the graph. This test
1024      # checks that the debugger can watch such a tensor.
1025      with ops.control_dependencies([u]):
1026        z = control_flow_ops.no_op(name="z")
1027
1028      _, dump = self._debug_run_and_get_dump(sess, z)
1029
1030      # Assert that the DebugIdentity watch on u works properly.
1031      self.assertEqual(1, len(dump.dumped_tensor_data))
1032      datum = dump.dumped_tensor_data[0]
1033      self.assertEqual("u", datum.node_name)
1034      self.assertEqual(0, datum.output_slot)
1035      self.assertEqual("DebugIdentity", datum.debug_op)
1036      self.assertAllClose([[5.0, 3.0], [-1.0, 0.0]], datum.get_tensor())
1037
1038  def testWatchingVariableUpdateOpsSeesUpdatedValues(self):
1039    """Watch output slots on Variable-updating ops, with no emitted edges."""
1040
1041    with session.Session(config=no_rewrite_session_config()) as sess:
1042      u_init = constant_op.constant(10.0)
1043      u = variables.VariableV1(u_init, name="gdo/u")
1044      v_init = constant_op.constant(20.0)
1045      v = variables.VariableV1(v_init, name="gdo/v")
1046
1047      w = math_ops.multiply(u, v, name="gdo/w")
1048      # gdo stands for GradientDescentOptimizer.
1049
1050      train_op = gradient_descent.GradientDescentOptimizer(
1051          learning_rate=0.1).minimize(
1052              w, name="gdo/train")
1053
1054      u.initializer.run()
1055      v.initializer.run()
1056
1057      _, dump = self._debug_run_and_get_dump(sess, train_op)
1058
1059      update_u_data = dump.watch_key_to_data(
1060          "gdo/train/update_gdo/u/ApplyGradientDescent:0:DebugIdentity")
1061      self.assertEqual(1, len(update_u_data))
1062
1063      # Gradient descent on u: w = u * v, so dw / du = v.
1064      # Updated value of u should be:
1065      #   10.0 - learning_rate * v = 10.0 - 0.1 * 20.0 = 8.0
1066      self.assertAllClose(8.0, update_u_data[0].get_tensor())
1067
1068      update_v_data = dump.watch_key_to_data(
1069          "gdo/train/update_gdo/v/ApplyGradientDescent:0:DebugIdentity")
1070      self.assertEqual(1, len(update_v_data))
1071
1072      # Gradient descent on u: w = u * v, so dw / dv = u.
1073      # Updated value of u should be:
1074      #   20.0 - learning_rate * u = 20.0 - 0.1 * 10.0 = 19.0
1075      self.assertAllClose(19.0, update_v_data[0].get_tensor())
1076
1077      # Verify that the Variables u and v are updated properly.
1078      self.assertAllClose(8.0, sess.run(u))
1079      self.assertAllClose(19.0, sess.run(v))
1080
1081  def testAllowsWatchingUnconnectedOutputTensor(self):
1082    """Watch an output slot not emitting any edges.
1083
1084    (Not even control edges from the node.)
1085    """
1086
1087    with session.Session() as sess:
1088      x_init = constant_op.constant([2, 2, 3, 5, 5])
1089      x = variables.VariableV1(x_init, name="unconnected/x")
1090
1091      # The UniqueOp (tf.unique) has two output slots. Use only slot 0 in the
1092      # graph. Let the debugger watch the unused slot 1.
1093      unique_x, _ = array_ops.unique(x, name="unconnected/unique_x")
1094      y = math_ops.add(unique_x, [0, 1, 2], name="unconnected/y")
1095
1096      x.initializer.run()
1097
1098      # Verify that only slot 0 of unique_x has recipients, while slot 1 of the
1099      # same node does not have recipients.
1100      unique_x_slot_0_recipients = []
1101      unique_x_slot_1_recipients = []
1102      for op in sess.graph.get_operations():
1103        for inp in op.inputs:
1104          if inp.name == "unconnected/unique_x:0":
1105            unique_x_slot_0_recipients.append(op.name)
1106          elif inp.name == "unconnected/unique_x:1":
1107            unique_x_slot_1_recipients.append(op.name)
1108
1109      self.assertEqual(["unconnected/y"], unique_x_slot_0_recipients)
1110      self.assertEqual([], unique_x_slot_1_recipients)
1111
1112      y_result, dump = self._debug_run_and_get_dump(sess, y)
1113      self.assertAllClose([2, 4, 7], y_result)
1114
1115      # Assert that the connected slot (slot 0) is dumped properly.
1116      unique_x_slot_0_dumps = dump.watch_key_to_data(
1117          "unconnected/unique_x:0:DebugIdentity")
1118      self.assertEqual(1, len(unique_x_slot_0_dumps))
1119      self.assertEqual("unconnected/unique_x",
1120                       unique_x_slot_0_dumps[0].node_name)
1121      self.assertEqual(0, unique_x_slot_0_dumps[0].output_slot)
1122      self.assertAllClose([2, 3, 5], unique_x_slot_0_dumps[0].get_tensor())
1123
1124      # Assert that the unconnected slot (slot 1) is dumped properly.
1125      unique_x_slot_1_dumps = dump.watch_key_to_data(
1126          "unconnected/unique_x:1:DebugIdentity")
1127      self.assertEqual(1, len(unique_x_slot_1_dumps))
1128      self.assertEqual("unconnected/unique_x",
1129                       unique_x_slot_1_dumps[0].node_name)
1130      self.assertEqual(1, unique_x_slot_1_dumps[0].output_slot)
1131      self.assertAllClose([0, 0, 1, 2, 2],
1132                          unique_x_slot_1_dumps[0].get_tensor())
1133
1134  def testSuccessiveDebuggingRunsIncreasesCounters(self):
1135    """Test repeated Session.run() calls with debugger increments counters."""
1136
1137    with session.Session() as sess:
1138      ph = array_ops.placeholder(dtypes.float32, name="successive/ph")
1139      x = array_ops.transpose(ph, name="mismatch/x")
1140      y = array_ops.squeeze(ph, name="mismatch/y")
1141
1142      _, dump1 = self._debug_run_and_get_dump(
1143          sess, x, feed_dict={ph: np.array([[7.0, 8.0]])}, global_step=1)
1144      self.assertEqual(1, dump1.core_metadata.global_step)
1145      self.assertGreaterEqual(dump1.core_metadata.session_run_index, 0)
1146      self.assertEqual(0, dump1.core_metadata.executor_step_index)
1147      self.assertEqual([ph.name], dump1.core_metadata.input_names)
1148      self.assertEqual([x.name], dump1.core_metadata.output_names)
1149      self.assertEqual([], dump1.core_metadata.target_nodes)
1150      shutil.rmtree(self._dump_root)
1151
1152      # Calling run() with the same feed, same output and same debug watch
1153      # options should increment both session_run_index and
1154      # executor_step_index.
1155      _, dump2 = self._debug_run_and_get_dump(
1156          sess, x, feed_dict={ph: np.array([[7.0, 8.0]])}, global_step=2)
1157      self.assertEqual(2, dump2.core_metadata.global_step)
1158      self.assertEqual(dump1.core_metadata.session_run_index + 1,
1159                       dump2.core_metadata.session_run_index)
1160      self.assertEqual(dump1.core_metadata.executor_step_index + 1,
1161                       dump2.core_metadata.executor_step_index)
1162      self.assertEqual([ph.name], dump2.core_metadata.input_names)
1163      self.assertEqual([x.name], dump2.core_metadata.output_names)
1164      self.assertEqual([], dump2.core_metadata.target_nodes)
1165      shutil.rmtree(self._dump_root)
1166
1167      run_options = config_pb2.RunOptions(output_partition_graphs=True)
1168      debug_utils.watch_graph(
1169          run_options, sess.graph, debug_urls=self._debug_urls(), global_step=3)
1170
1171      # Calling run() with a different output should increment
1172      # session_run_index, but not executor_step_index.
1173      _, dump3 = self._debug_run_and_get_dump(
1174          sess, y, feed_dict={ph: np.array([[7.0, 8.0]])}, global_step=3)
1175      self.assertEqual(3, dump3.core_metadata.global_step)
1176      self.assertEqual(dump2.core_metadata.session_run_index + 1,
1177                       dump3.core_metadata.session_run_index)
1178      self.assertEqual(0, dump3.core_metadata.executor_step_index)
1179      self.assertEqual([ph.name], dump3.core_metadata.input_names)
1180      self.assertEqual([y.name], dump3.core_metadata.output_names)
1181      self.assertEqual([], dump3.core_metadata.target_nodes)
1182
1183  def testDebuggingDuringOpError(self):
1184    """Test the debug tensor dumping when error occurs in graph runtime."""
1185
1186    with session.Session() as sess:
1187      ph = array_ops.placeholder(dtypes.float32, name="mismatch/ph")
1188      x = array_ops.transpose(ph, name="mismatch/x")
1189      m = constant_op.constant(
1190          np.array(
1191              [[1.0, 2.0]], dtype=np.float32), name="mismatch/m")
1192      y = math_ops.matmul(m, x, name="mismatch/y")
1193
1194      run_options = config_pb2.RunOptions(output_partition_graphs=True)
1195      debug_utils.watch_graph(
1196          run_options,
1197          sess.graph,
1198          debug_ops=["DebugIdentity"],
1199          debug_urls=self._debug_urls())
1200
1201      with self.assertRaises(errors.OpError):
1202        sess.run(y,
1203                 options=run_options,
1204                 feed_dict={ph: np.array([[-3.0], [0.0]])})
1205
1206      dump = debug_data.DebugDumpDir(self._dump_root)
1207
1208      self.assertGreaterEqual(dump.core_metadata.session_run_index, 0)
1209      self.assertGreaterEqual(dump.core_metadata.executor_step_index, 0)
1210      self.assertEqual([ph.name], dump.core_metadata.input_names)
1211      self.assertEqual([y.name], dump.core_metadata.output_names)
1212      self.assertEqual([], dump.core_metadata.target_nodes)
1213
1214      # Despite the fact that the run() call errored out and partition_graphs
1215      # are not available via run_metadata, the partition graphs should still
1216      # have been loaded from the dump directory.
1217      self.assertTrue(dump.loaded_partition_graphs())
1218
1219      m_dumps = dump.watch_key_to_data("mismatch/m:0:DebugIdentity")
1220      self.assertEqual(1, len(m_dumps))
1221      self.assertAllClose(np.array([[1.0, 2.0]]), m_dumps[0].get_tensor())
1222
1223      x_dumps = dump.watch_key_to_data("mismatch/x:0:DebugIdentity")
1224      self.assertEqual(1, len(x_dumps))
1225      self.assertAllClose(np.array([[-3.0, 0.0]]), x_dumps[0].get_tensor())
1226
1227  def testDebugNumericSummaryOnInitializedTensorGivesCorrectResult(self):
1228    with session.Session(config=no_rewrite_session_config()) as sess:
1229      a = variables.VariableV1(
1230          [
1231              np.nan, np.nan, 0.0, 0.0, 0.0, -1.0, -3.0, 3.0, 7.0, -np.inf,
1232              -np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.nan, np.nan
1233          ],
1234          dtype=np.float32,
1235          name="numeric_summary/a")
1236      b = variables.VariableV1(
1237          [0.0] * 18, dtype=np.float32, name="numeric_summary/b")
1238      c = math_ops.add(a, b, name="numeric_summary/c")
1239
1240      sess.run(variables.global_variables_initializer())
1241
1242      _, dump = self._debug_run_and_get_dump(
1243          sess, c, debug_ops=["DebugNumericSummary"])
1244      self.assertTrue(dump.loaded_partition_graphs())
1245
1246      self.assertAllClose([[
1247          1.0, 18.0, 4.0, 2.0, 2.0, 3.0, 2.0, 5.0, -3.0, 7.0, 0.85714286,
1248          8.97959184, 1.0, 1.0, 18.0
1249      ]], dump.get_tensors("numeric_summary/a/read", 0, "DebugNumericSummary"))
1250
1251  def testDebugNumericSummaryOnUninitializedTensorGivesCorrectResult(self):
1252    with session.Session() as sess:
1253      a = variables.VariableV1(
1254          [42], dtype=np.float32, name="numeric_summary_uninit/a")
1255
1256      _, dump = self._debug_run_and_get_dump(
1257          sess, a.initializer, debug_ops=["DebugNumericSummary"])
1258
1259      self.assertTrue(dump.loaded_partition_graphs())
1260
1261      # DebugNumericSummary output should reflect the uninitialized state of
1262      # the watched tensor.
1263      numeric_summary = dump.get_tensors("numeric_summary_uninit/a", 0,
1264                                         "DebugNumericSummary")[0]
1265      self.assertAllClose([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1266                          numeric_summary[0:8])
1267      # Check dtype (index 12), ndims (index 13) and dimension sizes (index
1268      # 14+).
1269      self.assertAllClose([1.0, 1.0, 1.0], numeric_summary[12:])
1270      self.assertTrue(np.isinf(numeric_summary[8]))
1271      self.assertGreater(numeric_summary[8], 0.0)
1272      self.assertTrue(np.isinf(numeric_summary[9]))
1273      self.assertLess(numeric_summary[9], 0.0)
1274      self.assertTrue(np.isnan(numeric_summary[10]))
1275      self.assertTrue(np.isnan(numeric_summary[11]))
1276
1277  def testDebugNumericSummaryFailureIsToleratedWhenOrdered(self):
1278    with session.Session() as sess:
1279      a = variables.VariableV1("1", name="a")
1280      b = variables.VariableV1("3", name="b")
1281      c = variables.VariableV1("2", name="c")
1282
1283      d = math_ops.add(a, b, name="d")
1284      e = math_ops.add(d, c, name="e")
1285      n = parsing_ops.string_to_number(e, name="n")
1286      m = math_ops.add(n, n, name="m")
1287
1288      sess.run(variables.global_variables_initializer())
1289
1290      # Using DebugNumericSummary on sess.run(m) with the default
1291      # tolerate_debug_op_creation_failures=False should error out due to the
1292      # presence of string-dtype Tensors in the graph.
1293      run_metadata = config_pb2.RunMetadata()
1294      run_options = config_pb2.RunOptions(output_partition_graphs=True)
1295      debug_utils.watch_graph(
1296          run_options,
1297          sess.graph,
1298          debug_ops=["DebugNumericSummary"],
1299          debug_urls=self._debug_urls())
1300      with self.assertRaises(errors.FailedPreconditionError):
1301        sess.run(m, options=run_options, run_metadata=run_metadata)
1302
1303      # Using tolerate_debug_op_creation_failures=True should get rid of the
1304      # error.
1305      m_result, dump = self._debug_run_and_get_dump(
1306          sess, m, debug_ops=["DebugNumericSummary"],
1307          tolerate_debug_op_creation_failures=True)
1308      self.assertEqual(264, m_result)
1309
1310      # The integer-dtype Tensors in the graph should have been dumped
1311      # properly.
1312      self.assertIn("n:0:DebugNumericSummary", dump.debug_watch_keys("n"))
1313      self.assertIn("m:0:DebugNumericSummary", dump.debug_watch_keys("m"))
1314
1315  def testDebugNumericSummaryInvalidAttributesStringAreCaught(self):
1316    with session.Session(config=no_rewrite_session_config()) as sess:
1317      a = variables.VariableV1(10.0, name="a")
1318      b = variables.VariableV1(0.0, name="b")
1319      c = variables.VariableV1(0.0, name="c")
1320
1321      x = math_ops.divide(a, b, name="x")
1322      y = math_ops.multiply(x, c, name="y")
1323
1324      sess.run(variables.global_variables_initializer())
1325
1326      run_metadata = config_pb2.RunMetadata()
1327      run_options = config_pb2.RunOptions(output_partition_graphs=True)
1328      debug_utils.watch_graph(
1329          run_options,
1330          sess.graph,
1331          debug_ops=["DebugNumericSummary(foo=1.0)"],
1332          debug_urls=self._debug_urls())
1333      with self.assertRaisesRegexp(
1334          errors.FailedPreconditionError,
1335          r"1 attribute key\(s\) were not valid for debug node "
1336          r"__dbg_.:0_0_DebugNumericSummary: foo"):
1337        sess.run(y, options=run_options, run_metadata=run_metadata)
1338
1339      run_options = config_pb2.RunOptions(output_partition_graphs=True)
1340      debug_utils.watch_graph(
1341          run_options,
1342          sess.graph,
1343          debug_ops=["DebugNumericSummary(foo=1.0; bar=false)"],
1344          debug_urls=self._debug_urls())
1345      with self.assertRaisesRegexp(
1346          errors.FailedPreconditionError,
1347          r"2 attribute key\(s\) were not valid for debug node "
1348          r"__dbg_.:0_0_DebugNumericSummary:"):
1349        sess.run(y, options=run_options, run_metadata=run_metadata)
1350
1351      run_options = config_pb2.RunOptions(output_partition_graphs=True)
1352      debug_utils.watch_graph(
1353          run_options,
1354          sess.graph,
1355          debug_ops=["DebugNumericSummary(foo=1.0; mute_if_healthy=true)"],
1356          debug_urls=self._debug_urls())
1357      with self.assertRaisesRegexp(
1358          errors.FailedPreconditionError,
1359          r"1 attribute key\(s\) were not valid for debug node "
1360          r"__dbg_.:0_0_DebugNumericSummary: foo"):
1361        sess.run(y, options=run_options, run_metadata=run_metadata)
1362
1363  def testDebugNumericSummaryMuteOnHealthyMutesOnlyHealthyTensorDumps(self):
1364    with session.Session(config=no_rewrite_session_config()) as sess:
1365      a = variables.VariableV1(10.0, name="a")
1366      b = variables.VariableV1(0.0, name="b")
1367      c = variables.VariableV1(0.0, name="c")
1368
1369      x = math_ops.divide(a, b, name="x")
1370      y = math_ops.multiply(x, c, name="y")
1371
1372      sess.run(variables.global_variables_initializer())
1373
1374      # Here, validate=False is necessary to avoid causality check error.
1375      # TODO(cais): Maybe let DebugDumpDir constructor automatically ignore
1376      #   debug ops with mute_if_healthy=false attribute during validation.
1377      _, dump = self._debug_run_and_get_dump(
1378          sess, y, debug_ops=["DebugNumericSummary(mute_if_healthy=true)"],
1379          validate=False)
1380
1381      self.assertEqual(2, dump.size)
1382      self.assertAllClose([[
1383          1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, np.inf, -np.inf, np.nan,
1384          np.nan, 1.0, 0.0
1385      ]], dump.get_tensors("x", 0, "DebugNumericSummary"))
1386      self.assertAllClose([[
1387          1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, np.inf, -np.inf, np.nan,
1388          np.nan, 1.0, 0.0
1389      ]], dump.get_tensors("y", 0, "DebugNumericSummary"))
1390
1391      # Another run with the default mute_if_healthy (false) value should
1392      # dump all the tensors.
1393      shutil.rmtree(self._dump_root)
1394      _, dump = self._debug_run_and_get_dump(
1395          sess, y, debug_ops=["DebugNumericSummary()"])
1396      self.assertEqual(8, dump.size)
1397
1398  def testDebugNumericSummaryMuteOnHealthyAndCustomBoundsWork(self):
1399    with session.Session() as sess:
1400      a = variables.VariableV1([10.0, 10.0], name="a")
1401      b = variables.VariableV1([10.0, 2.0], name="b")
1402
1403      x = math_ops.add(a, b, name="x")  # [20.0, 12.0]
1404      y = math_ops.divide(x, b, name="y")  # [2.0, 6.0]
1405
1406      sess.run(variables.global_variables_initializer())
1407
1408      # Here, validate=False is necessary to avoid causality check error.
1409      # TODO(cais): Maybe let DebugDumpDir constructor automatically ignore
1410      #   debug ops with mute_if_healthy=false attribute during validation.
1411      _, dump = self._debug_run_and_get_dump(
1412          sess, y, debug_ops=[
1413              "DebugNumericSummary(mute_if_healthy=true; upper_bound=11.0)"],
1414          validate=False)
1415
1416      self.assertEqual(1, dump.size)
1417      self.assertAllClose([[
1418          1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 12.0, 20.0, 16.0, 16.0, 1.0,
1419          1.0, 2.0]], dump.get_tensors("x", 0, "DebugNumericSummary"))
1420
1421  def testDebugQueueOpsDoesNotoErrorOut(self):
1422    with session.Session() as sess:
1423      q = data_flow_ops.FIFOQueue(3, "float", name="fifo_queue")
1424      q_init = q.enqueue_many(([101.0, 202.0, 303.0],), name="enqueue_many")
1425
1426      _, dump = self._debug_run_and_get_dump(sess, q_init)
1427      self.assertTrue(dump.loaded_partition_graphs())
1428
1429      fifo_queue_tensor = dump.get_tensors("fifo_queue", 0, "DebugIdentity")[0]
1430      self.assertIsInstance(fifo_queue_tensor,
1431                            debug_data.InconvertibleTensorProto)
1432      self.assertTrue(fifo_queue_tensor.initialized)
1433      self.assertAllClose(
1434          [101.0, 202.0, 303.0],
1435          dump.get_tensors("enqueue_many/component_0", 0, "DebugIdentity")[0])
1436
1437  def testLookUpNodePythonTracebackWorks(self):
1438    with session.Session() as sess:
1439      u_init = constant_op.constant(10.0)
1440      u = variables.VariableV1(u_init, name="traceback/u")
1441      v_init = constant_op.constant(20.0)
1442      v = variables.VariableV1(v_init, name="traceback/v")
1443
1444      w = math_ops.multiply(u, v, name="traceback/w")
1445
1446      sess.run(variables.global_variables_initializer())
1447      _, dump = self._debug_run_and_get_dump(sess, w)
1448
1449      # Prior to setting the Python graph, attempts to do traceback lookup
1450      # should lead to exceptions.
1451      with self.assertRaisesRegexp(
1452          LookupError, "Python graph is not available for traceback lookup"):
1453        dump.node_traceback("traceback/w")
1454
1455      dump.set_python_graph(sess.graph)
1456
1457      # After setting the Python graph, attempts to look up nonexistent nodes
1458      # should lead to exceptions.
1459      with self.assertRaisesRegexp(KeyError,
1460                                   r"Cannot find node \"foo\" in Python graph"):
1461        dump.node_traceback("foo")
1462
1463      # Lookup should work with node name input.
1464      traceback = dump.node_traceback("traceback/w")
1465      self.assertIsInstance(traceback, list)
1466      self.assertGreater(len(traceback), 0)
1467      for trace in traceback:
1468        self.assertIsInstance(trace, tuple)
1469
1470      # Lookup should also work with tensor name input.
1471      traceback = dump.node_traceback("traceback/w:0")
1472      self.assertIsInstance(traceback, list)
1473      self.assertGreater(len(traceback), 0)
1474      for trace in traceback:
1475        self.assertIsInstance(trace, tuple)
1476
1477
1478class DebugConcurrentRunCallsTest(test_util.TensorFlowTestCase):
1479  """Test for debugging concurrent Session.run() calls."""
1480
1481  def _get_concurrent_debug_urls(self):
1482    """Abstract method to generate debug URLs for concurrent debugged runs."""
1483    raise NotImplementedError(
1484        "_get_concurrent_debug_urls is not implemented in the base test class")
1485
1486  def testDebugConcurrentVariableUpdates(self):
1487    if test.is_gpu_available():
1488      self.skipTest("No testing concurrent runs on a single GPU.")
1489
1490    with session.Session() as sess:
1491      v = variables.VariableV1(30.0, name="v")
1492      constants = []
1493      for i in xrange(self._num_concurrent_runs):
1494        constants.append(constant_op.constant(1.0, name="c%d" % i))
1495      incs = [
1496          state_ops.assign_add(
1497              v, c, use_locking=True, name=("inc%d" % i))
1498          for (i, c) in enumerate(constants)
1499      ]
1500      sess.run(v.initializer)
1501
1502      concurrent_debug_urls = self._get_concurrent_debug_urls()
1503
1504      def inc_job(index):
1505        run_options = config_pb2.RunOptions(output_partition_graphs=True)
1506        debug_utils.watch_graph(
1507            run_options, sess.graph, debug_urls=concurrent_debug_urls[index])
1508        for _ in xrange(100):
1509          sess.run(incs[index], options=run_options)
1510
1511      inc_threads = []
1512      for index in xrange(self._num_concurrent_runs):
1513        inc_thread = threading.Thread(target=functools.partial(inc_job, index))
1514        inc_thread.start()
1515        inc_threads.append(inc_thread)
1516      for inc_thread in inc_threads:
1517        inc_thread.join()
1518
1519      self.assertAllClose(30.0 + 1.0 * self._num_concurrent_runs * 100,
1520                          sess.run(v))
1521
1522      all_session_run_indices = []
1523      for index in xrange(self._num_concurrent_runs):
1524        dump = debug_data.DebugDumpDir(self._dump_roots[index])
1525        self.assertTrue(dump.loaded_partition_graphs())
1526
1527        v_data = dump.get_tensors("v", 0, "DebugIdentity")
1528        self.assertEqual(100, len(v_data))
1529
1530        # Examine all the core metadata files
1531        core_metadata_files = glob.glob(
1532            os.path.join(self._dump_roots[index], "_tfdbg_core*"))
1533
1534        timestamps = []
1535        session_run_indices = []
1536        executor_step_indices = []
1537        for core_metadata_file in core_metadata_files:
1538          with open(core_metadata_file, "rb") as f:
1539            event = event_pb2.Event()
1540            event.ParseFromString(f.read())
1541            core_metadata = (
1542                debug_data.extract_core_metadata_from_event_proto(event))
1543            timestamps.append(event.wall_time)
1544            session_run_indices.append(core_metadata.session_run_index)
1545            executor_step_indices.append(core_metadata.executor_step_index)
1546
1547        all_session_run_indices.extend(session_run_indices)
1548
1549        # Assert that executor_step_index increases by one at a time.
1550        executor_step_indices = zip(timestamps, executor_step_indices)
1551        executor_step_indices = sorted(
1552            executor_step_indices, key=lambda x: x[0])
1553        for i in xrange(len(executor_step_indices) - 1):
1554          self.assertEquals(executor_step_indices[i][1] + 1,
1555                            executor_step_indices[i + 1][1])
1556
1557        # Assert that session_run_index increase monotonically.
1558        session_run_indices = zip(timestamps, session_run_indices)
1559        session_run_indices = sorted(session_run_indices, key=lambda x: x[0])
1560        for i in xrange(len(session_run_indices) - 1):
1561          self.assertGreater(session_run_indices[i + 1][1],
1562                             session_run_indices[i][1])
1563
1564      # Assert that the session_run_indices from the concurrent run() calls are
1565      # all unique.
1566      self.assertEqual(len(all_session_run_indices),
1567                       len(set(all_session_run_indices)))
1568
1569
1570if __name__ == "__main__":
1571  googletest.main()
1572