1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Command parsing module for TensorFlow Debugger (tfdbg)."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import argparse
21import ast
22import re
23import sys
24
25
26_BRACKETS_PATTERN = re.compile(r"\[[^\]]*\]")
27_QUOTES_PATTERN = re.compile(r"(\"[^\"]*\"|\'[^\']*\')")
28_WHITESPACE_PATTERN = re.compile(r"\s+")
29
30_NUMBER_PATTERN = re.compile(r"[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?")
31
32
33class Interval(object):
34  """Represents an interval between a start and end value."""
35
36  def __init__(self, start, start_included, end, end_included):
37    self.start = start
38    self.start_included = start_included
39    self.end = end
40    self.end_included = end_included
41
42  def contains(self, value):
43    if value < self.start or value == self.start and not self.start_included:
44      return False
45    if value > self.end or value == self.end and not self.end_included:
46      return False
47    return True
48
49  def __eq__(self, other):
50    return (self.start == other.start and
51            self.start_included == other.start_included and
52            self.end == other.end and
53            self.end_included == other.end_included)
54
55
56def parse_command(command):
57  """Parse command string into a list of arguments.
58
59  - Disregards whitespace inside double quotes and brackets.
60  - Strips paired leading and trailing double quotes in arguments.
61  - Splits the command at whitespace.
62
63  Nested double quotes and brackets are not handled.
64
65  Args:
66    command: (str) Input command.
67
68  Returns:
69    (list of str) List of arguments.
70  """
71
72  command = command.strip()
73  if not command:
74    return []
75
76  brackets_intervals = [f.span() for f in _BRACKETS_PATTERN.finditer(command)]
77  quotes_intervals = [f.span() for f in _QUOTES_PATTERN.finditer(command)]
78  whitespaces_intervals = [
79      f.span() for f in _WHITESPACE_PATTERN.finditer(command)
80  ]
81
82  if not whitespaces_intervals:
83    return [command]
84
85  arguments = []
86  idx0 = 0
87  for start, end in whitespaces_intervals + [(len(command), None)]:
88    # Skip whitespace stretches enclosed in brackets or double quotes.
89
90    if not any(interval[0] < start < interval[1]
91               for interval in brackets_intervals + quotes_intervals):
92      argument = command[idx0:start]
93
94      # Strip leading and trailing double quote if they are paired.
95      if (argument.startswith("\"") and argument.endswith("\"") or
96          argument.startswith("'") and argument.endswith("'")):
97        argument = argument[1:-1]
98      arguments.append(argument)
99      idx0 = end
100
101  return arguments
102
103
104def extract_output_file_path(args):
105  """Extract output file path from command arguments.
106
107  Args:
108    args: (list of str) command arguments.
109
110  Returns:
111    (list of str) Command arguments with the output file path part stripped.
112    (str or None) Output file path (if any).
113
114  Raises:
115    SyntaxError: If there is no file path after the last ">" character.
116  """
117
118  if args and args[-1].endswith(">"):
119    raise SyntaxError("Redirect file path is empty")
120  elif args and args[-1].startswith(">"):
121    try:
122      _parse_interval(args[-1])
123      if len(args) > 1 and args[-2].startswith("-"):
124        output_file_path = None
125      else:
126        output_file_path = args[-1][1:]
127        args = args[:-1]
128    except ValueError:
129      output_file_path = args[-1][1:]
130      args = args[:-1]
131  elif len(args) > 1 and args[-2] == ">":
132    output_file_path = args[-1]
133    args = args[:-2]
134  elif args and args[-1].count(">") == 1:
135    gt_index = args[-1].index(">")
136    if gt_index > 0 and args[-1][gt_index - 1] == "=":
137      output_file_path = None
138    else:
139      output_file_path = args[-1][gt_index + 1:]
140      args[-1] = args[-1][:gt_index]
141  elif len(args) > 1 and args[-2].endswith(">"):
142    output_file_path = args[-1]
143    args = args[:-1]
144    args[-1] = args[-1][:-1]
145  else:
146    output_file_path = None
147
148  return args, output_file_path
149
150
151def parse_tensor_name_with_slicing(in_str):
152  """Parse tensor name, potentially suffixed by slicing string.
153
154  Args:
155    in_str: (str) Input name of the tensor, potentially followed by a slicing
156      string. E.g.: Without slicing string: "hidden/weights/Variable:0", with
157      slicing string: "hidden/weights/Variable:0[1, :]"
158
159  Returns:
160    (str) name of the tensor
161    (str) slicing string, if any. If no slicing string is present, return "".
162  """
163
164  if in_str.count("[") == 1 and in_str.endswith("]"):
165    tensor_name = in_str[:in_str.index("[")]
166    tensor_slicing = in_str[in_str.index("["):]
167  else:
168    tensor_name = in_str
169    tensor_slicing = ""
170
171  return tensor_name, tensor_slicing
172
173
174def validate_slicing_string(slicing_string):
175  """Validate a slicing string.
176
177  Check if the input string contains only brackets, digits, commas and
178  colons that are valid characters in numpy-style array slicing.
179
180  Args:
181    slicing_string: (str) Input slicing string to be validated.
182
183  Returns:
184    (bool) True if and only if the slicing string is valid.
185  """
186
187  return bool(re.search(r"^\[(\d|,|\s|:)+\]$", slicing_string))
188
189
190def _parse_slices(slicing_string):
191  """Construct a tuple of slices from the slicing string.
192
193  The string must be a valid slicing string.
194
195  Args:
196    slicing_string: (str) Input slicing string to be parsed.
197
198  Returns:
199    tuple(slice1, slice2, ...)
200
201  Raises:
202    ValueError: If tensor_slicing is not a valid numpy ndarray slicing str.
203  """
204  parsed = []
205  for slice_string in slicing_string[1:-1].split(","):
206    indices = slice_string.split(":")
207    if len(indices) == 1:
208      parsed.append(int(indices[0].strip()))
209    elif 2 <= len(indices) <= 3:
210      parsed.append(
211          slice(*[
212              int(index.strip()) if index.strip() else None for index in indices
213          ]))
214    else:
215      raise ValueError("Invalid tensor-slicing string.")
216  return tuple(parsed)
217
218
219def parse_indices(indices_string):
220  """Parse a string representing indices.
221
222  For example, if the input is "[1, 2, 3]", the return value will be a list of
223  indices: [1, 2, 3]
224
225  Args:
226    indices_string: (str) a string representing indices. Can optionally be
227      surrounded by a pair of brackets.
228
229  Returns:
230    (list of int): Parsed indices.
231  """
232
233  # Strip whitespace.
234  indices_string = re.sub(r"\s+", "", indices_string)
235
236  # Strip any brackets at the two ends.
237  if indices_string.startswith("[") and indices_string.endswith("]"):
238    indices_string = indices_string[1:-1]
239
240  return [int(element) for element in indices_string.split(",")]
241
242
243def parse_ranges(range_string):
244  """Parse a string representing numerical range(s).
245
246  Args:
247    range_string: (str) A string representing a numerical range or a list of
248      them. For example:
249        "[-1.0,1.0]", "[-inf, 0]", "[[-inf, -1.0], [1.0, inf]]"
250
251  Returns:
252    (list of list of float) A list of numerical ranges parsed from the input
253      string.
254
255  Raises:
256    ValueError: If the input doesn't represent a range or a list of ranges.
257  """
258
259  range_string = range_string.strip()
260  if not range_string:
261    return []
262
263  if "inf" in range_string:
264    range_string = re.sub(r"inf", repr(sys.float_info.max), range_string)
265
266  ranges = ast.literal_eval(range_string)
267  if isinstance(ranges, list) and not isinstance(ranges[0], list):
268    ranges = [ranges]
269
270  # Verify that ranges is a list of list of numbers.
271  for item in ranges:
272    if len(item) != 2:
273      raise ValueError("Incorrect number of elements in range")
274    elif not isinstance(item[0], (int, float)):
275      raise ValueError("Incorrect type in the 1st element of range: %s" %
276                       type(item[0]))
277    elif not isinstance(item[1], (int, float)):
278      raise ValueError("Incorrect type in the 2nd element of range: %s" %
279                       type(item[0]))
280
281  return ranges
282
283
284def parse_memory_interval(interval_str):
285  """Convert a human-readable memory interval to a tuple of start and end value.
286
287  Args:
288    interval_str: (`str`) A human-readable str representing an interval
289      (e.g., "[10kB, 20kB]", "<100M", ">100G"). Only the units "kB", "MB", "GB"
290      are supported. The "B character at the end of the input `str` may be
291      omitted.
292
293  Returns:
294    `Interval` object where start and end are in bytes.
295
296  Raises:
297    ValueError: if the input is not valid.
298  """
299  str_interval = _parse_interval(interval_str)
300  interval_start = 0
301  interval_end = float("inf")
302  if str_interval.start:
303    interval_start = parse_readable_size_str(str_interval.start)
304  if str_interval.end:
305    interval_end = parse_readable_size_str(str_interval.end)
306  if interval_start > interval_end:
307    raise ValueError(
308        "Invalid interval %s. Start of interval must be less than or equal "
309        "to end of interval." % interval_str)
310  return Interval(interval_start, str_interval.start_included,
311                  interval_end, str_interval.end_included)
312
313
314def parse_time_interval(interval_str):
315  """Convert a human-readable time interval to a tuple of start and end value.
316
317  Args:
318    interval_str: (`str`) A human-readable str representing an interval
319      (e.g., "[10us, 20us]", "<100s", ">100ms"). Supported time suffixes are
320      us, ms, s.
321
322  Returns:
323    `Interval` object where start and end are in microseconds.
324
325  Raises:
326    ValueError: if the input is not valid.
327  """
328  str_interval = _parse_interval(interval_str)
329  interval_start = 0
330  interval_end = float("inf")
331  if str_interval.start:
332    interval_start = parse_readable_time_str(str_interval.start)
333  if str_interval.end:
334    interval_end = parse_readable_time_str(str_interval.end)
335  if interval_start > interval_end:
336    raise ValueError(
337        "Invalid interval %s. Start must be before end of interval." %
338        interval_str)
339  return Interval(interval_start, str_interval.start_included,
340                  interval_end, str_interval.end_included)
341
342
343def _parse_interval(interval_str):
344  """Convert a human-readable interval to a tuple of start and end value.
345
346  Args:
347    interval_str: (`str`) A human-readable str representing an interval
348      (e.g., "[1M, 2M]", "<100k", ">100ms"). The items following the ">", "<",
349      ">=" and "<=" signs have to start with a number (e.g., 3.0, -2, .98).
350      The same requirement applies to the items in the parentheses or brackets.
351
352  Returns:
353    Interval object where start or end can be None
354    if the range is specified as "<N" or ">N" respectively.
355
356  Raises:
357    ValueError: if the input is not valid.
358  """
359  interval_str = interval_str.strip()
360  if interval_str.startswith("<="):
361    if _NUMBER_PATTERN.match(interval_str[2:].strip()):
362      return Interval(start=None, start_included=False,
363                      end=interval_str[2:].strip(), end_included=True)
364    else:
365      raise ValueError("Invalid value string after <= in '%s'" % interval_str)
366  if interval_str.startswith("<"):
367    if _NUMBER_PATTERN.match(interval_str[1:].strip()):
368      return Interval(start=None, start_included=False,
369                      end=interval_str[1:].strip(), end_included=False)
370    else:
371      raise ValueError("Invalid value string after < in '%s'" % interval_str)
372  if interval_str.startswith(">="):
373    if _NUMBER_PATTERN.match(interval_str[2:].strip()):
374      return Interval(start=interval_str[2:].strip(), start_included=True,
375                      end=None, end_included=False)
376    else:
377      raise ValueError("Invalid value string after >= in '%s'" % interval_str)
378  if interval_str.startswith(">"):
379    if _NUMBER_PATTERN.match(interval_str[1:].strip()):
380      return Interval(start=interval_str[1:].strip(), start_included=False,
381                      end=None, end_included=False)
382    else:
383      raise ValueError("Invalid value string after > in '%s'" % interval_str)
384
385  if (not interval_str.startswith(("[", "("))
386      or not interval_str.endswith(("]", ")"))):
387    raise ValueError(
388        "Invalid interval format: %s. Valid formats are: [min, max], "
389        "(min, max), <max, >min" % interval_str)
390  interval = interval_str[1:-1].split(",")
391  if len(interval) != 2:
392    raise ValueError(
393        "Incorrect interval format: %s. Interval should specify two values: "
394        "[min, max] or (min, max)." % interval_str)
395
396  start_item = interval[0].strip()
397  if not _NUMBER_PATTERN.match(start_item):
398    raise ValueError("Invalid first item in interval: '%s'" % start_item)
399  end_item = interval[1].strip()
400  if not _NUMBER_PATTERN.match(end_item):
401    raise ValueError("Invalid second item in interval: '%s'" % end_item)
402
403  return Interval(start=start_item,
404                  start_included=(interval_str[0] == "["),
405                  end=end_item,
406                  end_included=(interval_str[-1] == "]"))
407
408
409def parse_readable_size_str(size_str):
410  """Convert a human-readable str representation to number of bytes.
411
412  Only the units "kB", "MB", "GB" are supported. The "B character at the end
413  of the input `str` may be omitted.
414
415  Args:
416    size_str: (`str`) A human-readable str representing a number of bytes
417      (e.g., "0", "1023", "1.1kB", "24 MB", "23GB", "100 G".
418
419  Returns:
420    (`int`) The parsed number of bytes.
421
422  Raises:
423    ValueError: on failure to parse the input `size_str`.
424  """
425
426  size_str = size_str.strip()
427  if size_str.endswith("B"):
428    size_str = size_str[:-1]
429
430  if size_str.isdigit():
431    return int(size_str)
432  elif size_str.endswith("k"):
433    return int(float(size_str[:-1]) * 1024)
434  elif size_str.endswith("M"):
435    return int(float(size_str[:-1]) * 1048576)
436  elif size_str.endswith("G"):
437    return int(float(size_str[:-1]) * 1073741824)
438  else:
439    raise ValueError("Failed to parsed human-readable byte size str: \"%s\"" %
440                     size_str)
441
442
443def parse_readable_time_str(time_str):
444  """Parses a time string in the format N, Nus, Nms, Ns.
445
446  Args:
447    time_str: (`str`) string consisting of an integer time value optionally
448      followed by 'us', 'ms', or 's' suffix. If suffix is not specified,
449      value is assumed to be in microseconds. (e.g. 100us, 8ms, 5s, 100).
450
451  Returns:
452    Microseconds value.
453  """
454  def parse_positive_float(value_str):
455    value = float(value_str)
456    if value < 0:
457      raise ValueError(
458          "Invalid time %s. Time value must be positive." % value_str)
459    return value
460
461  time_str = time_str.strip()
462  if time_str.endswith("us"):
463    return int(parse_positive_float(time_str[:-2]))
464  elif time_str.endswith("ms"):
465    return int(parse_positive_float(time_str[:-2]) * 1e3)
466  elif time_str.endswith("s"):
467    return int(parse_positive_float(time_str[:-1]) * 1e6)
468  return int(parse_positive_float(time_str))
469
470
471def evaluate_tensor_slice(tensor, tensor_slicing):
472  """Call eval on the slicing of a tensor, with validation.
473
474  Args:
475    tensor: (numpy ndarray) The tensor value.
476    tensor_slicing: (str or None) Slicing of the tensor, e.g., "[:, 1]". If
477      None, no slicing will be performed on the tensor.
478
479  Returns:
480    (numpy ndarray) The sliced tensor.
481
482  Raises:
483    ValueError: If tensor_slicing is not a valid numpy ndarray slicing str.
484  """
485
486  _ = tensor
487
488  if not validate_slicing_string(tensor_slicing):
489    raise ValueError("Invalid tensor-slicing string.")
490
491  return tensor[_parse_slices(tensor_slicing)]
492
493
494def get_print_tensor_argparser(description):
495  """Get an ArgumentParser for a command that prints tensor values.
496
497  Examples of such commands include print_tensor and print_feed.
498
499  Args:
500    description: Description of the ArgumentParser.
501
502  Returns:
503    An instance of argparse.ArgumentParser.
504  """
505
506  ap = argparse.ArgumentParser(
507      description=description, usage=argparse.SUPPRESS)
508  ap.add_argument(
509      "tensor_name",
510      type=str,
511      help="Name of the tensor, followed by any slicing indices, "
512      "e.g., hidden1/Wx_plus_b/MatMul:0, "
513      "hidden1/Wx_plus_b/MatMul:0[1, :]")
514  ap.add_argument(
515      "-n",
516      "--number",
517      dest="number",
518      type=int,
519      default=-1,
520      help="0-based dump number for the specified tensor. "
521      "Required for tensor with multiple dumps.")
522  ap.add_argument(
523      "-r",
524      "--ranges",
525      dest="ranges",
526      type=str,
527      default="",
528      help="Numerical ranges to highlight tensor elements in. "
529      "Examples: -r 0,1e-8, -r [-0.1,0.1], "
530      "-r \"[[-inf, -0.1], [0.1, inf]]\"")
531  ap.add_argument(
532      "-a",
533      "--all",
534      dest="print_all",
535      action="store_true",
536      help="Print the tensor in its entirety, i.e., do not use ellipses.")
537  ap.add_argument(
538      "-s",
539      "--numeric_summary",
540      action="store_true",
541      help="Include summary for non-empty tensors of numeric (int*, float*, "
542      "complex*) and Boolean types.")
543  ap.add_argument(
544      "-w",
545      "--write_path",
546      type=str,
547      default="",
548      help="Path of the numpy file to write the tensor data to, using "
549      "numpy.save().")
550  return ap
551