1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ========================================================================
15"""A utility to trace tensor values on TPU."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import os.path
23import re
24import sys
25
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import graph_io
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import control_flow_util
34from tensorflow.python.ops import gen_math_ops
35from tensorflow.python.ops import init_ops
36from tensorflow.python.ops import linalg_ops
37from tensorflow.python.ops import logging_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import state_ops
40from tensorflow.python.ops import variable_scope
41from tensorflow.python.platform import gfile
42from tensorflow.python.platform import tf_logging as logging
43from tensorflow.python.tpu import tpu
44from tensorflow.python.tpu.ops import tpu_ops
45
46_TRACER_LOG_PREFIX = ' [>>>TT>>>]'
47_DEVICE_TYPE_TPU = 'tpu'
48_DEVICE_TYPE_CPU = 'cpu'
49_TRACE_MODE_NAN_INF = 'nan-inf'
50_TRACE_MODE_PART_TENSOR = 'part-tensor'
51_TRACE_MODE_PART_TENSOR_SIZE = 3
52_TRACE_MODE_FULL_TENSOR = 'full-tensor'
53_TRACE_MODE_NORM = 'norm'
54_TRACE_MODE_MAX_ABS = 'max-abs'
55_SUBMODE_BRIEF = 'brief'
56_SUBMODE_DETAILED = 'detailed'
57_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range'
58_REASON_UNSAFE_OP = 'not-traced-unsafe-op'
59_REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op'
60_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar'
61_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op'
62_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch'
63_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape'
64_REASON_SCALAR_GET_TRACED = 'traced-scalar'
65_REASON_TENSOR_GET_TRACED = 'traced-tensor'
66_REASON_USER_INCLUDED = 'traced-user-included'
67_REASON_USER_EXCLUDED = 'not-traced-user-excluded'
68_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path'
69_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor'
70_REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op'
71_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:'
72_MARKER_SECTION_END = '!!!!!!! section-end:'
73_SECTION_NAME_CONFIG = 'configuration'
74_SECTION_NAME_REASON = 'reason'
75_SECTION_NAME_OP_LIST = 'op-list'
76_SECTION_NAME_TENSOR_LIST = 'tensor-list'
77_SECTION_NAME_CACHE_INDEX_MAP = 'cache-index-map'
78_SECTION_NAME_GRAPH = 'graph'
79_FIELD_NAME_VERSION = 'version:'
80_FIELD_NAME_DEVICE = 'device:'
81_FIELD_NAME_TRACE_MODE = 'trace-mode:'
82_FIELD_NAME_SUBMODE = 'submode:'
83_FIELD_NAME_NUM_REPLICAS = 'num-replicas:'
84_FIELD_NAME_NUM_REPLICAS_PER_HOST = 'num-replicas-per-host:'
85_FIELD_NAME_NUM_HOSTS = 'num-hosts:'
86_FIELD_NAME_NUM_OPS = 'number-of-ops:'
87_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:'
88_FIELD_NAME_NUM_CACHE_INDICES = 'number-of-indices:'
89_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:'
90_FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS'
91_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'")
92_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"')
93_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)')
94_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*')
95_FLAG_NAME_ENABLE = 'enable'
96_FLAG_NAME_TRACE_MODE = 'trace_mode'
97_FLAG_NAME_USE_COMPACT_TRACE = 'compact_trace'
98_FLAG_NAME_SUBMODE = 'submode'
99_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops'
100_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames'
101_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes'
102_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames'
103_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes'
104_FLAG_NAME_TRACE_DIR = 'trace_dir'
105_FLAG_NAME_REPORT_FILE = 'report_file'
106_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir'
107_FLAG_NAME_OP_RANGE = 'op_range'
108# Folder to dump the pre (before tensor tracer updates) and post graphs (after
109# tensor tracer updates).
110_FLAG_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs'
111_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)')
112_OUTPUT_STREAM_ESCAPE = 'file://'
113_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR'
114_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables'
115_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint'
116_TRACE_FILE_NAME = 'trace.all'
117_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.'
118_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0
119_TENSOR_TRACER_STORAGE = 'tensor_tracer_storage'
120_TENSOR_VALUES_CACHE = 'tensor_values_cache'
121_REPLICA_ID_TAG = '#replica-id: '
122
123
124def tensor_tracepoint(tensor, checkpoint_name):
125  """Adds a checkpoint with the given checkpoint name for the given tensor.
126
127  The tensor will be added to the list of tensors that will be traced by the
128  tensor tracer.
129
130  Args:
131     tensor: the tensor object for which the tracing is requested.
132     checkpoint_name: a string name for the checkpoint. This name has to be a
133     unique name if used within model comparison. The tensors that have the same
134     checkpoint identifier is compared in model comparison.
135  Returns:
136    The provided tensor.
137  """
138
139  tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION)
140  tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION,
141                                 (tensor, checkpoint_name))
142  return tensor
143
144
145def keras_layer_tracepoint(layer, checkpoint_name):
146  """An interface for adding the tensor outputs of a keras layer.
147
148  Encapsulates tensor_tracepoint.
149
150  Args:
151     layer: A keras layer.
152     checkpoint_name: a string name for the checkpoint. This name has to be a
153     unique name if used within model comparison. The tensors that have the same
154     checkpoint identifier is compared in model comparison.
155
156  Returns:
157    The provided layer.
158  """
159  try:
160    outputs = layer.output
161    if tensor_util.is_tensor(outputs):
162      tensor_tracepoint(outputs, '%s' % (checkpoint_name))
163    else:
164      idx = 0
165      for output_tensor in outputs:
166        if tensor_util.is_tensor(outputs):
167          tensor_tracepoint(output_tensor, '%s_%d' % (checkpoint_name, idx))
168        idx += 1
169  except AttributeError:
170    pass
171  except RuntimeError:
172    pass
173  return layer
174
175
176def _trace_files_need_precreated(output_dir):
177  """Return True if trace files must be pre-created by users."""
178
179  if not output_dir.startswith('/'):
180    return False
181  if len(output_dir) < 5:
182    return False
183  if output_dir[2] != 'n':
184    return False
185  if output_dir[3] != 's':
186    return False
187  if output_dir[1] != 'c':
188    return False
189  if output_dir[4] != '/':
190    return False
191  return True
192
193
194def _get_tensor_values_cache(graph=None):
195  """Returns the variable that implements tensor-value caching."""
196
197  graph = graph or ops.get_default_graph()
198  collection = graph.get_collection(_TENSOR_TRACER_STORAGE)
199  if len(collection) == 1:
200    return collection[0]
201  elif not collection:
202    raise RuntimeError('%s has not been created'%_TENSOR_VALUES_CACHE)
203  else:
204    raise RuntimeError('Multiple %s created'%_TENSOR_VALUES_CACHE)
205  return None
206
207
208def _create_tensor_values_cache(graph, num_tensors):
209  """Creates a variable as the cache to store intermediate tensor values."""
210  graph = graph or ops.get_default_graph()
211  # Create in proper graph and base name_scope.
212  with graph.as_default() as g, g.name_scope(None):
213    return variable_scope.get_variable(
214        _TENSOR_VALUES_CACHE,
215        shape=[num_tensors],
216        dtype=dtypes.float32,
217        initializer=init_ops.constant_initializer(
218            _COMPACT_TRACE_ENTRY_INIT_VALUE),
219        trainable=False,
220        use_resource=True,
221        collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.GLOBAL_VARIABLES])
222
223
224class TensorTracer(object):
225  """A software construct for tracing tensor values in a TF graph on TPU.
226
227  This utility is disabled by default. It can be enabled by setting
228  the TENSOR_TRACER_FLAGS env variable as:
229    export TENSOR_TRACER_FLAGS="--enable=1"
230  If it is enabled, it will trace the output tensor values of
231  selected Ops in the graph. It has two outputs: (1) the traces and (2)
232  a report. The traces are dumped to a specified local file on the TPU
233  host. The report is printed to the log.info of the TPU job.
234  By passing options via the env variable, users can change:
235     (1) the trace mode (e.g., detecting NaN/Inf, printing partial or
236         full tensor values)
237     (2) which Ops to be traced (via op.name or op.type)
238     (3) output trace file path.
239  """
240  # The set of graphs that are rewritten by tensor tracer.
241  _traced_graphs = set()
242  @staticmethod
243  def _match_next_flag(flags, pos):
244    """Returns the match for the next TensorTracer flag.
245
246    Args:
247       flags: a string that contains the flags.
248       pos: where in flags to start the search.
249
250    Returns:
251       A pair where the first element is the regular-expression
252       match found and the second element indicates if the match
253       has a value.
254    """
255
256    match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos)
257    if match:
258      return match, True
259    match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos)
260    if match:
261      return match, True
262    match = _FLAG_NO_QUOTE_PAT.match(flags, pos)
263    if match:
264      return match, True
265    match = _FLAG_NO_EQUAL_PAT.match(flags, pos)
266    if match:
267      # The flag is found but is not given a value.
268      return match, False
269    # The flag is not found.
270    return None, False
271
272  @staticmethod
273  def validate_flag_names():
274    """Validates if the TensorTrace flags passed are valid."""
275    valid_flag_names = [_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE,
276                        _FLAG_NAME_USE_COMPACT_TRACE,
277                        _FLAG_NAME_SUBMODE,
278                        _FLAG_NAME_EXCLUDED_OPNAMES,
279                        _FLAG_NAME_EXCLUDED_OPTYPES,
280                        _FLAG_NAME_INCLUDED_OPNAMES,
281                        _FLAG_NAME_INCLUDED_OPTYPES,
282                        _FLAG_NAME_TRACE_DIR,
283                        _FLAG_NAME_REPORT_FILE,
284                        _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR,
285                        _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS,
286                        _FLAG_NAME_OP_RANGE,
287                        _FLAG_DUMP_BEFORE_AFTER_GRAPHS]
288    tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR)
289    if not tensor_tracer_flags:
290      return
291    pos = 0
292    while True:
293      match, _ = TensorTracer._match_next_flag(tensor_tracer_flags, pos)
294      if not match:
295        break
296      flag_name = match.group(1)
297      if flag_name not in valid_flag_names:
298        raise ValueError(
299            'The flag name "%s" passed via the environment variable "%s" '
300            'is invalid. Valid flag names are:'
301            '\n%s'%(flag_name, _FLAGS_ENV_VAR, valid_flag_names))
302      pos = match.end()
303
304  @staticmethod
305  def print_flag_values():
306    """Prints all TensorTracer flags passed via environment variables."""
307
308    tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR)
309    if not tensor_tracer_flags:
310      return 'Env variable "%s" is not set'%_FLAGS_ENV_VAR
311    result = 'Env variable "%s" is set to "%s"\n'%(_FLAGS_ENV_VAR,
312                                                   tensor_tracer_flags)
313    result += 'Individual flag value:\n'
314    pos = 0
315    while True:
316      match, has_value = TensorTracer._match_next_flag(
317          tensor_tracer_flags, pos)
318      if not match:
319        break
320      flag_name = match.group(1)
321      if has_value:
322        flag_value = match.group(2)
323      else:
324        flag_value = None
325      result += '  %s: %s\n'%(flag_name, flag_value)
326      pos = match.end()
327    result += '\n'
328    return result
329
330  @staticmethod
331  def get_flag_value(wanted_flag_name):
332    """Returns the value of a TensorTracer flags.
333
334    Args:
335      wanted_flag_name: the name the the flag we are looking for.
336
337    Returns:
338      A pair where the first element indicates if the flag is
339      found and the second element is the value of the flag.
340
341    Raises:
342      RuntimeError: If supposedly deadcode is reached.
343    """
344
345    tensor_tracer_flags = os.getenv(_FLAGS_ENV_VAR)
346    if not tensor_tracer_flags:
347      return False, None
348    pos = 0
349    while True:
350      match, has_value = TensorTracer._match_next_flag(
351          tensor_tracer_flags, pos)
352      if not match:
353        return False, None
354      flag_name = match.group(1)
355      if has_value:
356        flag_value = match.group(2)
357      else:
358        flag_value = None
359      if flag_name == wanted_flag_name:
360        return True, flag_value
361      pos = match.end()
362    raise RuntimeError('Should not reach here.')
363
364  @staticmethod
365  def flag_value_to_re_list(flag_name):
366    """Converts list of strings to compiled RE."""
367
368    re_list = []
369    found, flag_value = TensorTracer.get_flag_value(flag_name)
370    if not found or not flag_value:
371      return re_list
372    list_of_values = flag_value.split()
373    for v in list_of_values:
374      r = re.compile(v)
375      re_list.append(r)
376    return re_list
377
378  @staticmethod
379  def _is_flag_on(flag_name):
380    """Returns True if the given flag is on."""
381
382    found, flag_value = TensorTracer.get_flag_value(flag_name)
383    if not found:
384      return False
385    if flag_value is None:
386      return True
387    # Depends on the flag value.
388    flag_value = flag_value.lower()
389    enabled = flag_value in ['1', 't', 'true', 'y', 'yes']
390    return enabled
391
392  @staticmethod
393  def is_enabled():
394    """Returns True if TensorTracer is enabled."""
395
396    return TensorTracer._is_flag_on(_FLAG_NAME_ENABLE)
397
398  @staticmethod
399  def use_test_undeclared_outputs_dir():
400    """Decides the output directory of the report and trace files.
401
402    Args:
403       None.
404
405    Returns:
406       True if the output files should be written to the
407       test-undeclared-outputs-directory defined via an
408       env variable.
409    """
410
411    return TensorTracer._is_flag_on(
412        _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)
413
414  @staticmethod
415  def use_compact_trace():
416    return TensorTracer._is_flag_on(
417        _FLAG_NAME_USE_COMPACT_TRACE)
418
419  @staticmethod
420  def check_device_type(device_type):
421    """Checks if the given device type is valid."""
422
423    if device_type not in [_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU]:
424      raise ValueError('Invalid device_type "%s"'%device_type)
425
426  @staticmethod
427  def check_trace_mode(trace_mode):
428    """Checks if the given trace mode is valid."""
429
430    valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR,
431                         _TRACE_MODE_FULL_TENSOR, _TRACE_MODE_NORM,
432                         _TRACE_MODE_MAX_ABS]
433    if trace_mode not in valid_trace_modes:
434      raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.'
435                       'Valid trace modes are: %s'%(trace_mode,
436                                                    valid_trace_modes))
437
438  @staticmethod
439  def check_submode(submode):
440    """Checks if the given submode is valid."""
441
442    if not submode:
443      return
444    valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF]
445    if submode not in valid_submodes:
446      raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.'
447                       'Valid submodes are: %s'%(submode,
448                                                 valid_submodes))
449
450  @staticmethod
451  def loop_cond_op(op):
452    return op.type in ('LoopCond', 'RefLoopCond')
453
454  @staticmethod
455  def while_loop_op(op):
456    """Returns true if op is one of the special ops of in a while loop.
457
458    Args:
459       op: A tf.Operation.
460
461    Returns:
462       True if the given op is one of [Switch, Merge, Enter, Exit,
463       NextIteration, LoopCond], which are all building blocks for TF while
464       loops.
465    """
466    return  (control_flow_util.IsLoopSwitch(op) or
467             control_flow_util.IsLoopMerge(op) or
468             control_flow_util.IsLoopEnter(op) or
469             control_flow_util.IsLoopExit(op) or
470             TensorTracer.loop_cond_op(op) or
471             op.type in ('RefNextIteration', 'NextIteration'))
472
473  @staticmethod
474  def unsafe_op(op):
475    """Returns True if this op is not safe to be traced."""
476
477    if control_flow_util.IsInCond(op):
478      return True
479    # Reasons for not including following op types:
480    #    Assign: cause incorrect result with CPU tracing.
481    if op.type in ['Assign']:
482      return True
483    return False
484
485  @staticmethod
486  def device_mismatch(device_type, op):
487    if device_type == _DEVICE_TYPE_TPU:
488      # pylint: disable=protected-access
489      return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr
490      # pylint: enable=protected-access
491    return False
492
493  @staticmethod
494  def unsafe_scalar_trace(op):
495    """Return true if scalar output tensor from Op is not safe to be traced."""
496
497    # Tracing the following causes cycle in the graph on TPU.
498    if op.type in ['LoopCond', 'Enter', 'Merge', 'Const',
499                   'Switch', 'Less', 'ReadVariableOp']:
500      return True
501    # Tracing the following will cause casting-issue
502    # with the norm tracing mode or other compilation issues on CPU.
503    if op.type in ['VarHandleOp', 'IteratorToStringHandle',
504                   'IteratorGetNext', 'OneShotIterator',
505                   'IteratorV2', 'MakeIterator',
506                   'BatchDatasetV2', 'MapDataset',
507                   'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset',
508                   'Placeholder', 'PlaceholderWithDefault', 'StridedSlice']:
509      return True
510    return False
511
512  @staticmethod
513  def less_interesting_op(op):
514    """Returns True if the given Op is not an interesting one to be traced."""
515
516    found, _ = TensorTracer.get_flag_value(
517        _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS)
518    if found:
519      # users force to include all ops.
520      return False
521    # Following ops are highly unlikey to cause bugs.
522    return op.type in ['Const', 'Identity', 'Cast', 'Shape']
523
524  @staticmethod
525  def reason(op_idx, details):
526    """Returns reason why the Op at op_idx is traced or not."""
527
528    return '%d %s'%(op_idx, details)
529
530  @staticmethod
531  def topological_sort(g):
532    """Performs topological sort on the given graph.
533
534    Args:
535       g: the graph.
536
537    Returns:
538       A pair where the first element indicates if the topological
539       sort succeeded (True if there is no cycle found; False if a
540       cycle is found) and the second element is either the sorted
541       list of nodes or the cycle of nodes found.
542    """
543    def _is_loop_edge(op):
544      """Returns true if the op is the end of a while-loop creating a cycle."""
545      return op.type in ['NextIteration']
546
547    def _in_op_degree(op):
548      """Returns the number of incoming edges to the given op.
549
550      The edge calculation skips the edges that come from 'NextIteration' ops.
551      NextIteration creates a cycle in the graph. We break cycles by treating
552      this op as 'sink' and ignoring all outgoing edges from it.
553      Args:
554        op: Tf.Operation
555      Returns:
556        the number of incoming edges.
557      """
558      count = 0
559      for op in op.control_inputs + [in_tensor.op for in_tensor in op.inputs]:
560        if not _is_loop_edge(op):
561          count += 1
562      return count
563
564    sorted_ops = []
565    op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()}
566
567    frontier = [op for (op, degree) in op_in_degree.items() if degree == 0]
568    while frontier:
569      op = frontier.pop()
570      # Remove the op from graph, and remove its outgoing edges.
571      sorted_ops.append(op)
572      if _is_loop_edge(op):
573        continue
574      # pylint: disable=protected-access
575      consumers = list(op._control_outputs)
576      # pylint: enable=protected-access
577      for out_tensor in op.outputs:
578        consumers += [consumer_op for consumer_op in out_tensor.consumers()]
579
580      for consumer in consumers:
581        # For each deleted edge shift the bucket of the vertex.
582        op_in_degree[consumer] -= 1
583        if op_in_degree[consumer] == 0:
584          frontier.append(consumer)
585        if op_in_degree[consumer] < 0:
586          raise ValueError('consumer:%s degree mismatch'%consumer.name)
587
588    left_ops = set([op for (op, degree) in op_in_degree.items() if degree > 0])
589    if left_ops:
590      return (False, left_ops)
591    else:
592      assert len(g.get_operations()) == len(sorted_ops)
593      return (True, sorted_ops)
594
595  @staticmethod
596  def _make_op_and_tensor_maps(op_list):
597    """Creates various maps and lists from op_list.
598
599    Args:
600       op_list: a list of Ops
601
602    Returns:
603       opname_idx_map: a map from Op's name to its index in op_list.
604       tensor_list: a list of output tensors of the Ops in op_list.
605       tensorname_idx_map: a map from output tensor name to its index
606                           in tensor_list.
607    """
608
609    opname_idx_map = {}
610    tensor_list = []
611    tensorname_idx_map = {}
612    for op_id, op in enumerate(op_list):
613      if op.name in opname_idx_map:
614        raise ValueError('Duplicated Op name: %s'%op.name)
615      opname_idx_map[op.name] = op_id
616      for output_tensor in op.outputs:
617        if output_tensor.name not in tensorname_idx_map:
618          tensor_list.append(output_tensor)
619          tensorname_idx_map[output_tensor.name] = len(tensor_list)-1
620    return (opname_idx_map, tensor_list, tensorname_idx_map)
621
622  def __init__(self):
623    """Initializes a TensorTracer.
624
625    Sets the various member fields from the flags (if given) or the defaults.
626    """
627    self._version = 'use-outside-compilation'
628    self._device_type = None
629    TensorTracer.validate_flag_names()
630    found, self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE)
631    if not found or not self._trace_mode:
632      self._trace_mode = _TRACE_MODE_NAN_INF
633    TensorTracer.check_trace_mode(self._trace_mode)
634    found, self._submode = TensorTracer.get_flag_value(_FLAG_NAME_SUBMODE)
635    if not found or not self._submode:
636      self._submode = _SUBMODE_DETAILED
637    TensorTracer.check_submode(self._submode)
638    self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE
639    self._instrument_records = {}
640    self._set_trace_dir()
641    self._set_report_file()
642    self._set_op_range()
643    self._set_excluded_opnames()
644    self._set_excluded_optypes()
645    self._set_included_opnames()
646    self._set_included_optypes()
647    self._num_replicas = None
648    self._num_replicas_per_host = None
649    self._num_hosts = None
650    self._replica_id = None
651    _, self._graph_dump_path = TensorTracer.get_flag_value(
652        _FLAG_DUMP_BEFORE_AFTER_GRAPHS)
653
654  def _add_replica_id_to_graph(self):
655    """Adds nodes for computing the replica ID to the graph."""
656
657    if self._num_replicas:
658      with ops.control_dependencies(None):
659        # Uses None as dependency to run outside of TPU graph rewrites.
660        self._replica_id = tpu_ops.tpu_replicated_input(
661            list(range(self._num_replicas)),
662            name='tt_replica_id')
663    else:
664      self._replica_id = 'unknown'
665
666  def _set_trace_dir(self):
667    found, self._trace_dir = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_DIR)
668    if found and self._trace_dir \
669       and TensorTracer.use_test_undeclared_outputs_dir():
670      raise ValueError('Cannot not use --%s and --%s at the same time'
671                       %(_FLAG_NAME_TRACE_DIR,
672                         _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR))
673    if TensorTracer.use_test_undeclared_outputs_dir():
674      self._trace_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR)
675
676  def _set_report_file(self):
677    """Sets the path of the output report file."""
678
679    found, self._report_file_path = TensorTracer.get_flag_value(
680        _FLAG_NAME_REPORT_FILE)
681    if found and self._report_file_path \
682       and TensorTracer.use_test_undeclared_outputs_dir():
683      if os.path.isabs(self._report_file_path):
684        raise ValueError('If use_test_undeclared_outputs_dir is set,'
685                         'report_file_path cannot be an absolute path (%s)'
686                         %self._report_file_path)
687      outputs_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR)
688      self._report_file_path = os.path.join(outputs_dir,
689                                            self._report_file_path)
690    if not self._report_file_path:
691      self._report_file = None
692      return
693    try:
694      self._report_file = gfile.Open(self._report_file_path, 'w')
695    except IOError as e:
696      raise e
697
698  def _close_report_file(self):
699    if self._report_file:
700      self._report_file.close()
701
702  def _set_op_range(self):
703    """Sets the index range of the Ops that we will consider tracing."""
704
705    found, op_range = TensorTracer.get_flag_value(_FLAG_NAME_OP_RANGE)
706    if not found or not op_range:
707      self._op_range = (-1, -1)  # this means including all ops.
708      return
709    match = _OP_RANGE_PAT.match(op_range)
710    if not match:
711      self._op_range = (-1, -1)  # this means including all ops.
712      return
713    self._op_range = (int(match.group(1)), int(match.group(2)))
714
715  def _inside_op_range(self, idx):
716    """Return True if the given index is inside the selected range."""
717
718    if idx < self._op_range[0]:
719      return False
720    return self._op_range[1] < 0 or idx <= self._op_range[1]
721
722  def _set_excluded_opnames(self):
723    self._excluded_opname_re_list = TensorTracer.flag_value_to_re_list(
724        _FLAG_NAME_EXCLUDED_OPNAMES)
725
726  def _set_excluded_optypes(self):
727    self._excluded_optype_re_list = TensorTracer.flag_value_to_re_list(
728        _FLAG_NAME_EXCLUDED_OPTYPES)
729
730  def _set_included_opnames(self):
731    self._included_opname_re_list = TensorTracer.flag_value_to_re_list(
732        _FLAG_NAME_INCLUDED_OPNAMES)
733
734  def _set_included_optypes(self):
735    self._included_optype_re_list = TensorTracer.flag_value_to_re_list(
736        _FLAG_NAME_INCLUDED_OPTYPES)
737
738  def _is_user_included_op(self, op):
739    for opname_re in self._included_opname_re_list:
740      if opname_re.match(op.name):
741        return True
742    for optype_re in self._included_optype_re_list:
743      if optype_re.match(op.type):
744        return True
745    return False
746
747  def _is_user_excluded_op(self, op):
748    for opname_re in self._excluded_opname_re_list:
749      if opname_re.match(op.name):
750        return True
751    for optype_re in self._excluded_optype_re_list:
752      if optype_re.match(op.type):
753        return True
754    return False
755
756  def _use_tensor_values_cache(self):
757    """Returns True if immediate tensors should be first saved to a cache."""
758
759    if self._trace_mode not in set([_TRACE_MODE_NAN_INF,
760                                    _TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS]):
761      return False
762    if self._trace_dir and _trace_files_need_precreated(self._trace_dir):
763      return True
764    if TensorTracer.use_compact_trace():
765      return True
766    return False
767
768  def _save_tensor_value_to_cache_op(self, graph, cache_idx, updates):
769    """Returns an Op that will save the given updates to an entry in the cache."""
770
771    cache = _get_tensor_values_cache(graph)
772    indices = constant_op.constant([cache_idx])
773    return state_ops.scatter_update(cache, indices, updates).op
774
775  def _write_report(self, content):
776    """Writes the given content to the report."""
777
778    line = '%s %s'%(_TRACER_LOG_PREFIX, content)
779    if self._report_file:
780      self._report_file.write(line)
781    else:
782      logging.info(line)
783
784  def _write_config_section(self):
785    """Writes the config section of the report."""
786
787    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG))
788    self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, self._version))
789    self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, self._device_type))
790    self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE, self._trace_mode))
791    self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE, self._submode))
792    self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS, self._num_replicas))
793    self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS_PER_HOST,
794                                  self._num_replicas_per_host))
795    self._write_report('%s %s\n'%(_FIELD_NAME_NUM_HOSTS, self._num_hosts))
796    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG))
797
798  def _write_reason_section(self):
799    """Writes the reason section of the report."""
800
801    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON))
802    for key in sorted(self._instrument_records):
803      self._write_report('"%s" %s\n'%(key, self._instrument_records[key]))
804    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON))
805
806  def _write_op_list_section(self, op_list):
807    """Writes the Op-list section of the report."""
808
809    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST))
810    self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list)))
811    for i in range(0, len(op_list)):
812      op = op_list[i]
813      line = '%d "%s" %s'%(i, op.name, op.type)
814      for out_tensor in op.outputs:
815        if out_tensor.name not in self._tensorname_idx_map:
816          raise ValueError(
817              'out_tensor %s is not in tensorname_idx_map'%out_tensor.name)
818        line += ' %d'%self._tensorname_idx_map[out_tensor.name]
819      line += '\n'
820      self._write_report(line)
821    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST))
822
823  def _write_tensor_list_section(self, tensor_list, opname_idx_map):
824    """Writes the tensor-list section of the report."""
825
826    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
827                                  _SECTION_NAME_TENSOR_LIST))
828    self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS, len(tensor_list)))
829    for i in range(0, len(tensor_list)):
830      tensor = tensor_list[i]
831      line = '%d "%s"'%(i, tensor.name)
832      for consumer_op in tensor.consumers():
833        if consumer_op.name not in opname_idx_map:
834          raise ValueError(
835              'consumer_op %s is not in opname_idx_map'%consumer_op.name)
836        line += ' %d'%opname_idx_map[consumer_op.name]
837      line += '\n'
838      self._write_report(line)
839    self._write_report('%s %s\n'%(_MARKER_SECTION_END,
840                                  _SECTION_NAME_TENSOR_LIST))
841
842  def _write_cache_index_map_section(self):
843    """Writes the mapping from cache index to tensor index to the report."""
844
845    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
846                                  _SECTION_NAME_CACHE_INDEX_MAP))
847    self._write_report('%s %d\n'%(_FIELD_NAME_NUM_CACHE_INDICES,
848                                  len(self._cache_idx_to_tensor_idx)))
849    for cache_idx in range(0, len(self._cache_idx_to_tensor_idx)):
850      tensor_idx = self._cache_idx_to_tensor_idx[cache_idx]
851      line = '%d %d\n'%(cache_idx, tensor_idx)
852      self._write_report(line)
853    self._write_report('%s %s\n'%(_MARKER_SECTION_END,
854                                  _SECTION_NAME_CACHE_INDEX_MAP))
855
856  def _write_graph_section(self, succeed, sorted_or_cycle):
857    """Writes the graph section of the report."""
858
859    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH))
860    self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED,
861                                  succeed))
862    l = list(sorted_or_cycle)
863    for i in range(0, len(l)):
864      self._write_report('%d "%s"\n'%(i, l[i].name))
865    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH))
866
867  def _preprocess_traced_tensor(self, tensor):
868    """Computes NAN/Norm/Max on TPUs before sending to CPU.
869
870    Args:
871      tensor: The tensor to be traced.
872    Returns:
873      A tensor that should be input to the trace_function.
874    Raises:
875      RuntimeError: If the trace mode is invalid.
876    """
877
878    def _detect_nan_inf(tensor):
879      """Trace function for detecting any NaN/Inf in the tensor."""
880
881      if tensor.dtype.is_floating:
882        mask = math_ops.reduce_any(
883            gen_math_ops.logical_or(
884                gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor)))
885        output_tensor = control_flow_ops.cond(mask,
886                                              lambda: constant_op.constant(1.0),
887                                              lambda: constant_op.constant(0.0))
888      else:
889        output_tensor = constant_op.constant(0.0)
890      # The shape has to be 1. Set it if it does not have the information.
891      output_tensor = array_ops.reshape(output_tensor, [1])
892      return output_tensor
893
894    def _show_norm(tensor):
895      tensor = math_ops.cast(tensor, dtypes.float32)
896      output_tensor = linalg_ops.norm(tensor)
897      # The shape has to be 1. Set it if it does not have the information.
898      output_tensor = array_ops.reshape(output_tensor, [1])
899      return output_tensor
900
901    def _show_max_abs(tensor):
902      tensor = math_ops.cast(tensor, dtypes.float32)
903      output_tensor = math_ops.reduce_max(math_ops.abs(tensor))
904      zero = constant_op.constant(0, dtypes.float32)
905      output_tensor = gen_math_ops.maximum(zero, output_tensor)
906      # The shape has to be 1. Set it if it does not have the information.
907      output_tensor = array_ops.reshape(output_tensor, [1])
908      return output_tensor
909
910    if self._trace_mode == _TRACE_MODE_NAN_INF:
911      return _detect_nan_inf(tensor)
912    if self._trace_mode == _TRACE_MODE_PART_TENSOR:
913      return tensor
914    if self._trace_mode == _TRACE_MODE_FULL_TENSOR:
915      return tensor
916    if self._trace_mode == _TRACE_MODE_NORM:
917      return _show_norm(tensor)
918    if self._trace_mode == _TRACE_MODE_MAX_ABS:
919      return _show_max_abs(tensor)
920    raise RuntimeError(
921        'Tensor trace fun for %s is not yet implemented' % self._trace_mode)
922
923  def _make_tensor_trace_fun(self, tensor_name):
924    """Makes the tensor tracing function called by outside compilation.
925
926    Args:
927      tensor_name: name of the tensor being traced.
928
929    Returns:
930      A function to be passed as the first argument to outside compilation.
931
932    Raises:
933      RuntimeError: If the trace mode is invalid.
934    """
935
936    def _print_tensor(tensor_name, num_elements, tensor, output_tensor):
937      """Prints a tensor value to a file.
938
939      Args:
940        tensor_name: name of the tensor being traced.
941        num_elements: number of elements to print (-1 means print all).
942        tensor: the tensor needs to be returned.
943        output_tensor: the tensor needs to be printed.
944
945      Returns:
946        The same tensor passed via the "tensor" argument.
947
948      Raises:
949        ValueError: If tensor_name is not already in
950                    self._tensorname_idx_map.
951      """
952
953      if self._submode == _SUBMODE_BRIEF:
954        if tensor_name not in self._tensorname_idx_map:
955          raise ValueError(
956              'Tensor name %s is not in the tensorname_idx_map'%tensor_name)
957        msg = '%d'%self._tensorname_idx_map[tensor_name]
958      else:
959        msg = '"%s"'%tensor_name
960
961      if self._trace_dir:
962        output_path = os.path.join(self._trace_dir, _TRACE_FILE_NAME)
963        output_stream = _OUTPUT_STREAM_ESCAPE + output_path
964      else:
965        output_stream = sys.stderr
966      return logging_ops.print_v2(msg, array_ops.shape(output_tensor),
967                                  '@', self._replica_id,
968                                  '\n', output_tensor, '\n',
969                                  summarize=num_elements,
970                                  output_stream=output_stream)
971
972    def _show_part_tensor(tensor):
973      """Trace function for printing part of the tensor."""
974
975      return _print_tensor(tensor_name, self._part_tensor_size,
976                           tensor, tensor)
977
978    def _show_full_tensor(tensor):
979      """Trace function for printing the entire tensor."""
980
981      return _print_tensor(tensor_name, -1, tensor, tensor)
982
983    if self._trace_mode == _TRACE_MODE_PART_TENSOR:
984      return _show_part_tensor
985    # The input tensor has a shape of "[1]" for _TRACE_MODE_NAN_INF,
986    # _TRACE_MODE_NORM, and _TRACE_MODE_MAX_ABS, as related computations are
987    # performed within TPUs and only their results are transferred to CPU.
988    # Simply, print the full tensor for these trace modes.
989    if self._trace_mode in [
990        _TRACE_MODE_NAN_INF, _TRACE_MODE_NORM, _TRACE_MODE_FULL_TENSOR,
991        _TRACE_MODE_MAX_ABS
992    ]:
993      return _show_full_tensor
994
995    raise RuntimeError('Tensor trace fun for %s is not yet implemented'
996                       %self._trace_mode)
997
998  def _skip_op(self, op_id, op, user_included, user_excluded,
999               in_exec_path=True):
1000    """Returns True if we should not trace Op."""
1001
1002    if TensorTracer.while_loop_op(op):
1003      self._instrument_records[op.name] = TensorTracer.reason(
1004          op_id, _REASON_WHILELOOP_OP)
1005      return True
1006    if TensorTracer.unsafe_op(op):
1007      self._instrument_records[op.name] = TensorTracer.reason(
1008          op_id, _REASON_UNSAFE_OP)
1009      return True
1010    if TensorTracer.device_mismatch(self._device_type, op):
1011      self._instrument_records[op.name] = TensorTracer.reason(
1012          op_id, _REASON_DEVICE_MISMATCH)
1013      return True
1014    if not in_exec_path:
1015      self._instrument_records[op.name] = TensorTracer.reason(
1016          op_id, _REASON_NOT_EXECUTED)
1017      return True
1018
1019    if not self._inside_op_range(op_id):
1020      self._instrument_records[op.name] = TensorTracer.reason(
1021          op_id, _REASON_OUTSIDE_OP_RANGE)
1022      return True
1023    if TensorTracer.less_interesting_op(op):
1024      self._instrument_records[op.name] = TensorTracer.reason(
1025          op_id, _REASON_LESS_INTERESTING_OP)
1026      return True
1027    if user_included:
1028      self._instrument_records[op.name] = TensorTracer.reason(
1029          op_id, _REASON_USER_INCLUDED)
1030      return False
1031    if user_excluded:
1032      self._instrument_records[op.name] = TensorTracer.reason(
1033          op_id, _REASON_USER_EXCLUDED)
1034      return True
1035    return False
1036
1037  def _skip_tensor(self, op_id, out_tensor, user_included,
1038                   user_excluded):
1039    """Returns True if we should not trace out_tensor."""
1040
1041    # Skips a tensor if the tensor has a non-numeric type.
1042    #   Note: we cannot use check_ops.is_numeric_tensor(out_tensor)
1043    #         because it also excludes tensors with dtypes, bool, and
1044    #         float32_ref, which we actually want to trace.
1045    non_numeric_tensor_types = set([dtypes.variant, dtypes.resource,
1046                                    dtypes.string])
1047    if out_tensor.dtype in non_numeric_tensor_types:
1048      self._instrument_records[out_tensor.name] = TensorTracer.reason(
1049          op_id, _REASON_NON_NUMERIC_TENSOR)
1050      return True
1051    # Skip a tensor if it feeds a special while loop op.
1052    if [consumer for consumer in out_tensor.consumers() if
1053        TensorTracer.while_loop_op(consumer)]:
1054      self._instrument_records[out_tensor.name] = TensorTracer.reason(
1055          op_id, _REASON_FEEDS_WHILELOOP_OP)
1056      return True
1057    if user_included:
1058      self._instrument_records[out_tensor.name] = TensorTracer.reason(
1059          op_id, _REASON_USER_INCLUDED)
1060      return False
1061    if user_excluded:
1062      self._instrument_records[out_tensor.name] = TensorTracer.reason(
1063          op_id, _REASON_USER_EXCLUDED)
1064      return True
1065    if not out_tensor.get_shape().is_fully_defined():
1066      # If trace mode is nan-inf, norm or max, then the tensor will be reduced
1067      # to a scalar before the outside compilation call.
1068      if self._trace_mode in [
1069          _TRACE_MODE_NAN_INF, _TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS
1070      ]:
1071        self._instrument_records[out_tensor.name] = TensorTracer.reason(
1072            op_id, _REASON_TENSOR_GET_TRACED)
1073        return False
1074      else:
1075        self._instrument_records[out_tensor.name] = TensorTracer.reason(
1076            op_id, _REASON_DYNAMIC_SHAPE)
1077        return True
1078    rank = len(out_tensor.shape)
1079    if rank < 1:
1080      # scalar
1081      if TensorTracer.unsafe_scalar_trace(out_tensor.op):
1082        self._instrument_records[out_tensor.name] = TensorTracer.reason(
1083            op_id, _REASON_UNSAFE_SCALAR)
1084        return True
1085      else:
1086        self._instrument_records[out_tensor.name] = TensorTracer.reason(
1087            op_id, _REASON_SCALAR_GET_TRACED)
1088        return False
1089    else:
1090      # tensor
1091      self._instrument_records[out_tensor.name] = TensorTracer.reason(
1092          op_id, _REASON_TENSOR_GET_TRACED)
1093      return False
1094
1095  def _filter_execution_path_operations(self, operations, fetches):
1096    """Returns the set of ops in the execution path to compute given fetches."""
1097
1098    # If no fetch provided, then return all operations.
1099    if fetches is None:
1100      return set(operations)
1101    # Convert to list, if a single element is provided.
1102    if not isinstance(fetches, (list, tuple)):
1103      fetches = [fetches]
1104    # If a tensor is given as fetch, convert it to op.
1105    op_fetches = []
1106    for fetch in fetches:
1107      if isinstance(fetch, ops.Operation):
1108        op_fetches.append(fetch)
1109      elif isinstance(fetch, ops.Tensor):
1110        op_fetches.append(fetch.op)
1111      else:
1112        raise RuntimeError('Given fetch:%s is neither a tensor nor an op.'
1113                           %fetch)
1114
1115    execution_path_operations = set(op_fetches)
1116    traverse_stack = list(op_fetches)
1117    while True:
1118      if not traverse_stack:
1119        break
1120      head_op = traverse_stack.pop()
1121      input_ops = [tensor_input.op for tensor_input in head_op.inputs]
1122      input_ops.extend(head_op.control_inputs)
1123
1124      for input_op in input_ops:
1125        if input_op not in execution_path_operations:
1126          # Filter out loop condition operations, tracing them causes a cycle.
1127          # Trace only the loop-body.
1128          if TensorTracer.loop_cond_op(input_op):
1129            continue
1130          execution_path_operations.add(input_op)
1131          traverse_stack.append(input_op)
1132    return execution_path_operations
1133
1134  def _determine_traced_tensors(self, graph, ops_in_exec_path):
1135    """Determines the tensors that will be traced."""
1136
1137    self._traced_tensorname_to_cache_idx_map = {}
1138    self._cache_idx_to_tensor_idx = []
1139    operations = graph.get_operations()
1140    checkpoint_operations = self._get_checkpoints(graph)
1141    for op_id, op in enumerate(operations):
1142      if checkpoint_operations and op.name not in checkpoint_operations:
1143        continue
1144      user_included = self._is_user_included_op(op)
1145      user_excluded = self._is_user_excluded_op(op)
1146      in_exec_path = op in ops_in_exec_path
1147      if self._skip_op(op_id, op, user_included, user_excluded, in_exec_path):
1148        continue
1149      for i in range(len(op.outputs)):
1150        out_tensor = op.outputs[i]
1151        if self._skip_tensor(op_id, out_tensor, user_included,
1152                             user_excluded):
1153          continue
1154        tensor_name = out_tensor.name
1155        if tensor_name in self._traced_tensorname_to_cache_idx_map:
1156          raise ValueError(
1157              'Tensor name %s should not be already in '
1158              'traced_tensorname_to_cache_idx_map'%tensor_name)
1159        if tensor_name not in self._tensorname_idx_map:
1160          raise ValueError(
1161              'Tensor name %s is not in the tensorname_idx_map'%tensor_name)
1162        tensor_idx = self._tensorname_idx_map[tensor_name]
1163        cache_idx = len(self._traced_tensorname_to_cache_idx_map)
1164        self._traced_tensorname_to_cache_idx_map[tensor_name] = cache_idx
1165        self._cache_idx_to_tensor_idx.append(tensor_idx)
1166        if len(self._traced_tensorname_to_cache_idx_map) != len(
1167            self._cache_idx_to_tensor_idx):
1168          raise RuntimeError('len(self._traced_tensorname_to_cache_idx_map) != '
1169                             'len(self._cache_idx_to_tensor_idx')
1170
1171  def _check_trace_files(self):
1172    """Checks if any requirements for trace files are satisfied."""
1173
1174    if not self._trace_dir:
1175      # traces will be written to stderr. No need to check trace files.
1176      return
1177    if _trace_files_need_precreated(self._trace_dir):
1178      for replica_id in range(0, self._num_replicas):
1179        trace_file_path = os.path.join(
1180            self._trace_dir,
1181            _COMPACT_TRACE_FILE_PREFIX) + '%d'%replica_id
1182        if not gfile.Exists(trace_file_path):
1183          raise RuntimeError(
1184              '%s must be pre-created with the '
1185              'appropriate properties.'%trace_file_path)
1186    else:
1187      if not gfile.Exists(self._trace_dir):
1188        gfile.MkDir(self._trace_dir)
1189        if not gfile.Exists(self._trace_dir):
1190          raise RuntimeError('Failed to create %s'%self._trace_dir)
1191
1192  def _pre_tracing(self, graph, fetches):
1193    """Work needs to be done prior to TPU or CPU tracing."""
1194
1195    self._check_trace_files()
1196    operations = graph.get_operations()
1197    (opname_idx_map, tensor_list, self._tensorname_idx_map) = (
1198        TensorTracer._make_op_and_tensor_maps(operations))
1199    self._write_config_section()
1200    self._write_op_list_section(operations)
1201    self._write_tensor_list_section(tensor_list, opname_idx_map)
1202    # Filter out the operations that won't be executed.
1203    # if fetches=None, then ops_in_exec_path = set(operations)
1204    ops_in_exec_path = self._filter_execution_path_operations(operations,
1205                                                              fetches)
1206    self._determine_traced_tensors(graph, ops_in_exec_path)
1207    self._write_cache_index_map_section()
1208    # Does the topological sort before adding any nodes to the graph.
1209    (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph)
1210    if self._use_tensor_values_cache():
1211      _create_tensor_values_cache(graph,
1212                                  len(self._cache_idx_to_tensor_idx))
1213    return (ops_in_exec_path, succeed, sorted_or_cycle)
1214
1215  def _post_tracing(self, succeed, sorted_or_cycle):
1216    """Work needs to be done after TPU or CPU tracing."""
1217
1218    self._write_reason_section()
1219    self._write_graph_section(succeed, sorted_or_cycle)
1220    self._close_report_file()
1221
1222  def _get_checkpoints(self, graph):
1223    """Returns the list of Ops that produce the tensors traced with API.
1224
1225    Args:
1226      graph: the graph of Ops.
1227
1228    Returns:
1229      A set of operation names which should be traced.
1230    """
1231
1232    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
1233                                  _TENSOR_TRACER_CHECKPOINT))
1234    checkpoint_operations = set()
1235    tensor_tracer_variables = graph.get_collection(_TENSOR_TRACER_COLLECTION)
1236    for (tensor, checkpoint_name) in tensor_tracer_variables:
1237      self._write_report('%s %s\n'%(tensor.name, checkpoint_name))
1238      checkpoint_operations.add(tensor.op.name)
1239    self._write_report('%s %s\n'%(_MARKER_SECTION_END,
1240                                  _TENSOR_TRACER_CHECKPOINT))
1241    return checkpoint_operations
1242
1243  def _generate_flush_cache_op(self, graph, start_replica, on_tpu):
1244    """Generates an Op that will flush the cache to file.
1245
1246    Args:
1247      graph: the graph of Ops
1248      start_replica: the ID of the first replica being flushed by this Op.
1249      on_tpu: if the graph is executed on TPU.
1250
1251    Returns:
1252      The Op to flush the cache to file.
1253    """
1254    def _make_flush_fun(replica_id):
1255      """Makes a function for flushing the cache for the given replica."""
1256
1257      def _fun():
1258        """A function that flushes the cache to a file."""
1259
1260        def _flush_fun(cache):
1261          """Flushes the cache to a file."""
1262
1263          if isinstance(replica_id, str):
1264            replica_id_str = replica_id
1265          else:
1266            replica_id_str = '%d'%replica_id
1267          if self._trace_dir:
1268            output_path = os.path.join(self._trace_dir,
1269                                       _COMPACT_TRACE_FILE_PREFIX) \
1270                                       + replica_id_str
1271            output_stream = _OUTPUT_STREAM_ESCAPE + output_path
1272          else:
1273            output_stream = sys.stderr
1274          new_step_line = _REPLICA_ID_TAG + replica_id_str
1275          print_op = logging_ops.print_v2(
1276              new_step_line, '\n',
1277              cache, '\n',
1278              summarize=-1,
1279              output_stream=output_stream)
1280          with ops.control_dependencies([print_op]):
1281            return constant_op.constant(0).op
1282
1283        cache = _get_tensor_values_cache(graph)
1284        if on_tpu:
1285          flush_op = tpu.outside_compilation(_flush_fun, cache.value())
1286        else:
1287          flush_op = _flush_fun(cache.value())
1288        with ops.control_dependencies([flush_op]):
1289          reset_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE,
1290                                             dtype=cache.dtype,
1291                                             shape=cache.shape)
1292          assign_op = state_ops.assign(cache, reset_value).op
1293          with ops.control_dependencies([assign_op]):
1294            return flush_op.outputs[0]
1295
1296      return _fun
1297
1298    def _f(replica_id):
1299      return _make_flush_fun(replica_id)
1300    def _eq(x):
1301      return math_ops.equal(x, self._replica_id)
1302    def _do_nothing():
1303      return constant_op.constant(0)
1304
1305    return control_flow_ops.case({\
1306                                  _eq(start_replica): _f(start_replica), \
1307                                  _eq(start_replica+1): _f(start_replica+1), \
1308                                  _eq(start_replica+2): _f(start_replica+2), \
1309                                  _eq(start_replica+3): _f(start_replica+3), \
1310                                  _eq(start_replica+4): _f(start_replica+4), \
1311                                  _eq(start_replica+5): _f(start_replica+5), \
1312                                  _eq(start_replica+6): _f(start_replica+6), \
1313                                  _eq(start_replica+7): _f(start_replica+7), \
1314    },
1315                                 default=_do_nothing,
1316                                 exclusive=True).op
1317
1318  def _flush_tensor_values_cache(self, graph, tensor_fetches, op_fetches,
1319                                 on_tpu):
1320    """Flushes the intermediate tensor values in the graph to the cache.
1321
1322    Args:
1323      graph: the graph of Ops
1324      tensor_fetches: list of tensor results returned by the model_fn.
1325      op_fetches: list of ops that are returned by the model_fn, e.g., train_op.
1326      on_tpu: if the graph is executed on TPU.
1327
1328    Returns:
1329      An identical copy of tensor_fetches.
1330    """
1331    # Add a dependency to op and tensor fetches to make sure that all tracing
1332    # ops are executed before flushing trace results.
1333    with ops.control_dependencies(op_fetches +
1334                                  [tensor.op for tensor in tensor_fetches]):
1335      flush_cache_op_list = []
1336      for host in range(self._num_hosts):
1337        start_replica = host * 8
1338        flush_op = self._generate_flush_cache_op(graph, start_replica, on_tpu)
1339        flush_cache_op_list.append(flush_op)
1340      return control_flow_ops.tuple(tensor_fetches,
1341                                    control_inputs=flush_cache_op_list)
1342
1343  def _process_tensor_fetches(self, tensor_fetches):
1344    """Check that tensor_fetches is not empty and have valid tensors."""
1345    # If none or empty list.
1346    if tensor_fetches is None:
1347      raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be '
1348                         'None.')
1349    if not isinstance(tensor_fetches, (list, tuple)):
1350      tensor_fetches = [tensor_fetches]
1351    elif not tensor_fetches:
1352      raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be '
1353                         'empty list.')
1354    fetches = []
1355    for fetch in tensor_fetches:
1356      if isinstance(fetch, ops.Tensor):
1357        fetches.append(fetch)
1358      else:
1359        raise RuntimeError('Given tensor_fetch:%s is not a tensor.' % fetch)
1360    return fetches
1361
1362  def _process_op_fetches(self, op_fetches):
1363    """Check that op_fetches have valid ops."""
1364    if op_fetches is None:
1365      return []
1366
1367    if not isinstance(op_fetches, (list, tuple)):
1368      op_fetches = [op_fetches]
1369
1370    fetches = []
1371    for fetch in op_fetches:
1372      if isinstance(fetch, ops.Operation):
1373        fetches.append(fetch)
1374      else:
1375        logging.warning('Ignoring the given op_fetch:%s, which is not an op.' %
1376                        fetch)
1377    return fetches
1378
1379  def _convert_fetches_to_input_format(self, input_fetches, current_fetches):
1380    """Changes current_fetches' format, so that it matches input_fetches."""
1381    if isinstance(input_fetches, ops.Tensor):
1382      if len(current_fetches) != 1:
1383        raise RuntimeError('Tensor tracer input/output fetches do not match.')
1384      return current_fetches[0]
1385    else:
1386      if len(current_fetches) != len(current_fetches):
1387        raise RuntimeError('Tensor tracer input/output fetches do not match.')
1388      elif isinstance(input_fetches, tuple):
1389        return tuple(current_fetches)
1390      else:
1391        return current_fetches
1392
1393  def _get_op_control_flow_context(self, op):
1394    """Returns the control flow of the given op.
1395
1396    Args:
1397      op: tf.Operation for which the control flow context is requested.
1398    Returns:
1399      op_control_flow_context: which the is control flow context of the given
1400      op. If the operation type is LoopExit, returns the outer control flow
1401      context.
1402    """
1403    # pylint: disable=protected-access
1404    op_control_flow_context = op._control_flow_context
1405    # pylint: enable=protected-access
1406    if control_flow_util.IsLoopExit(op):
1407      op_control_flow_context = op_control_flow_context.outer_context
1408    return op_control_flow_context
1409
1410  def _trace_execution(self, graph,
1411                       tensor_fetches,
1412                       op_fetches=None,
1413                       on_tpu=True):
1414    """Commong tracing function for both CPU and TPUs.
1415
1416    The caller function should set _device_type, _num_replicas,
1417    _num_replicas_per_host, _num_hosts and _replica_id before calling
1418    _trace_execution.
1419
1420
1421    Args:
1422      graph: the graph of Ops executed on the TPU.
1423      tensor_fetches: a (list,tuple,or a single object) of tensor fetches
1424        returned by model_fn given to session.run. Function must be provided
1425        with as least one tensor to fetch.
1426      op_fetches: A list of op fetches returned by model_fn given to
1427        session.run. op_fetches and tensor_fetches are used to determine the
1428        nodes that will be executed. Can be None.
1429      on_tpu: True if executing on TPU.
1430
1431    Returns:
1432      tensor_fetches: an exact copy of tensor_fetches that has additional
1433                      dependencies.
1434    Raises:
1435      RuntimeError: If tensor_fetches is None or empty.
1436    """
1437    def _cast_unsupported_dtypes(tensor):
1438      """Casts tensor to a supported type."""
1439
1440      if tensor.dtype.__eq__(dtypes.int64):
1441        # outside-compilation doesn't support int64 input yet.
1442        return math_ops.cast(tensor, dtypes.int32)
1443      if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__(
1444          dtypes.float16):
1445        # Since host can't handle bf16, convert tensor to f32.
1446        return math_ops.cast(tensor, dtypes.float32)
1447      return tensor
1448
1449    TensorTracer.check_device_type(self._device_type)
1450    # Check in_tensor_fetches, and op_fetches and convert them to lists.
1451    processed_t_fetches = self._process_tensor_fetches(tensor_fetches)
1452    op_fetches = self._process_op_fetches(op_fetches)
1453    all_fetches = op_fetches + [tensor.op for tensor in processed_t_fetches]
1454
1455    # Filter the set of ops that will be executed, and topological sort.
1456    (exec_op_set, succeed, sorted_or_cycle) = self._pre_tracing(graph,
1457                                                                all_fetches)
1458
1459    tensor_fetch_set = set(processed_t_fetches)
1460    tracing_ops = []
1461
1462    # pylint: disable=protected-access
1463    current_control_flow_context = graph._get_control_flow_context()
1464    # pylint: enable=protected-access
1465
1466    # Trace ops only if they are in the execution path.
1467    for op in exec_op_set:
1468      for i in range(len(op.outputs)):
1469        out_tensor = op.outputs[i]
1470        tensor_name = out_tensor.name
1471        if tensor_name not in self._traced_tensorname_to_cache_idx_map:
1472          continue
1473        # Create the list of consumers before calling _preprocess_traced_tensor.
1474        # Otherwise, adding control input below, will introduce a cycle in the
1475        # graph.
1476        consumers = out_tensor.consumers()
1477        # Not all consumers may be in the exec path. Filter out the consumers
1478        # to keep the graph simpler.
1479        consumers = [cop for cop in consumers if cop in exec_op_set]
1480
1481        # If there is no consumer of the tensor, there is no need to trace it;
1482        # unless the tensor itself is one of the fetches.
1483        is_a_fetched_tensor = out_tensor in tensor_fetch_set
1484        if (not consumers) and (not is_a_fetched_tensor):
1485          continue
1486
1487        op_control_flow_context = self._get_op_control_flow_context(op)
1488        # pylint: disable=protected-access
1489        graph._set_control_flow_context(op_control_flow_context)
1490        # pylint: enable=protected-access
1491        processed_out_tensor = self._preprocess_traced_tensor(out_tensor)
1492
1493        if on_tpu:
1494          processed_out_tensor = _cast_unsupported_dtypes(processed_out_tensor)
1495
1496        if self._use_tensor_values_cache():
1497          cache_idx = self._traced_tensorname_to_cache_idx_map[tensor_name]
1498          trace_op = self._save_tensor_value_to_cache_op(graph,
1499                                                         cache_idx,
1500                                                         processed_out_tensor)
1501        elif on_tpu:
1502          trace_op = tpu.outside_compilation(
1503              self._make_tensor_trace_fun(tensor_name), processed_out_tensor)
1504        else:
1505          trace_fun = self._make_tensor_trace_fun(tensor_name)
1506          trace_op = trace_fun(processed_out_tensor)
1507
1508        if is_a_fetched_tensor:
1509          tracing_ops.append(trace_op)
1510          continue
1511        # Add it to all consumers, as some consumers may not be executed if they
1512        # are in a control flow.
1513        for consumer_op in consumers:
1514          # pylint: disable=protected-access
1515          consumer_op._add_control_input(trace_op)
1516          # pylint: enable=protected-access
1517
1518    # pylint: disable=protected-access
1519    graph._set_control_flow_context(current_control_flow_context)
1520    # pylint: enable=protected-access
1521    if tracing_ops:
1522      # If we are tracing a fetched tensor, their dependency is stored in
1523      # tracing_ops.
1524      processed_t_fetches = control_flow_ops.tuple(processed_t_fetches,
1525                                                   control_inputs=tracing_ops)
1526    if self._use_tensor_values_cache():
1527      processed_t_fetches = self._flush_tensor_values_cache(graph,
1528                                                            processed_t_fetches,
1529                                                            op_fetches,
1530                                                            on_tpu=on_tpu)
1531    self._post_tracing(succeed, sorted_or_cycle)
1532    # processed_t_fetches is a list at this point. Convert it to the same
1533    # format as given in tensor_fetches.
1534    return self._convert_fetches_to_input_format(tensor_fetches,
1535                                                 processed_t_fetches)
1536
1537  def trace_tpu(self, graph,
1538                tensor_fetches,
1539                op_fetches=None,
1540                num_replicas=None,
1541                num_replicas_per_host=None,
1542                num_hosts=None):
1543    """Traces the tensors generated by TPU Ops in a TF graph.
1544
1545    Args:
1546      graph: the graph of Ops executed on the TPU.
1547      tensor_fetches: a (list,tuple,or a single object) of tensor fetches
1548        returned by model_fn given to session.run. Function must be provided
1549        with as least one tensor to fetch.
1550      op_fetches: A list of op fetches returned by model_fn given to
1551        session.run. op_fetches and tensor_fetches are used to determine the
1552        nodes that will be executed. Can be None.
1553      num_replicas: number of replicas used on the TPU.
1554      num_replicas_per_host: number of replicas per TPU host.
1555      num_hosts: total number of TPU hosts.
1556
1557    Returns:
1558      tensor_fetches: an exact copy of tensor_fetches that has additional
1559                      dependencies.
1560    Raises:
1561      RuntimeError: If num_replicas_per_host > 8.
1562      RuntimeError: If tensor_fetches is None or empty.
1563    """
1564
1565    if graph in TensorTracer._traced_graphs:
1566      logging.warning('Graph is already rewritten with tensor tracer, ignoring '
1567                      'multiple calls.')
1568      return tensor_fetches
1569    else:
1570      TensorTracer._traced_graphs.add(graph)
1571    self._device_type = _DEVICE_TYPE_TPU
1572    self._num_replicas = num_replicas
1573    self._num_replicas_per_host = num_replicas_per_host
1574    self._num_hosts = num_hosts
1575    if self._num_replicas is not None:
1576      if self._num_replicas_per_host is None:
1577        self._num_replicas_per_host = 8
1578      if self._num_hosts is None:
1579        self._num_hosts = num_replicas // self._num_replicas_per_host + \
1580            (num_replicas % self._num_replicas_per_host > 0)
1581
1582    if self._num_replicas_per_host > 8:
1583      # Checks for the assumption in _generate_flush_cache_op().
1584      raise RuntimeError('num_replicas_per_host (%d) is '
1585                         'greater than 8'%self._num_replicas_per_host)
1586    if self._graph_dump_path:
1587      graph_io.write_graph(graph, self._graph_dump_path,
1588                           'graph_before_tt.pbtxt')
1589    with graph.as_default():
1590      self._add_replica_id_to_graph()
1591      tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches,
1592                                             on_tpu=True)
1593    if self._graph_dump_path:
1594      graph_io.write_graph(graph, self._graph_dump_path,
1595                           'graph_after_tt.pbtxt')
1596    return tensor_fetches
1597
1598  def trace_cpu(self, graph, tensor_fetches, op_fetches=None):
1599    """Traces the tensors generated by CPU Ops in a TF graph.
1600
1601    Args:
1602      graph: the graph of Ops executed on the CPU.
1603      tensor_fetches: a (list,tuple,or a single object) of tensor fetches
1604        returned by model_fn given to session.run. Function must be provided
1605        with as least one tensor to fetch.
1606      op_fetches: A list of op fetches returned by model_fn given to
1607        session.run. op_fetches and tensor_fetches are used to determine the
1608        nodes that will be executed. Can be None.
1609
1610    Returns:
1611      tensor_fetches: an exact copy of tensor_fetches that has additional
1612                      dependencies.
1613    Raises:
1614      RuntimeError: If tensor_fetches is None or empty.
1615    """
1616
1617    if graph in TensorTracer._traced_graphs:
1618      logging.warning('Graph is already rewritten with tensor tracer, ignoring '
1619                      'multiple calls.')
1620      return tensor_fetches
1621    else:
1622      TensorTracer._traced_graphs.add(graph)
1623
1624    self._device_type = _DEVICE_TYPE_CPU
1625    self._num_replicas = 1
1626    self._num_replicas_per_host = 1
1627    self._num_hosts = 1
1628    self._replica_id = 0
1629    if self._graph_dump_path:
1630      graph_io.write_graph(graph, self._graph_dump_path,
1631                           'graph_before_tt.pbtxt')
1632    with graph.as_default():
1633      tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches,
1634                                             on_tpu=False)
1635    if self._graph_dump_path:
1636      graph_io.write_graph(graph, self._graph_dump_path,
1637                           'graph_after_tt.pbtxt')
1638    return tensor_fetches
1639
1640
1641