1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Format tensors (ndarrays) for screen display and navigation.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import copy 21import re 22 23import numpy as np 24from six.moves import xrange # pylint: disable=redefined-builtin 25 26from tensorflow.python.debug.cli import debugger_cli_common 27from tensorflow.python.debug.lib import debug_data 28 29_NUMPY_OMISSION = "...," 30_NUMPY_DEFAULT_EDGE_ITEMS = 3 31 32_NUMBER_REGEX = re.compile(r"[-+]?([0-9][-+0-9eE\.]+|nan|inf)(\s|,|\])") 33 34BEGIN_INDICES_KEY = "i0" 35OMITTED_INDICES_KEY = "omitted" 36 37DEFAULT_TENSOR_ELEMENT_HIGHLIGHT_FONT_ATTR = "bold" 38 39 40class HighlightOptions(object): 41 """Options for highlighting elements of a tensor.""" 42 43 def __init__(self, 44 criterion, 45 description=None, 46 font_attr=DEFAULT_TENSOR_ELEMENT_HIGHLIGHT_FONT_ATTR): 47 """Constructor of HighlightOptions. 48 49 Args: 50 criterion: (callable) A callable of the following signature: 51 def to_highlight(X): 52 # Args: 53 # X: The tensor to highlight elements in. 54 # 55 # Returns: 56 # (boolean ndarray) A boolean ndarray of the same shape as X 57 # indicating which elements are to be highlighted (iff True). 58 This callable will be used as the argument of np.argwhere() to 59 determine which elements of the tensor are to be highlighted. 60 description: (str) Description of the highlight criterion embodied by 61 criterion. 62 font_attr: (str) Font attribute to be applied to the 63 highlighted elements. 64 65 """ 66 67 self.criterion = criterion 68 self.description = description 69 self.font_attr = font_attr 70 71 72def format_tensor(tensor, 73 tensor_label, 74 include_metadata=False, 75 auxiliary_message=None, 76 include_numeric_summary=False, 77 np_printoptions=None, 78 highlight_options=None): 79 """Generate a RichTextLines object showing a tensor in formatted style. 80 81 Args: 82 tensor: The tensor to be displayed, as a numpy ndarray or other 83 appropriate format (e.g., None representing uninitialized tensors). 84 tensor_label: A label for the tensor, as a string. If set to None, will 85 suppress the tensor name line in the return value. 86 include_metadata: Whether metadata such as dtype and shape are to be 87 included in the formatted text. 88 auxiliary_message: An auxiliary message to display under the tensor label, 89 dtype and shape information lines. 90 include_numeric_summary: Whether a text summary of the numeric values (if 91 applicable) will be included. 92 np_printoptions: A dictionary of keyword arguments that are passed to a 93 call of np.set_printoptions() to set the text format for display numpy 94 ndarrays. 95 highlight_options: (HighlightOptions) options for highlighting elements 96 of the tensor. 97 98 Returns: 99 A RichTextLines object. Its annotation field has line-by-line markups to 100 indicate which indices in the array the first element of each line 101 corresponds to. 102 """ 103 lines = [] 104 font_attr_segs = {} 105 106 if tensor_label is not None: 107 lines.append("Tensor \"%s\":" % tensor_label) 108 suffix = tensor_label.split(":")[-1] 109 if suffix.isdigit(): 110 # Suffix is a number. Assume it is the output slot index. 111 font_attr_segs[0] = [(8, 8 + len(tensor_label), "bold")] 112 else: 113 # Suffix is not a number. It is auxiliary information such as the debug 114 # op type. In this case, highlight the suffix with a different color. 115 debug_op_len = len(suffix) 116 proper_len = len(tensor_label) - debug_op_len - 1 117 font_attr_segs[0] = [ 118 (8, 8 + proper_len, "bold"), 119 (8 + proper_len + 1, 8 + proper_len + 1 + debug_op_len, "yellow") 120 ] 121 122 if isinstance(tensor, debug_data.InconvertibleTensorProto): 123 if lines: 124 lines.append("") 125 lines.extend(str(tensor).split("\n")) 126 return debugger_cli_common.RichTextLines(lines) 127 elif not isinstance(tensor, np.ndarray): 128 # If tensor is not a np.ndarray, return simple text-line representation of 129 # the object without annotations. 130 if lines: 131 lines.append("") 132 lines.extend(repr(tensor).split("\n")) 133 return debugger_cli_common.RichTextLines(lines) 134 135 if include_metadata: 136 lines.append(" dtype: %s" % str(tensor.dtype)) 137 lines.append(" shape: %s" % str(tensor.shape).replace("L", "")) 138 139 if lines: 140 lines.append("") 141 formatted = debugger_cli_common.RichTextLines( 142 lines, font_attr_segs=font_attr_segs) 143 144 if auxiliary_message: 145 formatted.extend(auxiliary_message) 146 147 if include_numeric_summary: 148 formatted.append("Numeric summary:") 149 formatted.extend(numeric_summary(tensor)) 150 formatted.append("") 151 152 # Apply custom string formatting options for numpy ndarray. 153 if np_printoptions is not None: 154 np.set_printoptions(**np_printoptions) 155 156 array_lines = repr(tensor).split("\n") 157 if tensor.dtype.type is not np.string_: 158 # Parse array lines to get beginning indices for each line. 159 160 # TODO(cais): Currently, we do not annotate string-type tensors due to 161 # difficulty in escaping sequences. Address this issue. 162 annotations = _annotate_ndarray_lines( 163 array_lines, tensor, np_printoptions=np_printoptions) 164 else: 165 annotations = None 166 formatted_array = debugger_cli_common.RichTextLines( 167 array_lines, annotations=annotations) 168 formatted.extend(formatted_array) 169 170 # Perform optional highlighting. 171 if highlight_options is not None: 172 indices_list = list(np.argwhere(highlight_options.criterion(tensor))) 173 174 total_elements = np.size(tensor) 175 highlight_summary = "Highlighted%s: %d of %d element(s) (%.2f%%)" % ( 176 "(%s)" % highlight_options.description if highlight_options.description 177 else "", len(indices_list), total_elements, 178 len(indices_list) / float(total_elements) * 100.0) 179 180 formatted.lines[0] += " " + highlight_summary 181 182 if indices_list: 183 indices_list = [list(indices) for indices in indices_list] 184 185 are_omitted, rows, start_cols, end_cols = locate_tensor_element( 186 formatted, indices_list) 187 for is_omitted, row, start_col, end_col in zip(are_omitted, rows, 188 start_cols, end_cols): 189 if is_omitted or start_col is None or end_col is None: 190 continue 191 192 if row in formatted.font_attr_segs: 193 formatted.font_attr_segs[row].append( 194 (start_col, end_col, highlight_options.font_attr)) 195 else: 196 formatted.font_attr_segs[row] = [(start_col, end_col, 197 highlight_options.font_attr)] 198 199 return formatted 200 201 202def _annotate_ndarray_lines( 203 array_lines, tensor, np_printoptions=None, offset=0): 204 """Generate annotations for line-by-line begin indices of tensor text. 205 206 Parse the numpy-generated text representation of a numpy ndarray to 207 determine the indices of the first element of each text line (if any 208 element is present in the line). 209 210 For example, given the following multi-line ndarray text representation: 211 ["array([[ 0. , 0.0625, 0.125 , 0.1875],", 212 " [ 0.25 , 0.3125, 0.375 , 0.4375],", 213 " [ 0.5 , 0.5625, 0.625 , 0.6875],", 214 " [ 0.75 , 0.8125, 0.875 , 0.9375]])"] 215 the generate annotation will be: 216 {0: {BEGIN_INDICES_KEY: [0, 0]}, 217 1: {BEGIN_INDICES_KEY: [1, 0]}, 218 2: {BEGIN_INDICES_KEY: [2, 0]}, 219 3: {BEGIN_INDICES_KEY: [3, 0]}} 220 221 Args: 222 array_lines: Text lines representing the tensor, as a list of str. 223 tensor: The tensor being formatted as string. 224 np_printoptions: A dictionary of keyword arguments that are passed to a 225 call of np.set_printoptions(). 226 offset: Line number offset applied to the line indices in the returned 227 annotation. 228 229 Returns: 230 An annotation as a dict. 231 """ 232 233 if np_printoptions and "edgeitems" in np_printoptions: 234 edge_items = np_printoptions["edgeitems"] 235 else: 236 edge_items = _NUMPY_DEFAULT_EDGE_ITEMS 237 238 annotations = {} 239 240 # Put metadata about the tensor in the annotations["tensor_metadata"]. 241 annotations["tensor_metadata"] = { 242 "dtype": tensor.dtype, "shape": tensor.shape} 243 244 dims = np.shape(tensor) 245 ndims = len(dims) 246 if ndims == 0: 247 # No indices for a 0D tensor. 248 return annotations 249 250 curr_indices = [0] * len(dims) 251 curr_dim = 0 252 for i in xrange(len(array_lines)): 253 line = array_lines[i].strip() 254 255 if not line: 256 # Skip empty lines, which can appear for >= 3D arrays. 257 continue 258 259 if line == _NUMPY_OMISSION: 260 annotations[offset + i] = {OMITTED_INDICES_KEY: copy.copy(curr_indices)} 261 curr_indices[curr_dim - 1] = dims[curr_dim - 1] - edge_items 262 else: 263 num_lbrackets = line.count("[") # TODO(cais): String array escaping. 264 num_rbrackets = line.count("]") 265 266 curr_dim += num_lbrackets - num_rbrackets 267 268 annotations[offset + i] = {BEGIN_INDICES_KEY: copy.copy(curr_indices)} 269 if num_rbrackets == 0: 270 line_content = line[line.rfind("[") + 1:] 271 num_elements = line_content.count(",") 272 curr_indices[curr_dim - 1] += num_elements 273 else: 274 if curr_dim > 0: 275 curr_indices[curr_dim - 1] += 1 276 for k in xrange(curr_dim, ndims): 277 curr_indices[k] = 0 278 279 return annotations 280 281 282def locate_tensor_element(formatted, indices): 283 """Locate a tensor element in formatted text lines, given element indices. 284 285 Given a RichTextLines object representing a tensor and indices of the sought 286 element, return the row number at which the element is located (if exists). 287 288 Args: 289 formatted: A RichTextLines object containing formatted text lines 290 representing the tensor. 291 indices: Indices of the sought element, as a list of int or a list of list 292 of int. The former case is for a single set of indices to look up, 293 whereas the latter case is for looking up a batch of indices sets at once. 294 In the latter case, the indices must be in ascending order, or a 295 ValueError will be raised. 296 297 Returns: 298 1) A boolean indicating whether the element falls into an omitted line. 299 2) Row index. 300 3) Column start index, i.e., the first column in which the representation 301 of the specified tensor starts, if it can be determined. If it cannot 302 be determined (e.g., due to ellipsis), None. 303 4) Column end index, i.e., the column right after the last column that 304 represents the specified tensor. Iff it cannot be determined, None. 305 306 For return values described above are based on a single set of indices to 307 look up. In the case of batch mode (multiple sets of indices), the return 308 values will be lists of the types described above. 309 310 Raises: 311 AttributeError: If: 312 Input argument "formatted" does not have the required annotations. 313 ValueError: If: 314 1) Indices do not match the dimensions of the tensor, or 315 2) Indices exceed sizes of the tensor, or 316 3) Indices contain negative value(s). 317 4) If in batch mode, and if not all sets of indices are in ascending 318 order. 319 """ 320 321 if isinstance(indices[0], list): 322 indices_list = indices 323 input_batch = True 324 else: 325 indices_list = [indices] 326 input_batch = False 327 328 # Check that tensor_metadata is available. 329 if "tensor_metadata" not in formatted.annotations: 330 raise AttributeError("tensor_metadata is not available in annotations.") 331 332 # Sanity check on input argument. 333 _validate_indices_list(indices_list, formatted) 334 335 dims = formatted.annotations["tensor_metadata"]["shape"] 336 batch_size = len(indices_list) 337 lines = formatted.lines 338 annot = formatted.annotations 339 prev_r = 0 340 prev_line = "" 341 prev_indices = [0] * len(dims) 342 343 # Initialize return values 344 are_omitted = [None] * batch_size 345 row_indices = [None] * batch_size 346 start_columns = [None] * batch_size 347 end_columns = [None] * batch_size 348 349 batch_pos = 0 # Current position in the batch. 350 351 for r in xrange(len(lines)): 352 if r not in annot: 353 continue 354 355 if BEGIN_INDICES_KEY in annot[r]: 356 indices_key = BEGIN_INDICES_KEY 357 elif OMITTED_INDICES_KEY in annot[r]: 358 indices_key = OMITTED_INDICES_KEY 359 360 matching_indices_list = [ 361 ind for ind in indices_list[batch_pos:] 362 if prev_indices <= ind < annot[r][indices_key] 363 ] 364 365 if matching_indices_list: 366 num_matches = len(matching_indices_list) 367 368 match_start_columns, match_end_columns = _locate_elements_in_line( 369 prev_line, matching_indices_list, prev_indices) 370 371 start_columns[batch_pos:batch_pos + num_matches] = match_start_columns 372 end_columns[batch_pos:batch_pos + num_matches] = match_end_columns 373 are_omitted[batch_pos:batch_pos + num_matches] = [ 374 OMITTED_INDICES_KEY in annot[prev_r] 375 ] * num_matches 376 row_indices[batch_pos:batch_pos + num_matches] = [prev_r] * num_matches 377 378 batch_pos += num_matches 379 if batch_pos >= batch_size: 380 break 381 382 prev_r = r 383 prev_line = lines[r] 384 prev_indices = annot[r][indices_key] 385 386 if batch_pos < batch_size: 387 matching_indices_list = indices_list[batch_pos:] 388 num_matches = len(matching_indices_list) 389 390 match_start_columns, match_end_columns = _locate_elements_in_line( 391 prev_line, matching_indices_list, prev_indices) 392 393 start_columns[batch_pos:batch_pos + num_matches] = match_start_columns 394 end_columns[batch_pos:batch_pos + num_matches] = match_end_columns 395 are_omitted[batch_pos:batch_pos + num_matches] = [ 396 OMITTED_INDICES_KEY in annot[prev_r] 397 ] * num_matches 398 row_indices[batch_pos:batch_pos + num_matches] = [prev_r] * num_matches 399 400 if input_batch: 401 return are_omitted, row_indices, start_columns, end_columns 402 else: 403 return are_omitted[0], row_indices[0], start_columns[0], end_columns[0] 404 405 406def _validate_indices_list(indices_list, formatted): 407 prev_ind = None 408 for ind in indices_list: 409 # Check indices match tensor dimensions. 410 dims = formatted.annotations["tensor_metadata"]["shape"] 411 if len(ind) != len(dims): 412 raise ValueError("Dimensions mismatch: requested: %d; actual: %d" % 413 (len(ind), len(dims))) 414 415 # Check indices is within size limits. 416 for req_idx, siz in zip(ind, dims): 417 if req_idx >= siz: 418 raise ValueError("Indices exceed tensor dimensions.") 419 if req_idx < 0: 420 raise ValueError("Indices contain negative value(s).") 421 422 # Check indices are in ascending order. 423 if prev_ind and ind < prev_ind: 424 raise ValueError("Input indices sets are not in ascending order.") 425 426 prev_ind = ind 427 428 429def _locate_elements_in_line(line, indices_list, ref_indices): 430 """Determine the start and end indices of an element in a line. 431 432 Args: 433 line: (str) the line in which the element is to be sought. 434 indices_list: (list of list of int) list of indices of the element to 435 search for. Assumes that the indices in the batch are unique and sorted 436 in ascending order. 437 ref_indices: (list of int) reference indices, i.e., the indices of the 438 first element represented in the line. 439 440 Returns: 441 start_columns: (list of int) start column indices, if found. If not found, 442 None. 443 end_columns: (list of int) end column indices, if found. If not found, 444 None. 445 If found, the element is represented in the left-closed-right-open interval 446 [start_column, end_column]. 447 """ 448 449 batch_size = len(indices_list) 450 offsets = [indices[-1] - ref_indices[-1] for indices in indices_list] 451 452 start_columns = [None] * batch_size 453 end_columns = [None] * batch_size 454 455 if _NUMPY_OMISSION in line: 456 ellipsis_index = line.find(_NUMPY_OMISSION) 457 else: 458 ellipsis_index = len(line) 459 460 matches_iter = re.finditer(_NUMBER_REGEX, line) 461 462 batch_pos = 0 463 464 offset_counter = 0 465 for match in matches_iter: 466 if match.start() > ellipsis_index: 467 # Do not attempt to search beyond ellipsis. 468 break 469 470 if offset_counter == offsets[batch_pos]: 471 start_columns[batch_pos] = match.start() 472 # Remove the final comma, right bracket, or whitespace. 473 end_columns[batch_pos] = match.end() - 1 474 475 batch_pos += 1 476 if batch_pos >= batch_size: 477 break 478 479 offset_counter += 1 480 481 return start_columns, end_columns 482 483 484def _pad_string_to_length(string, length): 485 return " " * (length - len(string)) + string 486 487 488def numeric_summary(tensor): 489 """Get a text summary of a numeric tensor. 490 491 This summary is only available for numeric (int*, float*, complex*) and 492 Boolean tensors. 493 494 Args: 495 tensor: (`numpy.ndarray`) the tensor value object to be summarized. 496 497 Returns: 498 The summary text as a `RichTextLines` object. If the type of `tensor` is not 499 numeric or Boolean, a single-line `RichTextLines` object containing a 500 warning message will reflect that. 501 """ 502 503 def _counts_summary(counts, skip_zeros=True, total_count=None): 504 """Format values as a two-row table.""" 505 if skip_zeros: 506 counts = [(count_key, count_val) for count_key, count_val in counts 507 if count_val] 508 max_common_len = 0 509 for count_key, count_val in counts: 510 count_val_str = str(count_val) 511 common_len = max(len(count_key) + 1, len(count_val_str) + 1) 512 max_common_len = max(common_len, max_common_len) 513 514 key_line = debugger_cli_common.RichLine("|") 515 val_line = debugger_cli_common.RichLine("|") 516 for count_key, count_val in counts: 517 count_val_str = str(count_val) 518 key_line += _pad_string_to_length(count_key, max_common_len) 519 val_line += _pad_string_to_length(count_val_str, max_common_len) 520 key_line += " |" 521 val_line += " |" 522 523 if total_count is not None: 524 total_key_str = "total" 525 total_val_str = str(total_count) 526 max_common_len = max(len(total_key_str) + 1, len(total_val_str)) 527 total_key_str = _pad_string_to_length(total_key_str, max_common_len) 528 total_val_str = _pad_string_to_length(total_val_str, max_common_len) 529 key_line += total_key_str + " |" 530 val_line += total_val_str + " |" 531 532 return debugger_cli_common.rich_text_lines_from_rich_line_list( 533 [key_line, val_line]) 534 535 if not isinstance(tensor, np.ndarray) or not np.size(tensor): 536 return debugger_cli_common.RichTextLines([ 537 "No numeric summary available due to empty tensor."]) 538 elif (np.issubdtype(tensor.dtype, np.floating) or 539 np.issubdtype(tensor.dtype, np.complex) or 540 np.issubdtype(tensor.dtype, np.integer)): 541 counts = [ 542 ("nan", np.sum(np.isnan(tensor))), 543 ("-inf", np.sum(np.isneginf(tensor))), 544 ("-", np.sum(np.logical_and( 545 tensor < 0.0, np.logical_not(np.isneginf(tensor))))), 546 ("0", np.sum(tensor == 0.0)), 547 ("+", np.sum(np.logical_and( 548 tensor > 0.0, np.logical_not(np.isposinf(tensor))))), 549 ("+inf", np.sum(np.isposinf(tensor)))] 550 output = _counts_summary(counts, total_count=np.size(tensor)) 551 552 valid_array = tensor[ 553 np.logical_not(np.logical_or(np.isinf(tensor), np.isnan(tensor)))] 554 if np.size(valid_array): 555 stats = [ 556 ("min", np.min(valid_array)), 557 ("max", np.max(valid_array)), 558 ("mean", np.mean(valid_array)), 559 ("std", np.std(valid_array))] 560 output.extend(_counts_summary(stats, skip_zeros=False)) 561 return output 562 elif tensor.dtype == np.bool: 563 counts = [ 564 ("False", np.sum(tensor == 0)), 565 ("True", np.sum(tensor > 0)),] 566 return _counts_summary(counts, total_count=np.size(tensor)) 567 else: 568 return debugger_cli_common.RichTextLines([ 569 "No numeric summary available due to tensor dtype: %s." % tensor.dtype]) 570