1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Shared functions and classes for tfdbg command-line interface."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import math
21
22import numpy as np
23import six
24
25from tensorflow.python.debug.cli import command_parser
26from tensorflow.python.debug.cli import debugger_cli_common
27from tensorflow.python.debug.cli import tensor_format
28from tensorflow.python.debug.lib import common
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import variables
31from tensorflow.python.platform import gfile
32
33RL = debugger_cli_common.RichLine
34
35# Default threshold number of elements above which ellipses will be used
36# when printing the value of the tensor.
37DEFAULT_NDARRAY_DISPLAY_THRESHOLD = 2000
38
39COLOR_BLACK = "black"
40COLOR_BLUE = "blue"
41COLOR_CYAN = "cyan"
42COLOR_GRAY = "gray"
43COLOR_GREEN = "green"
44COLOR_MAGENTA = "magenta"
45COLOR_RED = "red"
46COLOR_WHITE = "white"
47COLOR_YELLOW = "yellow"
48
49TIME_UNIT_US = "us"
50TIME_UNIT_MS = "ms"
51TIME_UNIT_S = "s"
52TIME_UNITS = [TIME_UNIT_US, TIME_UNIT_MS, TIME_UNIT_S]
53
54
55def bytes_to_readable_str(num_bytes, include_b=False):
56  """Generate a human-readable string representing number of bytes.
57
58  The units B, kB, MB and GB are used.
59
60  Args:
61    num_bytes: (`int` or None) Number of bytes.
62    include_b: (`bool`) Include the letter B at the end of the unit.
63
64  Returns:
65    (`str`) A string representing the number of bytes in a human-readable way,
66      including a unit at the end.
67  """
68
69  if num_bytes is None:
70    return str(num_bytes)
71  if num_bytes < 1024:
72    result = "%d" % num_bytes
73  elif num_bytes < 1048576:
74    result = "%.2fk" % (num_bytes / 1024.0)
75  elif num_bytes < 1073741824:
76    result = "%.2fM" % (num_bytes / 1048576.0)
77  else:
78    result = "%.2fG" % (num_bytes / 1073741824.0)
79
80  if include_b:
81    result += "B"
82  return result
83
84
85def time_to_readable_str(value_us, force_time_unit=None):
86  """Convert time value to human-readable string.
87
88  Args:
89    value_us: time value in microseconds.
90    force_time_unit: force the output to use the specified time unit. Must be
91      in TIME_UNITS.
92
93  Returns:
94    Human-readable string representation of the time value.
95
96  Raises:
97    ValueError: if force_time_unit value is not in TIME_UNITS.
98  """
99  if not value_us:
100    return "0"
101  if force_time_unit:
102    if force_time_unit not in TIME_UNITS:
103      raise ValueError("Invalid time unit: %s" % force_time_unit)
104    order = TIME_UNITS.index(force_time_unit)
105    time_unit = force_time_unit
106    return "{:.10g}{}".format(value_us / math.pow(10.0, 3*order), time_unit)
107  else:
108    order = min(len(TIME_UNITS) - 1, int(math.log(value_us, 10) / 3))
109    time_unit = TIME_UNITS[order]
110    return "{:.3g}{}".format(value_us / math.pow(10.0, 3*order), time_unit)
111
112
113def parse_ranges_highlight(ranges_string):
114  """Process ranges highlight string.
115
116  Args:
117    ranges_string: (str) A string representing a numerical range of a list of
118      numerical ranges. See the help info of the -r flag of the print_tensor
119      command for more details.
120
121  Returns:
122    An instance of tensor_format.HighlightOptions, if range_string is a valid
123      representation of a range or a list of ranges.
124  """
125
126  ranges = None
127
128  def ranges_filter(x):
129    r = np.zeros(x.shape, dtype=bool)
130    for range_start, range_end in ranges:
131      r = np.logical_or(r, np.logical_and(x >= range_start, x <= range_end))
132
133    return r
134
135  if ranges_string:
136    ranges = command_parser.parse_ranges(ranges_string)
137    return tensor_format.HighlightOptions(
138        ranges_filter, description=ranges_string)
139  else:
140    return None
141
142
143def numpy_printoptions_from_screen_info(screen_info):
144  if screen_info and "cols" in screen_info:
145    return {"linewidth": screen_info["cols"]}
146  else:
147    return {}
148
149
150def format_tensor(tensor,
151                  tensor_name,
152                  np_printoptions,
153                  print_all=False,
154                  tensor_slicing=None,
155                  highlight_options=None,
156                  include_numeric_summary=False,
157                  write_path=None):
158  """Generate formatted str to represent a tensor or its slices.
159
160  Args:
161    tensor: (numpy ndarray) The tensor value.
162    tensor_name: (str) Name of the tensor, e.g., the tensor's debug watch key.
163    np_printoptions: (dict) Numpy tensor formatting options.
164    print_all: (bool) Whether the tensor is to be displayed in its entirety,
165      instead of printing ellipses, even if its number of elements exceeds
166      the default numpy display threshold.
167      (Note: Even if this is set to true, the screen output can still be cut
168       off by the UI frontend if it consist of more lines than the frontend
169       can handle.)
170    tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If
171      None, no slicing will be performed on the tensor.
172    highlight_options: (tensor_format.HighlightOptions) options to highlight
173      elements of the tensor. See the doc of tensor_format.format_tensor()
174      for more details.
175    include_numeric_summary: Whether a text summary of the numeric values (if
176      applicable) will be included.
177    write_path: A path to save the tensor value (after any slicing) to
178      (optional). `numpy.save()` is used to save the value.
179
180  Returns:
181    An instance of `debugger_cli_common.RichTextLines` representing the
182    (potentially sliced) tensor.
183  """
184
185  if tensor_slicing:
186    # Validate the indexing.
187    value = command_parser.evaluate_tensor_slice(tensor, tensor_slicing)
188    sliced_name = tensor_name + tensor_slicing
189  else:
190    value = tensor
191    sliced_name = tensor_name
192
193  auxiliary_message = None
194  if write_path:
195    with gfile.Open(write_path, "wb") as output_file:
196      np.save(output_file, value)
197    line = debugger_cli_common.RichLine("Saved value to: ")
198    line += debugger_cli_common.RichLine(write_path, font_attr="bold")
199    line += " (%sB)" % bytes_to_readable_str(gfile.Stat(write_path).length)
200    auxiliary_message = debugger_cli_common.rich_text_lines_from_rich_line_list(
201        [line, debugger_cli_common.RichLine("")])
202
203  if print_all:
204    np_printoptions["threshold"] = value.size
205  else:
206    np_printoptions["threshold"] = DEFAULT_NDARRAY_DISPLAY_THRESHOLD
207
208  return tensor_format.format_tensor(
209      value,
210      sliced_name,
211      include_metadata=True,
212      include_numeric_summary=include_numeric_summary,
213      auxiliary_message=auxiliary_message,
214      np_printoptions=np_printoptions,
215      highlight_options=highlight_options)
216
217
218def error(msg):
219  """Generate a RichTextLines output for error.
220
221  Args:
222    msg: (str) The error message.
223
224  Returns:
225    (debugger_cli_common.RichTextLines) A representation of the error message
226      for screen output.
227  """
228
229  return debugger_cli_common.rich_text_lines_from_rich_line_list([
230      RL("ERROR: " + msg, COLOR_RED)])
231
232
233def _recommend_command(command, description, indent=2, create_link=False):
234  """Generate a RichTextLines object that describes a recommended command.
235
236  Args:
237    command: (str) The command to recommend.
238    description: (str) A description of what the command does.
239    indent: (int) How many spaces to indent in the beginning.
240    create_link: (bool) Whether a command link is to be applied to the command
241      string.
242
243  Returns:
244    (RichTextLines) Formatted text (with font attributes) for recommending the
245      command.
246  """
247
248  indent_str = " " * indent
249
250  if create_link:
251    font_attr = [debugger_cli_common.MenuItem("", command), "bold"]
252  else:
253    font_attr = "bold"
254
255  lines = [RL(indent_str) + RL(command, font_attr) + ":",
256           indent_str + "  " + description]
257
258  return debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
259
260
261def get_tfdbg_logo():
262  """Make an ASCII representation of the tfdbg logo."""
263
264  lines = [
265      "",
266      "TTTTTT FFFF DDD  BBBB   GGG ",
267      "  TT   F    D  D B   B G    ",
268      "  TT   FFF  D  D BBBB  G  GG",
269      "  TT   F    D  D B   B G   G",
270      "  TT   F    DDD  BBBB   GGG ",
271      "",
272  ]
273  return debugger_cli_common.RichTextLines(lines)
274
275
276_HORIZONTAL_BAR = "======================================"
277
278
279def get_run_start_intro(run_call_count,
280                        fetches,
281                        feed_dict,
282                        tensor_filters,
283                        is_callable_runner=False):
284  """Generate formatted intro for run-start UI.
285
286  Args:
287    run_call_count: (int) Run call counter.
288    fetches: Fetches of the `Session.run()` call. See doc of `Session.run()`
289      for more details.
290    feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()`
291      for more details.
292    tensor_filters: (dict) A dict from tensor-filter name to tensor-filter
293      callable.
294    is_callable_runner: (bool) whether a runner returned by
295        Session.make_callable is being run.
296
297  Returns:
298    (RichTextLines) Formatted intro message about the `Session.run()` call.
299  """
300
301  fetch_lines = common.get_flattened_names(fetches)
302
303  if not feed_dict:
304    feed_dict_lines = [debugger_cli_common.RichLine("  (Empty)")]
305  else:
306    feed_dict_lines = []
307    for feed_key in feed_dict:
308      feed_key_name = common.get_graph_element_name(feed_key)
309      feed_dict_line = debugger_cli_common.RichLine("  ")
310      feed_dict_line += debugger_cli_common.RichLine(
311          feed_key_name,
312          debugger_cli_common.MenuItem(None, "pf '%s'" % feed_key_name))
313      # Surround the name string with quotes, because feed_key_name may contain
314      # spaces in some cases, e.g., SparseTensors.
315      feed_dict_lines.append(feed_dict_line)
316  feed_dict_lines = debugger_cli_common.rich_text_lines_from_rich_line_list(
317      feed_dict_lines)
318
319  out = debugger_cli_common.RichTextLines(_HORIZONTAL_BAR)
320  if is_callable_runner:
321    out.append("Running a runner returned by Session.make_callable()")
322  else:
323    out.append("Session.run() call #%d:" % run_call_count)
324    out.append("")
325    out.append("Fetch(es):")
326    out.extend(debugger_cli_common.RichTextLines(
327        ["  " + line for line in fetch_lines]))
328    out.append("")
329    out.append("Feed dict:")
330    out.extend(feed_dict_lines)
331  out.append(_HORIZONTAL_BAR)
332  out.append("")
333  out.append("Select one of the following commands to proceed ---->")
334
335  out.extend(
336      _recommend_command(
337          "run",
338          "Execute the run() call with debug tensor-watching",
339          create_link=True))
340  out.extend(
341      _recommend_command(
342          "run -n",
343          "Execute the run() call without debug tensor-watching",
344          create_link=True))
345  out.extend(
346      _recommend_command(
347          "run -t <T>",
348          "Execute run() calls (T - 1) times without debugging, then "
349          "execute run() once more with debugging and drop back to the CLI"))
350  out.extend(
351      _recommend_command(
352          "run -f <filter_name>",
353          "Keep executing run() calls until a dumped tensor passes a given, "
354          "registered filter (conditional breakpoint mode)"))
355
356  more_lines = ["    Registered filter(s):"]
357  if tensor_filters:
358    filter_names = []
359    for filter_name in tensor_filters:
360      filter_names.append(filter_name)
361      command_menu_node = debugger_cli_common.MenuItem(
362          "", "run -f %s" % filter_name)
363      more_lines.append(RL("        * ") + RL(filter_name, command_menu_node))
364  else:
365    more_lines.append("        (None)")
366
367  out.extend(
368      debugger_cli_common.rich_text_lines_from_rich_line_list(more_lines))
369
370  out.extend(
371      _recommend_command(
372          "invoke_stepper",
373          "Use the node-stepper interface, which allows you to interactively "
374          "step through nodes involved in the graph run() call and "
375          "inspect/modify their values", create_link=True))
376
377  out.append("")
378
379  out.append_rich_line(RL("For more details, see ") +
380                       RL("help.", debugger_cli_common.MenuItem("", "help")) +
381                       ".")
382  out.append("")
383
384  # Make main menu for the run-start intro.
385  menu = debugger_cli_common.Menu()
386  menu.append(debugger_cli_common.MenuItem("run", "run"))
387  menu.append(debugger_cli_common.MenuItem(
388      "invoke_stepper", "invoke_stepper"))
389  menu.append(debugger_cli_common.MenuItem("exit", "exit"))
390  out.annotations[debugger_cli_common.MAIN_MENU_KEY] = menu
391
392  return out
393
394
395def get_run_short_description(run_call_count,
396                              fetches,
397                              feed_dict,
398                              is_callable_runner=False):
399  """Get a short description of the run() call.
400
401  Args:
402    run_call_count: (int) Run call counter.
403    fetches: Fetches of the `Session.run()` call. See doc of `Session.run()`
404      for more details.
405    feed_dict: Feeds to the `Session.run()` call. See doc of `Session.run()`
406      for more details.
407    is_callable_runner: (bool) whether a runner returned by
408        Session.make_callable is being run.
409
410  Returns:
411    (str) A short description of the run() call, including information about
412      the fetche(s) and feed(s).
413  """
414  if is_callable_runner:
415    return "runner from make_callable()"
416
417  description = "run #%d: " % run_call_count
418
419  if isinstance(fetches, (ops.Tensor, ops.Operation, variables.Variable)):
420    description += "1 fetch (%s); " % common.get_graph_element_name(fetches)
421  else:
422    # Could be (nested) list, tuple, dict or namedtuple.
423    num_fetches = len(common.get_flattened_names(fetches))
424    if num_fetches > 1:
425      description += "%d fetches; " % num_fetches
426    else:
427      description += "%d fetch; " % num_fetches
428
429  if not feed_dict:
430    description += "0 feeds"
431  else:
432    if len(feed_dict) == 1:
433      for key in feed_dict:
434        description += "1 feed (%s)" % (
435            key if isinstance(key, six.string_types) or not hasattr(key, "name")
436            else key.name)
437    else:
438      description += "%d feeds" % len(feed_dict)
439
440  return description
441
442
443def get_error_intro(tf_error):
444  """Generate formatted intro for TensorFlow run-time error.
445
446  Args:
447    tf_error: (errors.OpError) TensorFlow run-time error object.
448
449  Returns:
450    (RichTextLines) Formatted intro message about the run-time OpError, with
451      sample commands for debugging.
452  """
453
454  if hasattr(tf_error, "op") and hasattr(tf_error.op, "name"):
455    op_name = tf_error.op.name
456  else:
457    op_name = None
458
459  intro_lines = [
460      "--------------------------------------",
461      RL("!!! An error occurred during the run !!!", "blink"),
462      "",
463  ]
464
465  out = debugger_cli_common.rich_text_lines_from_rich_line_list(intro_lines)
466
467  if op_name is not None:
468    out.extend(debugger_cli_common.RichTextLines(
469        ["You may use the following commands to debug:"]))
470    out.extend(
471        _recommend_command("ni -a -d -t %s" % op_name,
472                           "Inspect information about the failing op.",
473                           create_link=True))
474    out.extend(
475        _recommend_command("li -r %s" % op_name,
476                           "List inputs to the failing op, recursively.",
477                           create_link=True))
478
479    out.extend(
480        _recommend_command(
481            "lt",
482            "List all tensors dumped during the failing run() call.",
483            create_link=True))
484  else:
485    out.extend(debugger_cli_common.RichTextLines([
486        "WARNING: Cannot determine the name of the op that caused the error."]))
487
488  more_lines = [
489      "",
490      "Op name:    %s" % op_name,
491      "Error type: " + str(type(tf_error)),
492      "",
493      "Details:",
494      str(tf_error),
495      "",
496      "--------------------------------------",
497      "",
498  ]
499
500  out.extend(debugger_cli_common.RichTextLines(more_lines))
501
502  return out
503