1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Format tensors (ndarrays) for screen display and navigation."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import copy
21import re
22
23import numpy as np
24from six.moves import xrange  # pylint: disable=redefined-builtin
25
26from tensorflow.python.debug.cli import debugger_cli_common
27from tensorflow.python.debug.lib import debug_data
28
29_NUMPY_OMISSION = "...,"
30_NUMPY_DEFAULT_EDGE_ITEMS = 3
31
32_NUMBER_REGEX = re.compile(r"[-+]?([0-9][-+0-9eE\.]+|nan|inf)(\s|,|\])")
33
34BEGIN_INDICES_KEY = "i0"
35OMITTED_INDICES_KEY = "omitted"
36
37DEFAULT_TENSOR_ELEMENT_HIGHLIGHT_FONT_ATTR = "bold"
38
39
40class HighlightOptions(object):
41  """Options for highlighting elements of a tensor."""
42
43  def __init__(self,
44               criterion,
45               description=None,
46               font_attr=DEFAULT_TENSOR_ELEMENT_HIGHLIGHT_FONT_ATTR):
47    """Constructor of HighlightOptions.
48
49    Args:
50      criterion: (callable) A callable of the following signature:
51        def to_highlight(X):
52          # Args:
53          #   X: The tensor to highlight elements in.
54          #
55          # Returns:
56          #   (boolean ndarray) A boolean ndarray of the same shape as X
57          #   indicating which elements are to be highlighted (iff True).
58        This callable will be used as the argument of np.argwhere() to
59        determine which elements of the tensor are to be highlighted.
60      description: (str) Description of the highlight criterion embodied by
61        criterion.
62      font_attr: (str) Font attribute to be applied to the
63        highlighted elements.
64
65    """
66
67    self.criterion = criterion
68    self.description = description
69    self.font_attr = font_attr
70
71
72def format_tensor(tensor,
73                  tensor_label,
74                  include_metadata=False,
75                  auxiliary_message=None,
76                  include_numeric_summary=False,
77                  np_printoptions=None,
78                  highlight_options=None):
79  """Generate a RichTextLines object showing a tensor in formatted style.
80
81  Args:
82    tensor: The tensor to be displayed, as a numpy ndarray or other
83      appropriate format (e.g., None representing uninitialized tensors).
84    tensor_label: A label for the tensor, as a string. If set to None, will
85      suppress the tensor name line in the return value.
86    include_metadata: Whether metadata such as dtype and shape are to be
87      included in the formatted text.
88    auxiliary_message: An auxiliary message to display under the tensor label,
89      dtype and shape information lines.
90    include_numeric_summary: Whether a text summary of the numeric values (if
91      applicable) will be included.
92    np_printoptions: A dictionary of keyword arguments that are passed to a
93      call of np.set_printoptions() to set the text format for display numpy
94      ndarrays.
95    highlight_options: (HighlightOptions) options for highlighting elements
96      of the tensor.
97
98  Returns:
99    A RichTextLines object. Its annotation field has line-by-line markups to
100    indicate which indices in the array the first element of each line
101    corresponds to.
102  """
103  lines = []
104  font_attr_segs = {}
105
106  if tensor_label is not None:
107    lines.append("Tensor \"%s\":" % tensor_label)
108    suffix = tensor_label.split(":")[-1]
109    if suffix.isdigit():
110      # Suffix is a number. Assume it is the output slot index.
111      font_attr_segs[0] = [(8, 8 + len(tensor_label), "bold")]
112    else:
113      # Suffix is not a number. It is auxiliary information such as the debug
114      # op type. In this case, highlight the suffix with a different color.
115      debug_op_len = len(suffix)
116      proper_len = len(tensor_label) - debug_op_len - 1
117      font_attr_segs[0] = [
118          (8, 8 + proper_len, "bold"),
119          (8 + proper_len + 1, 8 + proper_len + 1 + debug_op_len, "yellow")
120      ]
121
122  if isinstance(tensor, debug_data.InconvertibleTensorProto):
123    if lines:
124      lines.append("")
125    lines.extend(str(tensor).split("\n"))
126    return debugger_cli_common.RichTextLines(lines)
127  elif not isinstance(tensor, np.ndarray):
128    # If tensor is not a np.ndarray, return simple text-line representation of
129    # the object without annotations.
130    if lines:
131      lines.append("")
132    lines.extend(repr(tensor).split("\n"))
133    return debugger_cli_common.RichTextLines(lines)
134
135  if include_metadata:
136    lines.append("  dtype: %s" % str(tensor.dtype))
137    lines.append("  shape: %s" % str(tensor.shape).replace("L", ""))
138
139  if lines:
140    lines.append("")
141  formatted = debugger_cli_common.RichTextLines(
142      lines, font_attr_segs=font_attr_segs)
143
144  if auxiliary_message:
145    formatted.extend(auxiliary_message)
146
147  if include_numeric_summary:
148    formatted.append("Numeric summary:")
149    formatted.extend(numeric_summary(tensor))
150    formatted.append("")
151
152  # Apply custom string formatting options for numpy ndarray.
153  if np_printoptions is not None:
154    np.set_printoptions(**np_printoptions)
155
156  array_lines = repr(tensor).split("\n")
157  if tensor.dtype.type is not np.string_:
158    # Parse array lines to get beginning indices for each line.
159
160    # TODO(cais): Currently, we do not annotate string-type tensors due to
161    #   difficulty in escaping sequences. Address this issue.
162    annotations = _annotate_ndarray_lines(
163        array_lines, tensor, np_printoptions=np_printoptions)
164  else:
165    annotations = None
166  formatted_array = debugger_cli_common.RichTextLines(
167      array_lines, annotations=annotations)
168  formatted.extend(formatted_array)
169
170  # Perform optional highlighting.
171  if highlight_options is not None:
172    indices_list = list(np.argwhere(highlight_options.criterion(tensor)))
173
174    total_elements = np.size(tensor)
175    highlight_summary = "Highlighted%s: %d of %d element(s) (%.2f%%)" % (
176        "(%s)" % highlight_options.description if highlight_options.description
177        else "", len(indices_list), total_elements,
178        len(indices_list) / float(total_elements) * 100.0)
179
180    formatted.lines[0] += " " + highlight_summary
181
182    if indices_list:
183      indices_list = [list(indices) for indices in indices_list]
184
185      are_omitted, rows, start_cols, end_cols = locate_tensor_element(
186          formatted, indices_list)
187      for is_omitted, row, start_col, end_col in zip(are_omitted, rows,
188                                                     start_cols, end_cols):
189        if is_omitted or start_col is None or end_col is None:
190          continue
191
192        if row in formatted.font_attr_segs:
193          formatted.font_attr_segs[row].append(
194              (start_col, end_col, highlight_options.font_attr))
195        else:
196          formatted.font_attr_segs[row] = [(start_col, end_col,
197                                            highlight_options.font_attr)]
198
199  return formatted
200
201
202def _annotate_ndarray_lines(
203    array_lines, tensor, np_printoptions=None, offset=0):
204  """Generate annotations for line-by-line begin indices of tensor text.
205
206  Parse the numpy-generated text representation of a numpy ndarray to
207  determine the indices of the first element of each text line (if any
208  element is present in the line).
209
210  For example, given the following multi-line ndarray text representation:
211      ["array([[ 0.    ,  0.0625,  0.125 ,  0.1875],",
212       "       [ 0.25  ,  0.3125,  0.375 ,  0.4375],",
213       "       [ 0.5   ,  0.5625,  0.625 ,  0.6875],",
214       "       [ 0.75  ,  0.8125,  0.875 ,  0.9375]])"]
215  the generate annotation will be:
216      {0: {BEGIN_INDICES_KEY: [0, 0]},
217       1: {BEGIN_INDICES_KEY: [1, 0]},
218       2: {BEGIN_INDICES_KEY: [2, 0]},
219       3: {BEGIN_INDICES_KEY: [3, 0]}}
220
221  Args:
222    array_lines: Text lines representing the tensor, as a list of str.
223    tensor: The tensor being formatted as string.
224    np_printoptions: A dictionary of keyword arguments that are passed to a
225      call of np.set_printoptions().
226    offset: Line number offset applied to the line indices in the returned
227      annotation.
228
229  Returns:
230    An annotation as a dict.
231  """
232
233  if np_printoptions and "edgeitems" in np_printoptions:
234    edge_items = np_printoptions["edgeitems"]
235  else:
236    edge_items = _NUMPY_DEFAULT_EDGE_ITEMS
237
238  annotations = {}
239
240  # Put metadata about the tensor in the annotations["tensor_metadata"].
241  annotations["tensor_metadata"] = {
242      "dtype": tensor.dtype, "shape": tensor.shape}
243
244  dims = np.shape(tensor)
245  ndims = len(dims)
246  if ndims == 0:
247    # No indices for a 0D tensor.
248    return annotations
249
250  curr_indices = [0] * len(dims)
251  curr_dim = 0
252  for i in xrange(len(array_lines)):
253    line = array_lines[i].strip()
254
255    if not line:
256      # Skip empty lines, which can appear for >= 3D arrays.
257      continue
258
259    if line == _NUMPY_OMISSION:
260      annotations[offset + i] = {OMITTED_INDICES_KEY: copy.copy(curr_indices)}
261      curr_indices[curr_dim - 1] = dims[curr_dim - 1] - edge_items
262    else:
263      num_lbrackets = line.count("[")  # TODO(cais): String array escaping.
264      num_rbrackets = line.count("]")
265
266      curr_dim += num_lbrackets - num_rbrackets
267
268      annotations[offset + i] = {BEGIN_INDICES_KEY: copy.copy(curr_indices)}
269      if num_rbrackets == 0:
270        line_content = line[line.rfind("[") + 1:]
271        num_elements = line_content.count(",")
272        curr_indices[curr_dim - 1] += num_elements
273      else:
274        if curr_dim > 0:
275          curr_indices[curr_dim - 1] += 1
276          for k in xrange(curr_dim, ndims):
277            curr_indices[k] = 0
278
279  return annotations
280
281
282def locate_tensor_element(formatted, indices):
283  """Locate a tensor element in formatted text lines, given element indices.
284
285  Given a RichTextLines object representing a tensor and indices of the sought
286  element, return the row number at which the element is located (if exists).
287
288  Args:
289    formatted: A RichTextLines object containing formatted text lines
290      representing the tensor.
291    indices: Indices of the sought element, as a list of int or a list of list
292      of int. The former case is for a single set of indices to look up,
293      whereas the latter case is for looking up a batch of indices sets at once.
294      In the latter case, the indices must be in ascending order, or a
295      ValueError will be raised.
296
297  Returns:
298    1) A boolean indicating whether the element falls into an omitted line.
299    2) Row index.
300    3) Column start index, i.e., the first column in which the representation
301       of the specified tensor starts, if it can be determined. If it cannot
302       be determined (e.g., due to ellipsis), None.
303    4) Column end index, i.e., the column right after the last column that
304       represents the specified tensor. Iff it cannot be determined, None.
305
306  For return values described above are based on a single set of indices to
307    look up. In the case of batch mode (multiple sets of indices), the return
308    values will be lists of the types described above.
309
310  Raises:
311    AttributeError: If:
312      Input argument "formatted" does not have the required annotations.
313    ValueError: If:
314      1) Indices do not match the dimensions of the tensor, or
315      2) Indices exceed sizes of the tensor, or
316      3) Indices contain negative value(s).
317      4) If in batch mode, and if not all sets of indices are in ascending
318         order.
319  """
320
321  if isinstance(indices[0], list):
322    indices_list = indices
323    input_batch = True
324  else:
325    indices_list = [indices]
326    input_batch = False
327
328  # Check that tensor_metadata is available.
329  if "tensor_metadata" not in formatted.annotations:
330    raise AttributeError("tensor_metadata is not available in annotations.")
331
332  # Sanity check on input argument.
333  _validate_indices_list(indices_list, formatted)
334
335  dims = formatted.annotations["tensor_metadata"]["shape"]
336  batch_size = len(indices_list)
337  lines = formatted.lines
338  annot = formatted.annotations
339  prev_r = 0
340  prev_line = ""
341  prev_indices = [0] * len(dims)
342
343  # Initialize return values
344  are_omitted = [None] * batch_size
345  row_indices = [None] * batch_size
346  start_columns = [None] * batch_size
347  end_columns = [None] * batch_size
348
349  batch_pos = 0  # Current position in the batch.
350
351  for r in xrange(len(lines)):
352    if r not in annot:
353      continue
354
355    if BEGIN_INDICES_KEY in annot[r]:
356      indices_key = BEGIN_INDICES_KEY
357    elif OMITTED_INDICES_KEY in annot[r]:
358      indices_key = OMITTED_INDICES_KEY
359
360    matching_indices_list = [
361        ind for ind in indices_list[batch_pos:]
362        if prev_indices <= ind < annot[r][indices_key]
363    ]
364
365    if matching_indices_list:
366      num_matches = len(matching_indices_list)
367
368      match_start_columns, match_end_columns = _locate_elements_in_line(
369          prev_line, matching_indices_list, prev_indices)
370
371      start_columns[batch_pos:batch_pos + num_matches] = match_start_columns
372      end_columns[batch_pos:batch_pos + num_matches] = match_end_columns
373      are_omitted[batch_pos:batch_pos + num_matches] = [
374          OMITTED_INDICES_KEY in annot[prev_r]
375      ] * num_matches
376      row_indices[batch_pos:batch_pos + num_matches] = [prev_r] * num_matches
377
378      batch_pos += num_matches
379      if batch_pos >= batch_size:
380        break
381
382    prev_r = r
383    prev_line = lines[r]
384    prev_indices = annot[r][indices_key]
385
386  if batch_pos < batch_size:
387    matching_indices_list = indices_list[batch_pos:]
388    num_matches = len(matching_indices_list)
389
390    match_start_columns, match_end_columns = _locate_elements_in_line(
391        prev_line, matching_indices_list, prev_indices)
392
393    start_columns[batch_pos:batch_pos + num_matches] = match_start_columns
394    end_columns[batch_pos:batch_pos + num_matches] = match_end_columns
395    are_omitted[batch_pos:batch_pos + num_matches] = [
396        OMITTED_INDICES_KEY in annot[prev_r]
397    ] * num_matches
398    row_indices[batch_pos:batch_pos + num_matches] = [prev_r] * num_matches
399
400  if input_batch:
401    return are_omitted, row_indices, start_columns, end_columns
402  else:
403    return are_omitted[0], row_indices[0], start_columns[0], end_columns[0]
404
405
406def _validate_indices_list(indices_list, formatted):
407  prev_ind = None
408  for ind in indices_list:
409    # Check indices match tensor dimensions.
410    dims = formatted.annotations["tensor_metadata"]["shape"]
411    if len(ind) != len(dims):
412      raise ValueError("Dimensions mismatch: requested: %d; actual: %d" %
413                       (len(ind), len(dims)))
414
415    # Check indices is within size limits.
416    for req_idx, siz in zip(ind, dims):
417      if req_idx >= siz:
418        raise ValueError("Indices exceed tensor dimensions.")
419      if req_idx < 0:
420        raise ValueError("Indices contain negative value(s).")
421
422    # Check indices are in ascending order.
423    if prev_ind and ind < prev_ind:
424      raise ValueError("Input indices sets are not in ascending order.")
425
426    prev_ind = ind
427
428
429def _locate_elements_in_line(line, indices_list, ref_indices):
430  """Determine the start and end indices of an element in a line.
431
432  Args:
433    line: (str) the line in which the element is to be sought.
434    indices_list: (list of list of int) list of indices of the element to
435       search for. Assumes that the indices in the batch are unique and sorted
436       in ascending order.
437    ref_indices: (list of int) reference indices, i.e., the indices of the
438      first element represented in the line.
439
440  Returns:
441    start_columns: (list of int) start column indices, if found. If not found,
442      None.
443    end_columns: (list of int) end column indices, if found. If not found,
444      None.
445    If found, the element is represented in the left-closed-right-open interval
446      [start_column, end_column].
447  """
448
449  batch_size = len(indices_list)
450  offsets = [indices[-1] - ref_indices[-1] for indices in indices_list]
451
452  start_columns = [None] * batch_size
453  end_columns = [None] * batch_size
454
455  if _NUMPY_OMISSION in line:
456    ellipsis_index = line.find(_NUMPY_OMISSION)
457  else:
458    ellipsis_index = len(line)
459
460  matches_iter = re.finditer(_NUMBER_REGEX, line)
461
462  batch_pos = 0
463
464  offset_counter = 0
465  for match in matches_iter:
466    if match.start() > ellipsis_index:
467      # Do not attempt to search beyond ellipsis.
468      break
469
470    if offset_counter == offsets[batch_pos]:
471      start_columns[batch_pos] = match.start()
472      # Remove the final comma, right bracket, or whitespace.
473      end_columns[batch_pos] = match.end() - 1
474
475      batch_pos += 1
476      if batch_pos >= batch_size:
477        break
478
479    offset_counter += 1
480
481  return start_columns, end_columns
482
483
484def _pad_string_to_length(string, length):
485  return " " * (length - len(string)) + string
486
487
488def numeric_summary(tensor):
489  """Get a text summary of a numeric tensor.
490
491  This summary is only available for numeric (int*, float*, complex*) and
492  Boolean tensors.
493
494  Args:
495    tensor: (`numpy.ndarray`) the tensor value object to be summarized.
496
497  Returns:
498    The summary text as a `RichTextLines` object. If the type of `tensor` is not
499    numeric or Boolean, a single-line `RichTextLines` object containing a
500    warning message will reflect that.
501  """
502
503  def _counts_summary(counts, skip_zeros=True, total_count=None):
504    """Format values as a two-row table."""
505    if skip_zeros:
506      counts = [(count_key, count_val) for count_key, count_val in counts
507                if count_val]
508    max_common_len = 0
509    for count_key, count_val in counts:
510      count_val_str = str(count_val)
511      common_len = max(len(count_key) + 1, len(count_val_str) + 1)
512      max_common_len = max(common_len, max_common_len)
513
514    key_line = debugger_cli_common.RichLine("|")
515    val_line = debugger_cli_common.RichLine("|")
516    for count_key, count_val in counts:
517      count_val_str = str(count_val)
518      key_line += _pad_string_to_length(count_key, max_common_len)
519      val_line += _pad_string_to_length(count_val_str, max_common_len)
520    key_line += " |"
521    val_line += " |"
522
523    if total_count is not None:
524      total_key_str = "total"
525      total_val_str = str(total_count)
526      max_common_len = max(len(total_key_str) + 1, len(total_val_str))
527      total_key_str = _pad_string_to_length(total_key_str, max_common_len)
528      total_val_str = _pad_string_to_length(total_val_str, max_common_len)
529      key_line += total_key_str + " |"
530      val_line += total_val_str + " |"
531
532    return debugger_cli_common.rich_text_lines_from_rich_line_list(
533        [key_line, val_line])
534
535  if not isinstance(tensor, np.ndarray) or not np.size(tensor):
536    return debugger_cli_common.RichTextLines([
537        "No numeric summary available due to empty tensor."])
538  elif (np.issubdtype(tensor.dtype, np.floating) or
539        np.issubdtype(tensor.dtype, np.complex) or
540        np.issubdtype(tensor.dtype, np.integer)):
541    counts = [
542        ("nan", np.sum(np.isnan(tensor))),
543        ("-inf", np.sum(np.isneginf(tensor))),
544        ("-", np.sum(np.logical_and(
545            tensor < 0.0, np.logical_not(np.isneginf(tensor))))),
546        ("0", np.sum(tensor == 0.0)),
547        ("+", np.sum(np.logical_and(
548            tensor > 0.0, np.logical_not(np.isposinf(tensor))))),
549        ("+inf", np.sum(np.isposinf(tensor)))]
550    output = _counts_summary(counts, total_count=np.size(tensor))
551
552    valid_array = tensor[
553        np.logical_not(np.logical_or(np.isinf(tensor), np.isnan(tensor)))]
554    if np.size(valid_array):
555      stats = [
556          ("min", np.min(valid_array)),
557          ("max", np.max(valid_array)),
558          ("mean", np.mean(valid_array)),
559          ("std", np.std(valid_array))]
560      output.extend(_counts_summary(stats, skip_zeros=False))
561    return output
562  elif tensor.dtype == np.bool:
563    counts = [
564        ("False", np.sum(tensor == 0)),
565        ("True", np.sum(tensor > 0)),]
566    return _counts_summary(counts, total_count=np.size(tensor))
567  else:
568    return debugger_cli_common.RichTextLines([
569        "No numeric summary available due to tensor dtype: %s." % tensor.dtype])
570