1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TensorFlow Debugger (tfdbg) Utilities."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import re
22
23from six.moves import xrange  # pylint: disable=redefined-builtin
24
25
26def add_debug_tensor_watch(run_options,
27                           node_name,
28                           output_slot=0,
29                           debug_ops="DebugIdentity",
30                           debug_urls=None,
31                           tolerate_debug_op_creation_failures=False,
32                           global_step=-1):
33  """Add watch on a `Tensor` to `RunOptions`.
34
35  N.B.:
36    1. Under certain circumstances, the `Tensor` may not get actually watched
37      (e.g., if the node of the `Tensor` is constant-folded during runtime).
38    2. For debugging purposes, the `parallel_iteration` attribute of all
39      `tf.while_loop`s in the graph are set to 1 to prevent any node from
40      being executed multiple times concurrently. This change does not affect
41      subsequent non-debugged runs of the same `tf.while_loop`s.
42
43  Args:
44    run_options: An instance of `config_pb2.RunOptions` to be modified.
45    node_name: (`str`) name of the node to watch.
46    output_slot: (`int`) output slot index of the tensor from the watched node.
47    debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s). Can be a
48      `list` of `str` or a single `str`. The latter case is equivalent to a
49      `list` of `str` with only one element.
50      For debug op types with customizable attributes, each debug op string can
51      optionally contain a list of attribute names, in the syntax of:
52        debug_op_name(attr_name_1=attr_value_1;attr_name_2=attr_value_2;...)
53    debug_urls: (`str` or `list` of `str`) URL(s) to send debug values to,
54      e.g., `file:///tmp/tfdbg_dump_1`, `grpc://localhost:12345`.
55    tolerate_debug_op_creation_failures: (`bool`) Whether to tolerate debug op
56      creation failures by not throwing exceptions.
57    global_step: (`int`) Optional global_step count for this debug tensor
58      watch.
59  """
60
61  watch_opts = run_options.debug_options.debug_tensor_watch_opts
62  run_options.debug_options.global_step = global_step
63
64  watch = watch_opts.add()
65  watch.tolerate_debug_op_creation_failures = (
66      tolerate_debug_op_creation_failures)
67  watch.node_name = node_name
68  watch.output_slot = output_slot
69
70  if isinstance(debug_ops, str):
71    debug_ops = [debug_ops]
72
73  watch.debug_ops.extend(debug_ops)
74
75  if debug_urls:
76    if isinstance(debug_urls, str):
77      debug_urls = [debug_urls]
78
79    watch.debug_urls.extend(debug_urls)
80
81
82def watch_graph(run_options,
83                graph,
84                debug_ops="DebugIdentity",
85                debug_urls=None,
86                node_name_regex_allowlist=None,
87                op_type_regex_allowlist=None,
88                tensor_dtype_regex_allowlist=None,
89                tolerate_debug_op_creation_failures=False,
90                global_step=-1,
91                reset_disk_byte_usage=False):
92  """Add debug watches to `RunOptions` for a TensorFlow graph.
93
94  To watch all `Tensor`s on the graph, let both `node_name_regex_allowlist`
95  and `op_type_regex_allowlist` be the default (`None`).
96
97  N.B.:
98    1. Under certain circumstances, the `Tensor` may not get actually watched
99      (e.g., if the node of the `Tensor` is constant-folded during runtime).
100    2. For debugging purposes, the `parallel_iteration` attribute of all
101      `tf.while_loop`s in the graph are set to 1 to prevent any node from
102      being executed multiple times concurrently. This change does not affect
103      subsequent non-debugged runs of the same `tf.while_loop`s.
104
105
106  Args:
107    run_options: An instance of `config_pb2.RunOptions` to be modified.
108    graph: An instance of `ops.Graph`.
109    debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s) to use.
110    debug_urls: URLs to send debug values to. Can be a list of strings,
111      a single string, or None. The case of a single string is equivalent to
112      a list consisting of a single string, e.g., `file:///tmp/tfdbg_dump_1`,
113      `grpc://localhost:12345`.
114      For debug op types with customizable attributes, each debug op name string
115      can optionally contain a list of attribute names, in the syntax of:
116        debug_op_name(attr_name_1=attr_value_1;attr_name_2=attr_value_2;...)
117    node_name_regex_allowlist: Regular-expression allowlist for node_name,
118      e.g., `"(weight_[0-9]+|bias_.*)"`
119    op_type_regex_allowlist: Regular-expression allowlist for the op type of
120      nodes, e.g., `"(Variable|Add)"`.
121      If both `node_name_regex_allowlist` and `op_type_regex_allowlist`
122      are set, the two filtering operations will occur in a logical `AND`
123      relation. In other words, a node will be included if and only if it
124      hits both allowlists.
125    tensor_dtype_regex_allowlist: Regular-expression allowlist for Tensor
126      data type, e.g., `"^int.*"`.
127      This allowlist operates in logical `AND` relations to the two allowlists
128      above.
129    tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
130      failures (e.g., due to dtype incompatibility) are to be tolerated by not
131      throwing exceptions.
132    global_step: (`int`) Optional global_step count for this debug tensor
133      watch.
134    reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte
135      usage to zero (default: `False`).
136  """
137  if not debug_ops:
138    raise ValueError("debug_ops must not be empty or None.")
139  if not debug_urls:
140    raise ValueError("debug_urls must not be empty or None.")
141
142  if isinstance(debug_ops, str):
143    debug_ops = [debug_ops]
144
145  node_name_pattern = (
146      re.compile(node_name_regex_allowlist)
147      if node_name_regex_allowlist else None)
148  op_type_pattern = (
149      re.compile(op_type_regex_allowlist) if op_type_regex_allowlist else None)
150  tensor_dtype_pattern = (
151      re.compile(tensor_dtype_regex_allowlist)
152      if tensor_dtype_regex_allowlist else None)
153
154  ops = graph.get_operations()
155  for op in ops:
156    # Skip nodes without any output tensors.
157    if not op.outputs:
158      continue
159
160    node_name = op.name
161    op_type = op.type
162
163    if node_name_pattern and not node_name_pattern.match(node_name):
164      continue
165    if op_type_pattern and not op_type_pattern.match(op_type):
166      continue
167
168    for slot in xrange(len(op.outputs)):
169      if (tensor_dtype_pattern and
170          not tensor_dtype_pattern.match(op.outputs[slot].dtype.name)):
171        continue
172
173      add_debug_tensor_watch(
174          run_options,
175          node_name,
176          output_slot=slot,
177          debug_ops=debug_ops,
178          debug_urls=debug_urls,
179          tolerate_debug_op_creation_failures=(
180              tolerate_debug_op_creation_failures),
181          global_step=global_step)
182
183  # If no filter for node or tensor is used, will add a wildcard node name, so
184  # that all nodes, including the ones created internally by TensorFlow itself
185  # (e.g., by Grappler), can be watched during debugging.
186  use_node_name_wildcard = (not node_name_pattern and
187                            not op_type_pattern and
188                            not tensor_dtype_pattern)
189  if use_node_name_wildcard:
190    add_debug_tensor_watch(
191        run_options,
192        "*",
193        output_slot=-1,
194        debug_ops=debug_ops,
195        debug_urls=debug_urls,
196        tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures,
197        global_step=global_step)
198
199  run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage
200
201
202def watch_graph_with_denylists(run_options,
203                               graph,
204                               debug_ops="DebugIdentity",
205                               debug_urls=None,
206                               node_name_regex_denylist=None,
207                               op_type_regex_denylist=None,
208                               tensor_dtype_regex_denylist=None,
209                               tolerate_debug_op_creation_failures=False,
210                               global_step=-1,
211                               reset_disk_byte_usage=False):
212  """Add debug tensor watches, denylisting nodes and op types.
213
214  This is similar to `watch_graph()`, but the node names and op types are
215  denylisted, instead of allowlisted.
216
217  N.B.:
218    1. Under certain circumstances, the `Tensor` may not get actually watched
219      (e.g., if the node of the `Tensor` is constant-folded during runtime).
220    2. For debugging purposes, the `parallel_iteration` attribute of all
221      `tf.while_loop`s in the graph are set to 1 to prevent any node from
222      being executed multiple times concurrently. This change does not affect
223      subsequent non-debugged runs of the same `tf.while_loop`s.
224
225  Args:
226    run_options: An instance of `config_pb2.RunOptions` to be modified.
227    graph: An instance of `ops.Graph`.
228    debug_ops: (`str` or `list` of `str`) name(s) of the debug op(s) to use. See
229      the documentation of `watch_graph` for more details.
230    debug_urls: URL(s) to send debug values to, e.g.,
231      `file:///tmp/tfdbg_dump_1`, `grpc://localhost:12345`.
232    node_name_regex_denylist: Regular-expression denylist for node_name. This
233      should be a string, e.g., `"(weight_[0-9]+|bias_.*)"`.
234    op_type_regex_denylist: Regular-expression denylist for the op type of
235      nodes, e.g., `"(Variable|Add)"`. If both node_name_regex_denylist and
236      op_type_regex_denylist are set, the two filtering operations will occur in
237      a logical `OR` relation. In other words, a node will be excluded if it
238      hits either of the two denylists; a node will be included if and only if
239      it hits neither of the denylists.
240    tensor_dtype_regex_denylist: Regular-expression denylist for Tensor data
241      type, e.g., `"^int.*"`. This denylist operates in logical `OR` relations
242      to the two allowlists above.
243    tolerate_debug_op_creation_failures: (`bool`) whether debug op creation
244      failures (e.g., due to dtype incompatibility) are to be tolerated by not
245      throwing exceptions.
246    global_step: (`int`) Optional global_step count for this debug tensor watch.
247    reset_disk_byte_usage: (`bool`) whether to reset the tracked disk byte
248      usage to zero (default: `False`).
249  """
250
251  if isinstance(debug_ops, str):
252    debug_ops = [debug_ops]
253
254  node_name_pattern = (
255      re.compile(node_name_regex_denylist)
256      if node_name_regex_denylist else None)
257  op_type_pattern = (
258      re.compile(op_type_regex_denylist) if op_type_regex_denylist else None)
259  tensor_dtype_pattern = (
260      re.compile(tensor_dtype_regex_denylist)
261      if tensor_dtype_regex_denylist else None)
262
263  ops = graph.get_operations()
264  for op in ops:
265    # Skip nodes without any output tensors.
266    if not op.outputs:
267      continue
268
269    node_name = op.name
270    op_type = op.type
271
272    if node_name_pattern and node_name_pattern.match(node_name):
273      continue
274    if op_type_pattern and op_type_pattern.match(op_type):
275      continue
276
277    for slot in xrange(len(op.outputs)):
278      if (tensor_dtype_pattern and
279          tensor_dtype_pattern.match(op.outputs[slot].dtype.name)):
280        continue
281
282      add_debug_tensor_watch(
283          run_options,
284          node_name,
285          output_slot=slot,
286          debug_ops=debug_ops,
287          debug_urls=debug_urls,
288          tolerate_debug_op_creation_failures=(
289              tolerate_debug_op_creation_failures),
290          global_step=global_step)
291    run_options.debug_options.reset_disk_byte_usage = reset_disk_byte_usage
292