1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ========================================================================
15"""Utilities to handle tensor tracer parameters."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21
22import os
23import os.path
24import re
25
26from tensorflow.python.ops import linalg_ops
27from tensorflow.python.ops import math_ops
28from tensorflow.python.platform import tf_logging as logging
29
30TRACE_MODE_PART_TENSOR = 'part-tensor'
31TRACE_MODE_FULL_TENSOR = 'full-tensor'
32TRACE_MODE_FULL_TENSOR_SUMMARY = 'full_tensor_summary'
33
34TRACE_MODE_NAN_INF = 'nan-inf'
35TRACE_MODE_NORM = 'norm'
36TRACE_MODE_MAX_ABS = 'max-abs'
37TRACE_MODE_SUMMARY = 'summary'
38# summary mode to collects a finite set of signatures for each traced tensor,
39# (such as norm, max, min, mean) and dumps it using tb summaries.
40
41# Full tensor mode dumps the whole tensor values for the traced tensors without
42# any processing on them; using tb summaries.
43
44_SUBMODE_BRIEF = 'brief'
45_SUBMODE_DETAILED = 'detailed'
46
47_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'")
48_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"')
49_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)')
50_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*')
51
52FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS'
53FLAG_NAME_ENABLE = 'enable'
54FLAG_NAME_TRACE_MODE = 'trace_mode'
55FLAG_NAME_TRACE_SCALAR_OPS = 'trace_scalar'
56FLAG_NAME_SUBMODE = 'submode'
57FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames'
58FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes'
59FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames'
60FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes'
61FLAG_NAME_TRACE_LEVEL = 'trace_level'
62FLAG_NAME_TRACE_DIR = 'trace_dir'
63FLAG_NAME_REPORT_FILE = 'report_file'
64FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir'
65FLAG_NAME_OP_RANGE = 'op_range'
66# Folder to dump the pre (before tensor tracer updates) and post graphs (after
67# tensor tracer updates).
68FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs'
69FLAG_NAME_SUMMARY_SIGNATURES = 'signatures'
70FLAG_NAME_SUMMARY_PER_CORE = 'collect_summary_per_core'
71FLAG_NAME_TEMP_CACHE_VAR = 'use_temp_cache'
72FLAG_NAME_INSPECT_TRACE = 'inspect_trace'
73FLAG_NAME_FINGERPRINT_DIR = 'use_fingerprint_subdirectory'
74FLAG_FLUSH_SUMMARY = 'flush_summaries'
75
76_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)')
77_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR'
78
79_TT_DEFAULT_TRACE_LEVEL = 3
80_TT_PREFIX = 'tensor_tracer'
81
82_TT_NORM = 'norm'
83_TT_MAX = 'max'
84_TT_MAX_ABS = 'max-abs'
85_TT_MIN = 'min'
86_TT_MEAN = 'mean'
87_TT_VAR = 'var'
88_TT_SIZE = 'size'
89
90TT_SUMMARY_NORM = '%s_%s' % (_TT_PREFIX, _TT_NORM)
91TT_SUMMARY_MAX = '%s_%s' % (_TT_PREFIX, _TT_MAX)
92TT_SUMMARY_MAX_ABS = '%s_%s' % (_TT_PREFIX, _TT_MAX_ABS)
93TT_SUMMARY_MIN = '%s_%s' % (_TT_PREFIX, _TT_MIN)
94TT_SUMMARY_MEAN = '%s_%s' % (_TT_PREFIX, _TT_MEAN)
95TT_SUMMARY_VAR = '%s_%s' % (_TT_PREFIX, _TT_VAR)
96TT_SUMMARY_SIZE = '%s_%s' % (_TT_PREFIX, _TT_SIZE)
97
98TT_SUMMARY_SIGNATURES = (TT_SUMMARY_NORM, TT_SUMMARY_MAX, TT_SUMMARY_MIN,
99                         TT_SUMMARY_MEAN, TT_SUMMARY_VAR, TT_SUMMARY_SIZE,
100                         TT_SUMMARY_MAX_ABS)
101
102
103class TTParameters(object):
104  """A class that handles the parameters of Tensor Tracer."""
105
106  def __init__(self, env=None):
107    if env:
108      self._env = env
109    else:
110      self._env = os.environ
111    self._validate_flag_names()
112    self.trace_mode = self._get_trace_mode()
113    self.submode = self._get_submode()
114    self.trace_dir = self._get_trace_dir()
115    self.report_file_path = self._get_report_filepath()
116    self.op_range = self._get_op_range()
117    self.excluded_opname_re_list = self._flag_value_to_re_list(
118        FLAG_NAME_EXCLUDED_OPNAMES)
119    self.excluded_optype_re_list = self._flag_value_to_re_list(
120        FLAG_NAME_EXCLUDED_OPTYPES)
121
122    self.included_opname_re_list = self._flag_value_to_re_list(
123        FLAG_NAME_INCLUDED_OPNAMES)
124    self.included_optype_re_list = self._flag_value_to_re_list(
125        FLAG_NAME_INCLUDED_OPTYPES)
126
127    self.trace_scalar_ops = self.is_flag_on(FLAG_NAME_TRACE_SCALAR_OPS)
128    self.use_compact_trace = self.trace_mode in (TRACE_MODE_NAN_INF,
129                                                 TRACE_MODE_NORM,
130                                                 TRACE_MODE_MAX_ABS,
131                                                 TRACE_MODE_SUMMARY)
132    self.use_temp_cache_var = self.is_flag_on(FLAG_NAME_TEMP_CACHE_VAR)
133    self.inspect_trace = self.is_flag_on(FLAG_NAME_INSPECT_TRACE)
134    self.use_fingerprint_subdir = self.is_flag_on(FLAG_NAME_FINGERPRINT_DIR)
135
136    _, self.graph_dump_path = self.get_flag_value(
137        FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS)
138    self.trace_level = self._get_flag_int_value(FLAG_NAME_TRACE_LEVEL,
139                                                _TT_DEFAULT_TRACE_LEVEL)
140    self.summary_signatures = self._get_summary_signatures()
141    self.collect_summary_per_core = self.is_flag_on(FLAG_NAME_SUMMARY_PER_CORE)
142    self.flush_summaries_with_outside_compile = self.is_flag_on(
143        FLAG_FLUSH_SUMMARY)
144    self._check_flag_errors()
145
146  def _check_flag_errors(self):
147    if self.trace_mode in (TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY):
148      if not self.trace_dir:
149        raise ValueError('trace_dir must be explicitly provided in '
150                         'TENSOR_TRACER_FLAGS when summary mode is used.')
151
152  def _get_report_filepath(self):
153    """Sets the path of the output report file."""
154
155    found, report_file_path = self.get_flag_value(FLAG_NAME_REPORT_FILE)
156    if found and report_file_path \
157       and self.use_test_undeclared_outputs_dir():
158      if os.path.isabs(report_file_path):
159        raise ValueError('If use_test_undeclared_outputs_dir is set,'
160                         'report_file_path cannot be an absolute path (%s)'
161                         %report_file_path)
162      outputs_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR)
163      report_file_path = os.path.join(outputs_dir, report_file_path)
164    return report_file_path
165
166  def _get_op_range(self):
167    """Sets the index range of the Ops that we will consider tracing."""
168    found, op_range = self.get_flag_value(FLAG_NAME_OP_RANGE)
169    if not found or not op_range:
170      op_range = (-1, -1)  # this means including all ops.
171      return op_range
172    match = _OP_RANGE_PAT.match(op_range)
173    if not match:
174      op_range = (-1, -1)  # this means including all ops.
175      return op_range
176    op_range = (int(match.group(1)), int(match.group(2)))
177    return op_range
178
179  def _get_trace_dir(self):
180    found, trace_dir = self.get_flag_value(FLAG_NAME_TRACE_DIR)
181    if found and trace_dir \
182       and self.use_test_undeclared_outputs_dir():
183      raise ValueError(
184          'Cannot not use --%s and --%s at the same time' %
185          (FLAG_NAME_TRACE_DIR, FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR))
186    if self.use_test_undeclared_outputs_dir():
187      trace_dir = self._env.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR)
188    return trace_dir
189
190  def _get_trace_mode(self):
191    """Checks if the given trace mode is valid."""
192
193    found, trace_mode = self.get_flag_value(FLAG_NAME_TRACE_MODE)
194    if not found or not trace_mode:
195      trace_mode = TRACE_MODE_NORM
196    valid_trace_modes = [
197        TRACE_MODE_NAN_INF, TRACE_MODE_PART_TENSOR, TRACE_MODE_FULL_TENSOR,
198        TRACE_MODE_NORM, TRACE_MODE_MAX_ABS,
199        TRACE_MODE_SUMMARY, TRACE_MODE_FULL_TENSOR_SUMMARY
200    ]
201    if trace_mode not in valid_trace_modes:
202      raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.'
203                       'Valid trace modes are: %s'%(trace_mode,
204                                                    valid_trace_modes))
205    return trace_mode
206
207  def is_brief_mode(self):
208    return self.submode == _SUBMODE_BRIEF
209
210  def _get_submode(self):
211    """Checks if the given submode is valid."""
212
213    found, submode = self.get_flag_value(FLAG_NAME_SUBMODE)
214    if not found or not submode:
215      submode = _SUBMODE_DETAILED
216    if not submode:
217      return
218    valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF]
219    if submode not in valid_submodes:
220      raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.'
221                       'Valid submodes are: %s'%(submode,
222                                                 valid_submodes))
223    return submode
224
225  @staticmethod
226  def match_next_flag(flags, pos):
227    """Returns the match for the next TensorTracer flag.
228
229    Args:
230       flags: a string that contains the flags.
231       pos: where in flags to start the search.
232
233    Returns:
234       A pair where the first element is the regular-expression
235       match found and the second element indicates if the match
236       has a value.
237    """
238
239    match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos)
240    if match:
241      return match, True
242    match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos)
243    if match:
244      return match, True
245    match = _FLAG_NO_QUOTE_PAT.match(flags, pos)
246    if match:
247      return match, True
248    match = _FLAG_NO_EQUAL_PAT.match(flags, pos)
249    if match:
250      # The flag is found but is not given a value.
251      return match, False
252    # The flag is not found.
253    return None, False
254
255  def _validate_flag_names(self):
256    """Validates if the TensorTrace flags passed are valid."""
257    valid_flag_names = [
258        FLAG_NAME_ENABLE, FLAG_NAME_TRACE_MODE,
259        FLAG_NAME_TRACE_SCALAR_OPS,
260        FLAG_NAME_SUBMODE, FLAG_NAME_EXCLUDED_OPNAMES,
261        FLAG_NAME_EXCLUDED_OPTYPES, FLAG_NAME_INCLUDED_OPNAMES,
262        FLAG_NAME_INCLUDED_OPTYPES, FLAG_NAME_TRACE_DIR,
263        FLAG_NAME_REPORT_FILE,
264        FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR,
265        FLAG_NAME_OP_RANGE,
266        FLAG_NAME_DUMP_BEFORE_AFTER_GRAPHS, FLAG_NAME_TRACE_LEVEL,
267        FLAG_NAME_SUMMARY_SIGNATURES, FLAG_NAME_SUMMARY_PER_CORE,
268        FLAG_NAME_TEMP_CACHE_VAR, FLAG_NAME_FINGERPRINT_DIR,
269        FLAG_NAME_INSPECT_TRACE, FLAG_FLUSH_SUMMARY
270    ]
271    tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR)
272    if not tensor_tracer_flags:
273      return
274    pos = 0
275    while True:
276      match, _ = TTParameters.match_next_flag(tensor_tracer_flags, pos)
277      if not match:
278        break
279      flag_name = match.group(1)
280      if flag_name not in valid_flag_names:
281        raise ValueError(
282            'The flag name "%s" passed via the environment variable "%s" '
283            'is invalid. Valid flag names are:'
284            '\n%s' % (flag_name, FLAGS_ENV_VAR, valid_flag_names))
285      pos = match.end()
286
287  def _get_summary_signatures(self):
288    """Verifies and returns the summary signatures.
289
290    Returns:
291      A dictionary of the signature identifiers {signature: index} that will be
292      computed when trace_mode is summary.
293    """
294    signatures = self._flag_value_as_list(FLAG_NAME_SUMMARY_SIGNATURES)
295
296    tt_signatures = []
297    for signature in signatures:
298      signature_with_prefix = '%s_%s' % (_TT_PREFIX, signature)
299      if signature in TT_SUMMARY_SIGNATURES:
300        tt_signatures.append(signature)
301      elif signature_with_prefix in TT_SUMMARY_SIGNATURES:
302        tt_signatures.append(signature_with_prefix)
303      else:
304        logging.warning('Unknown signature:%s. Supported signatures: %s' % (
305            signature, TT_SUMMARY_SIGNATURES))
306    if not tt_signatures:
307      # Default case collects norm and max only.
308      return {TT_SUMMARY_MAX_ABS: 0, TT_SUMMARY_NORM: 1}
309    else:
310      return {signature: idx for idx, signature in enumerate(tt_signatures)}
311
312  def get_signature_to_agg_fn_map(self):
313    """Returns a map that contains the aggregate function for each signature."""
314    return {TRACE_MODE_NORM: linalg_ops.norm,
315            TRACE_MODE_MAX_ABS: math_ops.reduce_max,
316            TRACE_MODE_NAN_INF: math_ops.reduce_max,
317            TT_SUMMARY_NORM: linalg_ops.norm,
318            TT_SUMMARY_MAX: math_ops.reduce_max,
319            TT_SUMMARY_MAX_ABS:
320                lambda t, axis=0: math_ops.reduce_max(math_ops.abs(t),  # pylint: disable=g-long-lambda
321                                                      axis=axis),
322            TT_SUMMARY_MIN: math_ops.reduce_min,
323            TT_SUMMARY_MEAN: math_ops.reduce_mean,
324            TT_SUMMARY_VAR: math_ops.reduce_max,  # Simply reduce max variance.
325            TT_SUMMARY_SIZE: math_ops.reduce_sum}
326
327  def _flag_value_as_list(self, wanted_flag_name):
328    """Returns the string list of a TensorTracer flag.
329
330    Args:
331      wanted_flag_name: the name of the flag we are looking for.
332
333    Returns:
334      The list value of the flag.
335    """
336    string_value_list = []
337    found, flag_value = self.get_flag_value(wanted_flag_name)
338
339    if found:
340      string_value_list = flag_value.split(',')
341    return string_value_list
342
343  def _flag_value_as_int_list(self, wanted_flag_name):
344    """Returns the integer list of a TensorTracer flag.
345
346    Args:
347      wanted_flag_name: the name of the flag we are looking for.
348
349    Returns:
350      the value of the flag.
351    Raises:
352      RuntimeError: If supposedly deadcode is reached.
353    """
354    int_list = []
355    found, flag_value = self.get_flag_value(wanted_flag_name)
356
357    if found and flag_value:
358      try:
359        integer_values = flag_value.split(',')
360        int_list = [int(int_val) for int_val in integer_values]
361      except ValueError:
362        logging.warning('Cannot convert %s to int for flag %s', int_list,
363                        wanted_flag_name)
364    return int_list
365
366  def _get_flag_int_value(self, wanted_flag_name, default_value):
367    """Returns the int value of a TensorTracer flag.
368
369    Args:
370      wanted_flag_name: the name of the flag we are looking for.
371      default_value: the default value for the flag, if not provided.
372    Returns:
373      the value of the flag.
374    Raises:
375      RuntimeError: If supposedly deadcode is reached.
376    """
377    flag_int_value = default_value
378    found, flag_value = self.get_flag_value(wanted_flag_name)
379
380    if found:
381      try:
382        flag_int_value = int(flag_value)
383      except ValueError:
384        logging.warning('Cannot convert %s to int for flag %s' % (
385            flag_int_value, wanted_flag_name))
386    return flag_int_value
387
388  def get_flag_value(self, wanted_flag_name):
389    """Returns the value of a TensorTracer flags.
390
391    Args:
392      wanted_flag_name: the name of the flag we are looking for.
393
394    Returns:
395      A pair where the first element indicates if the flag is
396      found and the second element is the value of the flag.
397
398    Raises:
399      RuntimeError: If supposedly deadcode is reached.
400    """
401
402    tensor_tracer_flags = self._env.get(FLAGS_ENV_VAR)
403    if not tensor_tracer_flags:
404      return False, None
405    pos = 0
406    while True:
407      match, has_value = TTParameters.match_next_flag(
408          tensor_tracer_flags, pos)
409      if not match:
410        return False, None
411      flag_name = match.group(1)
412      if has_value:
413        flag_value = match.group(2)
414      else:
415        flag_value = None
416      if flag_name == wanted_flag_name:
417        return True, flag_value
418      pos = match.end()
419    raise RuntimeError('Should not reach here.')
420
421  def _flag_value_to_re_list(self, flag_name):
422    """Converts list of strings to compiled RE."""
423
424    re_list = []
425    found, flag_value = self.get_flag_value(flag_name)
426    if not found or not flag_value:
427      return re_list
428    list_of_values = flag_value.split(',')
429    for v in list_of_values:
430      r = re.compile(v)
431      re_list.append(r)
432    return re_list
433
434  def is_flag_on(self, flag_name):
435    """Returns True if the given flag is on."""
436
437    found, flag_value = self.get_flag_value(flag_name)
438    if not found:
439      return False
440    if flag_value is None:
441      return True
442    # Depends on the flag value.
443    flag_value = flag_value.lower()
444    enabled = flag_value in ['1', 't', 'true', 'y', 'yes']
445    return enabled
446
447  def is_enabled(self):
448    """Returns True if TensorTracer is enabled."""
449
450    if self.is_flag_on(FLAG_NAME_ENABLE):
451      logging.info('Tensor Tracer is enabled with flags %s.' %
452                   self._env.get(FLAGS_ENV_VAR))
453      return True
454    else:
455      return False
456
457  def use_test_undeclared_outputs_dir(self):
458    """Decides the output directory of the report and trace files.
459
460    Args:
461       None.
462
463    Returns:
464       True if the output files should be written to the
465       test-undeclared-outputs-directory defined via an
466       env variable.
467    """
468
469    return self.is_flag_on(FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)
470