1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ========================================================================
15"""A utility to trace tensor values on TPU."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import operator
22
23import os
24import os.path
25import sys
26
27import numpy as np
28import six
29
30from tensorflow.core.framework import summary_pb2
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import func_graph
34from tensorflow.python.framework import function
35from tensorflow.python.framework import graph_io
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import tensor_util
38from tensorflow.python.lib.io import file_io
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import control_flow_ops
41from tensorflow.python.ops import control_flow_util
42from tensorflow.python.ops import gen_math_ops
43from tensorflow.python.ops import init_ops
44from tensorflow.python.ops import linalg_ops
45from tensorflow.python.ops import logging_ops
46from tensorflow.python.ops import math_ops
47from tensorflow.python.ops import nn_impl
48from tensorflow.python.ops import state_ops
49from tensorflow.python.ops import string_ops
50from tensorflow.python.ops import summary_ops_v2 as summary
51from tensorflow.python.ops import variable_scope
52from tensorflow.python.platform import analytics
53from tensorflow.python.platform import gfile
54from tensorflow.python.platform import remote_utils
55from tensorflow.python.platform import tf_logging as logging
56from tensorflow.python.summary import summary_iterator
57from tensorflow.python.tpu import tensor_tracer_flags
58from tensorflow.python.tpu import tensor_tracer_report
59from tensorflow.python.tpu import tpu
60from tensorflow.python.tpu.ops import tpu_ops
61from tensorflow.python.training import training_util
62
63_DEVICE_TYPE_TPU = 'tpu'
64_DEVICE_TYPE_CPU = 'cpu'
65_TRACE_MODE_PART_TENSOR_SIZE = 3
66
67_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range'
68_REASON_UNSAFE_OP = 'not-traced-unsafe-op'
69_REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op'
70_REASON_CONTROLFLOW_OP = 'not-traced-control-flow-op'
71_REASON_IN_CONTROL_FLOW = 'not-traced-in-control-flow'
72_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar'
73_REASON_SKIP_SCALAR = 'not-traced-scalar'
74_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op'
75_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch'
76_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape'
77_REASON_SCALAR_GET_TRACED = 'traced-scalar'
78_REASON_TENSOR_GET_TRACED = 'traced-tensor'
79_REASON_USER_INCLUDED = 'traced-user-included'
80_REASON_USER_EXCLUDED = 'not-traced-user-excluded'
81_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path'
82_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor'
83_REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op'
84
85_OUTPUT_STREAM_ESCAPE = 'file://'
86_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables'
87TENSOR_TRACER_SUMMARY_COLLECTION = 'tensor_tracer_summary_writers'
88_TRACE_FILE_NAME = 'trace.all'
89_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.'
90_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0
91_TENSOR_TRACER_STORAGE = 'tensor_tracer_storage'
92_TT_SNAPSHOT = 'tensor_tracer_snapshot'
93_REPLICA_ID_TAG = '#replica-id: '
94_SKIP_REPORT_FILE = 'None'  # Do not write report proto if --report_file=None
95
96_TT_SUMMARY_NORM = tensor_tracer_flags.TT_SUMMARY_NORM
97_TT_SUMMARY_MAX = tensor_tracer_flags.TT_SUMMARY_MAX
98_TT_SUMMARY_MAX_ABS = tensor_tracer_flags.TT_SUMMARY_MAX_ABS
99_TT_SUMMARY_MIN = tensor_tracer_flags.TT_SUMMARY_MIN
100_TT_SUMMARY_MEAN = tensor_tracer_flags.TT_SUMMARY_MEAN
101_TT_SUMMARY_VAR = tensor_tracer_flags.TT_SUMMARY_VAR
102_TT_SUMMARY_SIZE = tensor_tracer_flags.TT_SUMMARY_SIZE
103
104_TT_SUMMARY_TAG = 'tensor_tracer_summary'
105_TT_TENSORBOARD_PLUGIN_NAME = 'tensor_tracer'
106_TT_HOSTCALL_KEY = 'tensor_tracer_host_call'
107_TT_EVENT_FILE_SUFFIX = '.tensor_tracer'
108
109_TT_SUMMARY_MAX_QUEUE = 10
110
111
112def set_parameters(tensor_tracer_params=None):
113  """Enables tensor tracer and sets its parameters.
114
115  Example usage:
116    tensor_tracer_parameters = {'trace_dir': '/usr/tmp/trace_dir',
117                                'trace_mode': 'norm',
118                                'report_file': '/usr/tmp/trace_dir/report.all'}
119    tensor_tracer.set_parameters(tensor_tracer_parameters)
120
121  This sets up the parameters for tensor tracer. A call to tensor tracer as
122  below is necessary to enable debugging on CPUs and GPUs. On TPUs below can be
123  skipped as this call is hooked into tpu.rewrite.
124    tt = tensor_tracer.TensorTracer()
125    loss = tt.trace_cpu(tf.get_default_graph(), tensor_fetches=loss)
126
127  Args:
128    tensor_tracer_params: Tensor tracer parameter dictionary. Below gives
129    examples of these parameters: See tensor_tracer_report.py for all
130      parameters.
131        - enable: If set, tensor tracer will be enabled. Calling
132          enable_tensor_tracer automatically adds this parameters.
133        - trace_mode: The trace_mode to be used by tensor tracer. These include:
134          - summary: Collects multiple statistics for traced tensors, and writes
135            them a summary file that can be visualized using tensorboard. This
136            mode currently only works for TPUEstimator. It can be also be used
137            for other models, but outfeed must be handled by the user.
138          - norm: Collects norm of each traced tensor and writes them into a
139            text file pointed by 'trace_dir' flag. (Default mode).
140          - nan-inf: Checks the existince of NaNs and Infs in the tensor, and
141            writes a boolean value to a text file pointed by 'trace_dir' flag.
142            Note that 'norm' mode can also capture this information with more
143            numerical info.
144          - max-abs: Collects the absolute max for each traced tensors and
145            writes it into a text file pointed by 'trace_dir' flag.
146          - full-tensor: Writes the full tensor content of the traced tensors
147            into a text file pointed by 'trace_dir' flag.
148          - part-tensor: Writes a part of the tensor content of the traced
149            tensors into a text file pointed by 'trace_dir' flag.
150          - full_tensor_summary: Writes the full tensors as binary event files.
151            The outputs can be read using: trace =
152              tensor_tracer.read_tensor_tracer_event_file(event_file_path)
153
154        - report_file: Path to the metadata file that is written during graph
155          construction. If not set, metadata will be printed to stdout during
156          graph construction.
157        - trace_dir: Path where the execution traces will be written during the
158          graph execution. If not set, trace will be printed to stderr.
159        - trace_level: Tensor tracer aims to trace everything it can. This
160          introduces some overhead on graph execution and graph compilation
161          times. Using trace_level parameter, it is possible to trace operation
162          based on their priorities. For example, - trace_level=7 is the highest
163          trace_level, in which every op is traced. - trace_level=6 will skip
164          constant operations such as tf.constant. - trace_level=5 will skip
165          less important ops such as tf.identities. - The default trace_level=3,
166          that will skip concat ops, or random number generators. - To reduce
167          the graph compile time overhead, trace_level can be set to 0, that
168          will skip additions, and substractions, and multiplications as well.
169        - excluded_opnames: If set, any matching op name will not be traced.
170          excluded_opnames can be set as a regular expression. E.g,
171          excluded_opnames=.* will exclude everything.
172        - excluded_optypes: If set, any matching op type will not be traced.
173          excluded_optypes can be set as a regular expression. E.g,
174          excluded_optypes=.* will exclude everything. excluded_optypes=MatMul
175          will exclude all MatMul ops from tracing.
176        - included_opnames: If set, any matching op name will be forced to be
177          traced. included_opnames can be set as a regular expression. E.g,
178          '--included_opnames=some_op --excluded_opname=*.' will only trace
179          some_op.
180        - included_optypes: If set, any matching op type will be forced to be
181          traced. included_optypes can be set as a regular expression. E.g,
182          '--included_optypes=some_op_type --excluded_optypes=*.' will trace
183          only the ops with type 'some_op_type'
184        - flush_summaries: If summary mode is used, flush_summaries=1 will
185          flush summaries using outside compilation. Note that, if used with
186          low level APIs, flush_summaries=1 is necessary to obtain results.
187        Advanced Flags:
188        - trace_scalar: Scalar values are not traced by default. If this flag is
189          set, scalar values will also be traced.
190        - op_range: In the form of '%d:%d' that limits the tracing to the ops
191          within this limit. --op_range='5:10' will trace only the ops that have
192            topological order between 5-10.
193        - submode: 'brief' or 'detailed'. If the trace mode is not compact,
194          brief mode will print only the id of each traced tensor to save some
195          space. 'detailed' mode prints the full tensor name.
196        - use_fingerprint_subdirectory: The trace directory will be chosen as
197          using the fingerprint of the trace metadata under the provided
198          trace_dir.
199  """
200  flags = '--%s=1' % tensor_tracer_flags.FLAG_NAME_ENABLE
201  if tensor_tracer_params:
202    for key, value in tensor_tracer_params.items():
203      flags += ' --%s=%s' % (key, value)
204  os.environ[tensor_tracer_flags.FLAGS_ENV_VAR] = flags
205
206
207def op_priority(op_type):
208  """Returns the priority of the op.
209
210  If the priority of the op is k, it will be traced if trace_level>=k.
211  Args:
212    op_type: String name of the operation type.
213  Returns:
214    Integer value corresponding the priority of the op.
215  """
216  if op_type in ('Const', 'Shape', 'BroadcastGradientArgs', 'Range',
217                 'VariableShape', 'Fill', 'OneHot', 'ShapeN'):
218    # Lowest priority ops, e.g., constant ops across different steps,
219    # They will be traced only if trace_level>=7
220    return 7
221
222  if op_type in ('Identity', 'Cast', 'Reshape', 'ExpandDims', 'StopGradient',
223                 'PreventGradient', 'Squeeze'):
224    # Operations without numerical effects.
225    # They will be only if trace_level>=6
226    return 6
227  if op_type in ('ConcatV2', 'Concat', 'StridedSlice', 'Slice', 'Pack', 'Tile',
228                 'CollectivePermute', 'SplitV'):
229    # Operations that merge or slice an input, will be traced if trace_level>=5
230    return 5
231  if op_type in ('Pad', 'RandomUniformInt', 'GreaterEqual'):
232    # Operations less likely to provide useful information,
233    # will be traced if trace_level>=4
234    return 4
235  if op_type in ('Sum', 'AddV2', 'Add', 'AddN', 'BiasAdd', 'CrossReplicaSum'):
236    # Add operations that are less likely create any issues, will be traced
237    # if trace_level>=3 (default=3)
238    return 3
239  if op_type in ('Neg', 'Sub'):
240    # Sub operations that are less likely create any issues, will be traced
241    # trace_level>=2
242    return 2
243  if op_type in ('Mul', 'Square', 'MatMul', 'RandomUniform', 'Select',
244                 'Maximum', 'Mean', 'Variance'):
245    # Multiplication and some other operations, will be traced if trace_level>=1
246    return 1
247  return 0
248
249
250def read_tensor_tracer_event_file(event_file):
251  """Reads the event file written by tensor tracer.
252
253  This can be used to read the full tensors written into binary event files by
254  by TensorTracer with trace_mode=full_tensor_summary.
255
256  Example usage:
257    result_dict = tensor_tracer.read_tensor_tracer_event_file(event_file_path)
258    for step, tensor_dict in result_dict.items():
259      for tensor_name, full_tensor_content in tensor_dict.items():
260        logging.info(tensor_name, full_tensor_content)
261
262  Args:
263    event_file: Path to the event file that contains only tensor tracer events.
264  Returns:
265    An event dictionary in the form of
266    {step_number: {tensor_name: tensor_content}}
267  Raises:
268    ValueError: If an unexpected trace is found.
269  """
270  event_dict = {}
271  for trace_event in summary_iterator.summary_iterator(event_file):
272    # First event is an event with file_version: "brain.Event:2"
273    if not trace_event.HasField('summary'):
274      continue
275    step = trace_event.step
276    if step not in event_dict:
277      event_dict[step] = {}
278
279    if len(trace_event.summary.value) != 1:
280      raise ValueError('Single step contains %d summary values,'
281                       ' expected 1.' % len(trace_event.summary.value))
282    tensor_value = trace_event.summary.value[0]
283    tensor_name = tensor_value.tag
284
285    real_shape = [d.size for d in tensor_value.tensor.tensor_shape.dim]
286    tensor_content = np.frombuffer(
287        tensor_value.tensor.tensor_content,
288        dtypes.DType(tensor_value.tensor.dtype).as_numpy_dtype()
289        ).reshape(real_shape)
290    event_dict[step][tensor_name] = tensor_content
291  return event_dict
292
293
294def trace_tensor(tensor, tracepoint_name=None):
295  """Programmatic interface to trace a tensor with Tensor Tracer.
296
297  Tensor Tracer, by default, traces all tensors in the execution. This function
298  can be used to limit traced tensors. If this function is called for a subset
299  of the tensors, only those will be traced.
300
301  For example, Tensor Traacer will only trace c below.
302    c = tf.MatMul(a, b)
303    tensor_tracer.trace_tensor(c)
304    d = tf.add(c, 1)
305  Args:
306     tensor: the tensor object for which the tracing is requested.
307     tracepoint_name: an optional tensor tracepoint name string. A tracepoint
308       name is an Tensor Tracer internal name for the tensor. It is useful when
309       comparing equivalent traces from different models that have different
310       tensor namings. Equivalent tensors (with different names) can be mapped
311       to each other by assigning a common tracepoint_name.
312
313  Returns:
314    The provided tensor.
315  """
316  if tracepoint_name is None:
317    tracepoint_name = tensor.name
318  tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION)
319  tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION,
320                                 (tensor, tracepoint_name))
321  return tensor
322
323
324def keras_layer_tracepoint(layer, checkpoint_name):
325  """An interface for adding the tensor outputs of a keras layer.
326
327  Encapsulates trace_tensor.
328
329  Args:
330     layer: A keras layer.
331     checkpoint_name: a string name for the checkpoint. This name has to be a
332     unique name if used within model comparison. The tensors that have the same
333     checkpoint identifier is compared in model comparison.
334
335  Returns:
336    The provided layer.
337  """
338  try:
339    outputs = layer.output
340    if tensor_util.is_tf_type(outputs):
341      trace_tensor(outputs, '%s' % (checkpoint_name))
342    else:
343      idx = 0
344      for output_tensor in outputs:
345        if tensor_util.is_tf_type(outputs):
346          trace_tensor(output_tensor, '%s_%d' % (checkpoint_name, idx))
347        idx += 1
348  except AttributeError:
349    pass
350  except RuntimeError:
351    pass
352  return layer
353
354
355class TensorTracer(object):
356  """A software construct for tracing tensor values in a TF graph.
357
358  This utility is disabled by default. It is hooked into tpu.rewrite, so it can
359  easily be enabled on TPUs by setting the TENSOR_TRACER_FLAGS env variable as
360  below without a code change.
361    export TENSOR_TRACER_FLAGS="--enable=1"
362
363  Below is the use example to enable it on CPUs or GPUs, or for more advance use
364  cases on TPUs.
365
366    a = x + 1
367    b = a * 2
368    rs = tf.reduce_sum(b)
369    tensor_tracer.set_parameters({'trace_dir': 'path/to/trace_dir',
370                             'report_file: 'path/to/report/file'})
371    tt = tensor_tracer.TensorTracer()
372    if on_tpu:
373      rs = tt.trace_tpu(tf.get_default_graph(),
374                          tensor_fetches=rs)
375    else:
376      rs = tt.trace_cpu(tf.get_default_graph(),
377                          tensor_fetches=rs)
378    session.run(rs)
379
380  If it is enabled, it will trace the output tensor values of
381  selected Ops in the graph. It has two outputs: (1) the traces and (2)
382  a report. The traces are dumped to a specified directory during the graph
383  execution, while the report is dumped during the graph construction.
384  By passing options via the env variable, users can change:
385     (1) the trace mode (e.g., detecting NaN/Inf, printing partial or
386         full tensor values)
387     (2) which Ops to be traced (via op.name or op.type)
388     (3) output trace file path.
389
390  """
391  # The set of graphs that are rewritten by tensor tracer.
392  _traced_graphs = set()
393
394  @staticmethod
395  def is_enabled():
396    """Returns True if TensorTracer is enabled."""
397    return tensor_tracer_flags.TTParameters().is_enabled()
398
399  @staticmethod
400  def check_device_type(device_type):
401    """Checks if the given device type is valid."""
402
403    if device_type not in (_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU):
404      raise ValueError('Invalid device_type "%s"'%device_type)
405
406  @staticmethod
407  def check_trace_mode(device_type, trace_mode):
408    """Checks if the given trace mode work on the given device type.
409
410    Args:
411      device_type: Device type, TPU, GPU, CPU.
412      trace_mode: Tensor tracer trace mode.
413    Raises:
414      ValueError: If the given trace mode is not supported for the device.
415    """
416    if trace_mode == tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY:
417      if device_type != _DEVICE_TYPE_TPU:
418        raise ValueError('Device_type "%s" is not yet supported for '
419                         'trace mode "%s"' % (device_type, trace_mode))
420
421  @staticmethod
422  def loop_cond_op(op):
423    return op.type in ('LoopCond', 'RefLoopCond')
424
425  @staticmethod
426  def while_loop_op(op):
427    """Returns true if op is one of the special ops of in a while loop.
428
429    Args:
430       op: A tf.Operation.
431
432    Returns:
433       True if the given op is one of [Switch, Merge, Enter, Exit,
434       NextIteration, LoopCond], which are all building blocks for TF while
435       loops.
436    """
437    return  (control_flow_util.IsLoopSwitch(op) or
438             control_flow_util.IsLoopMerge(op) or
439             control_flow_util.IsLoopEnter(op) or
440             control_flow_util.IsLoopExit(op) or
441             TensorTracer.loop_cond_op(op) or
442             op.type in ('RefNextIteration', 'NextIteration'))
443
444  @staticmethod
445  def control_flow_op(op):
446    """Returns true if op is one of the special ops of in a while loop.
447
448    Args:
449       op: A tf.Operation.
450
451    Returns:
452       True if the given op is one of [Switch, Merge, Enter, Exit,
453       NextIteration, LoopCond], which are all building blocks for TF while
454       loops.
455    """
456    return  (control_flow_util.IsSwitch(op) or
457             control_flow_util.IsMerge(op))
458
459  @staticmethod
460  def unsafe_op(op):
461    """Returns True if this op is not safe to be traced."""
462
463    # Reasons for not including following op types:
464    #    Assign: cause incorrect result with CPU tracing.
465    if op.type == 'Assign':
466      return True
467    return False
468
469  @staticmethod
470  def device_mismatch(device_type, op):
471    if device_type == _DEVICE_TYPE_TPU:
472      # pylint: disable=protected-access
473      return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr
474      # pylint: enable=protected-access
475    return False
476
477  @staticmethod
478  def unsafe_scalar_trace(op):
479    """Return true if scalar output tensor from Op is not safe to be traced."""
480
481    # Tracing the following causes cycle in the graph on TPU.
482    if op.type in ('LoopCond', 'Enter', 'Merge', 'Const',
483                   'Switch', 'Less', 'ReadVariableOp'):
484      return True
485    # Tracing the following will cause casting-issue
486    # with the norm tracing mode or other compilation issues on CPU.
487    if op.type in ('VarHandleOp', 'IteratorToStringHandle',
488                   'IteratorGetNext', 'OneShotIterator',
489                   'IteratorV2', 'MakeIterator',
490                   'BatchDatasetV2', 'MapDataset',
491                   'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset',
492                   'Placeholder', 'PlaceholderWithDefault', 'StridedSlice'):
493      return True
494    return False
495
496  def _is_interesting_op(self, op):
497    """Returns True if the given op is not an interesting one to be traced."""
498    return op_priority(op.type) <= self._parameters.trace_level
499
500  @staticmethod
501  def reason(op_idx, details):
502    """Returns reason why the Op at op_idx is traced or not."""
503
504    return '%d %s'%(op_idx, details)
505
506  def __init__(self):
507    """Initializes a TensorTracer.
508
509    Sets the various member fields from the flags (if given) or the defaults.
510    """
511    self._replica_id = None
512    self._tt_config = tensor_tracer_report.TensorTracerConfig()
513    self._parameters = None
514    self._host_call_fn = {}
515    self._cache_variables = {}
516    self._traced_op_names = set()
517    self._report_proto = None
518    self._temp_cache_var = []
519    self._report_proto_path = ''
520    self._outmost_context = None
521
522  def report_proto(self):
523    """Getter for tensor_tracer.proto object for summary and full_tensor_summary modes.
524
525    Returns:
526      A tensor_tracer.proto object.
527    Raises:
528      ValueError if called before tracing happens, or when trace mode is not
529      summary or full_tensor_summary.
530    """
531    if self._report_proto:
532      return self._report_proto
533    else:
534      raise ValueError('Call to report_proto must be done after tracing.'
535                       'Report proto only exists for '
536                       'trace_mode=[summary|full_tensor_summary]')
537
538  def report_proto_path(self):
539    """Getter for path where tensor_tracer.proto object should be written.
540
541    Returns:
542      A string path.
543    """
544    return self._report_proto_path
545
546  def _get_all_cache_variables(self):
547    return self._cache_variables
548
549  def _create_or_get_tensor_values_cache(self, cache_name, graph=None,
550                                         shape=None, dtype=dtypes.float32):
551    """Creates a variable as the cache to store intermediate tensor values.
552
553    Args:
554      cache_name: Name to be given to the cache (an instance of tf.variable).
555      graph: Tensorflow graph.
556      shape: A list of dimensions.
557      dtype: Data type of created cache.
558    Returns:
559      A ref to newly created or existing cache with the given dimensions.
560    Raises:
561      ValueError: If missing a parameter to create the cache.
562    """
563    def _escape_namescopes(variable_name):
564      # TODO(deveci): This might cause name collisions as in "foo/bar/mytensor"
565      # and "foo_bar/mytensor".
566      return variable_name.replace('/', '_').replace(':', '_')
567
568    if cache_name not in self._cache_variables:
569      if graph is None:
570        raise ValueError('Graph must be provided at cache creation.')
571      if shape is None:
572        raise ValueError('shape must be provided at cache creation.')
573      graph = graph or ops.get_default_graph()
574      if dtype.is_integer:
575        init_val = int(_COMPACT_TRACE_ENTRY_INIT_VALUE)
576      else:
577        init_val = _COMPACT_TRACE_ENTRY_INIT_VALUE
578
579      # Create in proper graph and base name_scope.
580      with graph.as_default() as g, g.name_scope(None):
581        self._cache_variables[cache_name] = variable_scope.get_variable(
582            _TT_SNAPSHOT + '_' + _escape_namescopes(cache_name),
583            shape=shape, dtype=dtype,
584            initializer=init_ops.constant_initializer(init_val),
585            trainable=False,
586            use_resource=True,
587            collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.LOCAL_VARIABLES])
588    return self._cache_variables[cache_name]
589
590  def _add_replica_id_to_graph(self):
591    """Adds nodes for computing the replica ID to the graph."""
592
593    if self._tt_config.num_replicas:
594      with ops.control_dependencies(None):
595        # Uses None as dependency to run outside of TPU graph rewrites.
596        self._replica_id = tpu_ops.tpu_replicated_input(
597            list(range(self._tt_config.num_replicas)),
598            name='tt_replica_id')
599    else:
600      self._replica_id = 'unknown'
601
602  def _inside_op_range(self, idx):
603    """Return True if the given index is inside the selected range."""
604
605    if idx < self._parameters.op_range[0]:
606      return False
607    return (self._parameters.op_range[1] < 0 or
608            idx <= self._parameters.op_range[1])
609
610  def _is_user_included_op(self, op):
611    """Checks whether the op is included in the tensor tracer flags.
612
613    Args:
614      op: tf Operation
615    Returns:
616      True, if the op is included.
617      An op is included if:
618      - Its op name is given in included_opnames
619      - Its op type is given in included_optypes
620      - The op is at most _trace_ops_before_included hops before an included op
621      - The op is at most _trace_ops_after_included hops after an included op
622    """
623    for opname_re in self._parameters.included_opname_re_list:
624      if opname_re.match(op.name):
625        return True
626
627    for optype_re in self._parameters.included_optype_re_list:
628      if optype_re.match(op.type):
629        return True
630    return False
631
632  def _is_user_excluded_op(self, op):
633    for opname_re in self._parameters.excluded_opname_re_list:
634      if opname_re.match(op.name):
635        return True
636    for optype_re in self._parameters.excluded_optype_re_list:
637      if optype_re.match(op.type):
638        return True
639    return False
640
641  def _signature_types(self):
642    """Returns a dictionary holding the order of signatures in the cache for the selected trace mode."""
643    if self._parameters.trace_mode in set([
644        tensor_tracer_flags.TRACE_MODE_NAN_INF,
645        tensor_tracer_flags.TRACE_MODE_NORM,
646        tensor_tracer_flags.TRACE_MODE_MAX_ABS]):
647      return {self._parameters.trace_mode: 0}
648    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
649      return self._parameters.summary_signatures
650    return {}
651
652  def _num_signature_dimensions(self):
653    return len(self._signature_types())
654
655  def _use_temp_cache(self):
656    """Returns true if the intermediate values should be stacked instead of being stored in a tf.Variable.
657
658    Returns:
659      A boolean, denoting whether to use a temporary cache or not.
660    """
661    # If full tensors need to be stored tf.variables, then do not use temp
662    # variables to store them.
663    if self._use_tensor_buffer():
664      return False
665    if self._use_tensor_values_cache():
666      return self._parameters.use_temp_cache_var
667    else:
668      # Temporary caches only replaces tf.Variables caches. If no cache is used
669      # return False.
670      return False
671
672  def _use_tensor_values_cache(self):
673    """Returns True if immediate tensors should be first saved to a cache."""
674    return self._parameters.use_compact_trace
675
676  def _use_tensor_buffer(self):
677    """Returns true if the whole tensor needs to be cached/buffered in memory."""
678    return (self._parameters.trace_mode ==
679            tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)
680
681  def _merge_tensor_signatures(self, signatures):
682    """Returns a tensor that merges the given signatures.
683
684    Args:
685      signatures: A dictionary of the signature updates from signature name to
686      a tensor of dimension [1].
687    Returns:
688      A tensor that concats the signature values in a predefined order.
689    """
690    sorted_update = []
691    if self._num_signature_dimensions() > 1:
692      signature_indices = self._signature_types()
693      for _, val in sorted(signatures.items(),
694                           key=lambda item: signature_indices[item[0]]):
695        sorted_update.append(val)
696      updates = array_ops.stack(
697          sorted_update, axis=0, name='merge_single_op_signatures')
698    elif self._num_signature_dimensions() == 1:
699      # Avoid stack operation if there is only a single signature.
700      (_, val), = signatures.items()
701      updates = val
702    else:
703      raise ValueError('Cannot merge 0 signatures.')
704    return updates
705
706  def _save_tensor_value_to_tmp_cache(self, cache_idx, updates):
707    """Returns an op that will save the given updates to an entry in the cache.
708
709    Args:
710      cache_idx: The cache index of the tensor within the cache.
711      updates: A dictionary of the signature updates from signature name to
712      a tensor of dimension [1].
713    """
714    updates = self._merge_tensor_signatures(updates)
715    updates = array_ops.reshape(updates,
716                                [self._num_signature_dimensions()])
717    self._temp_cache_var[cache_idx] = updates
718
719  def _save_tensor_value_to_cache_op(self, cache_idx, updates):
720    """Returns an op that will save the given updates to an entry in the cache.
721
722    Args:
723      cache_idx: The cache index of the tensor within the cache.
724      updates: A dictionary of the signature updates.
725    Returns:
726      Cache update operation.
727    """
728    # state_ops.scatter_update allows updates only along the first dimension.
729    # Make a compact array by concatenating different signatures, and update
730    # them all together.
731    updates = self._merge_tensor_signatures(updates)
732    updates = array_ops.reshape(updates,
733                                [1, self._num_signature_dimensions()])
734    indices = constant_op.constant([cache_idx])
735    cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG)
736    return state_ops.scatter_update(cache, indices, updates).op
737
738  def _snapshot_tensor(self, tensor):
739    """Creates a new tf.Variable and a new tf.Operation that assigns the value of the tensor to this variable.
740
741    Args:
742      tensor: tensor whose values will be stored in a new tf.Variable.
743    Returns:
744      An assignment operation.
745    """
746
747    snapshot_variable = self._create_or_get_tensor_values_cache(
748        tensor.name, tensor.op.graph,
749        tensor.shape.as_list(), tensor.dtype)
750    return state_ops.assign(snapshot_variable, tensor).op
751
752  def _preprocess_traced_tensor(self, tensor):
753    """Computes NAN/Norm/Max on TPUs before sending to CPU.
754
755    Args:
756      tensor: The tensor to be traced.
757    Returns:
758      A tensor that should be input to the trace_function.
759    Raises:
760      RuntimeError: If the trace mode is invalid.
761    """
762
763    def _detect_nan_inf(tensor):
764      """Trace function for detecting any NaN/Inf in the tensor."""
765
766      if tensor.dtype.is_floating:
767        mask = math_ops.reduce_any(
768            gen_math_ops.logical_or(
769                gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor)))
770        output_tensor = control_flow_ops.cond(
771            mask,
772            lambda: constant_op.constant([1.0]),
773            lambda: constant_op.constant([0.0]))
774      else:
775        output_tensor = constant_op.constant([0.0])
776      return output_tensor
777
778    def _compute_signature(tensor, tf_op, cast_to_f32=True):
779      if cast_to_f32:
780        tensor = math_ops.cast(tensor, dtypes.float32)
781      output_tensor = tf_op(tensor)
782      # Return type should be scalar. Set it if it does not have the
783      # information.
784      if not output_tensor.get_shape().is_fully_defined():
785        output_tensor = array_ops.reshape(output_tensor, [])
786      return output_tensor
787
788    def _show_size(tensor):
789      # In order to check the size of a tensor.
790      # Not all sizes are known at the compile time, also, different replicas
791      # sometimes get different sizes of tensors.
792      # Collect it here to be used in merging replica data.
793      tsize = _compute_signature(tensor, array_ops.size, cast_to_f32=False)
794      # Cast to float32, so that it can be placed into same cache with other
795      # signatures.
796      return math_ops.cast(tsize, dtypes.float32)
797
798    def _show_max(tensor, cast_to_f32=True):
799      # returns -inf for empty tensor
800      return _compute_signature(tensor, math_ops.reduce_max, cast_to_f32)
801
802    def _show_min(tensor, cast_to_f32=True):
803      # returns inf for empty tensor
804      return _compute_signature(tensor, math_ops.reduce_min, cast_to_f32)
805
806    def _show_norm(tensor, cast_to_f32=True):
807      # returns 0 for empty tensor
808      return _compute_signature(tensor, linalg_ops.norm, cast_to_f32)
809
810    def _show_mean_and_variance(tensor, cast_to_f32=True):
811      """Returns the mean and variance of the given tensor."""
812      if cast_to_f32:
813        tensor = math_ops.cast(tensor, dtypes.float32)
814      # returns nan for empty tensor
815      mean, var = nn_impl.moments(array_ops.reshape(tensor, [-1]), axes=[0])
816      # The shape has to be 1. Set it if it does not have the information.
817      if not mean.get_shape().is_fully_defined():
818        mean = array_ops.reshape(mean, [])
819      if not var.get_shape().is_fully_defined():
820        var = array_ops.reshape(var, [])
821      return mean, var
822
823    def _show_max_abs(tensor, cast_to_f32=True):
824      return _compute_signature(
825          tensor, lambda t: math_ops.reduce_max(math_ops.abs(t)), cast_to_f32)
826
827    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF:
828      return {self._parameters.trace_mode: _detect_nan_inf(tensor)}
829    if (self._parameters.trace_mode ==
830        tensor_tracer_flags.TRACE_MODE_PART_TENSOR):
831      return {self._parameters.trace_mode: tensor}
832    if (self._parameters.trace_mode in (
833        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR,
834        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)):
835      return {self._parameters.trace_mode: tensor}
836    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NORM:
837      return {self._parameters.trace_mode: array_ops.reshape(
838          _show_norm(tensor), [1])}
839    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_MAX_ABS:
840      return {self._parameters.trace_mode: _show_max_abs(tensor)}
841
842    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
843      tensor = math_ops.cast(tensor, dtypes.float32)
844      result_dict = {}
845      # Call mean and variance computation here to avoid adding the same nodes
846      # twice.
847      if (_TT_SUMMARY_MEAN in self._signature_types() or
848          _TT_SUMMARY_VAR in self._signature_types()):
849        mean, variance = _show_mean_and_variance(tensor, cast_to_f32=False)
850
851      for signature_name, _ in sorted(self._signature_types().items(),
852                                      key=lambda x: x[1]):
853        if signature_name == _TT_SUMMARY_NORM:
854          signature_result_tensor = _show_norm(tensor, cast_to_f32=False)
855        elif signature_name == _TT_SUMMARY_MAX:
856          signature_result_tensor = _show_max(tensor, cast_to_f32=False)
857        elif signature_name == _TT_SUMMARY_MAX_ABS:
858          signature_result_tensor = _show_max_abs(tensor, cast_to_f32=False)
859        elif signature_name == _TT_SUMMARY_MIN:
860          signature_result_tensor = _show_min(tensor, cast_to_f32=False)
861        elif signature_name == _TT_SUMMARY_SIZE:
862          signature_result_tensor = _show_size(tensor)
863        elif signature_name == _TT_SUMMARY_MEAN:
864          signature_result_tensor = mean
865        elif signature_name == _TT_SUMMARY_VAR:
866          signature_result_tensor = variance
867        else:
868          raise ValueError('Unknown signature type :%s.' % signature_name)
869
870        result_dict[signature_name] = signature_result_tensor
871      return result_dict
872
873    raise RuntimeError(
874        'Tensor trace fun for %s is not yet implemented'
875        % self._parameters.trace_mode)
876
877  def _make_tensor_trace_fun(self, tensor_name, tensor_trace_order):
878    """Makes the tensor tracing function called by outside compilation.
879
880    Args:
881      tensor_name: name of the tensor being traced.
882      tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
883    Returns:
884      A function to be passed as the first argument to outside compilation.
885
886    Raises:
887      RuntimeError: If the trace mode is invalid.
888    """
889
890    def _print_tensor(tensor_name, num_elements, tensor, output_tensor):
891      """Prints a tensor value to a file.
892
893      Args:
894        tensor_name: name of the tensor being traced.
895        num_elements: number of elements to print (-1 means print all).
896        tensor: the tensor needs to be returned.
897        output_tensor: the tensor needs to be printed.
898
899      Returns:
900        The same tensor passed via the "tensor" argument.
901
902      Raises:
903        ValueError: If tensor_name is not already in
904                    tensor_trace_order.tensorname_to_cache_idx.
905      """
906
907      if self._parameters.is_brief_mode():
908        if tensor_name not in tensor_trace_order.tensorname_to_cache_idx:
909          raise ValueError(
910              'Tensor name %s is not in the tensorname_to_cache_idx' %
911              tensor_name)
912        msg = '%d' % tensor_trace_order.tensorname_to_cache_idx[tensor_name]
913      else:
914        msg = '"%s"' % tensor_name
915
916      if self._parameters.trace_dir:
917        output_path = os.path.join(
918            self._parameters.trace_dir,
919            _TRACE_FILE_NAME + self._get_outfile_suffix())
920        output_stream = _OUTPUT_STREAM_ESCAPE + output_path
921      else:
922        output_stream = sys.stderr
923      return logging_ops.print_v2(msg, array_ops.shape(output_tensor),
924                                  '@', self._replica_id,
925                                  '\n', output_tensor, '\n',
926                                  summarize=num_elements,
927                                  output_stream=output_stream)
928
929    def _show_part_tensor(tensor):
930      """Trace function for printing part of the tensor."""
931
932      return _print_tensor(tensor_name, _TRACE_MODE_PART_TENSOR_SIZE,
933                           tensor, tensor)
934
935    def _show_full_tensor(tensor):
936      """Trace function for printing the entire tensor."""
937
938      return _print_tensor(tensor_name, -1, tensor, tensor)
939
940    if (self._parameters.trace_mode ==
941        tensor_tracer_flags.TRACE_MODE_PART_TENSOR):
942      return _show_part_tensor
943    # The input tensor has a shape of "[1]" for TRACE_MODE_NAN_INF,
944    # TRACE_MODE_NORM, and TRACE_MODE_MAX_ABS, as related computations are
945    # performed within TPUs and only their results are transferred to CPU.
946    # Simply, print the full tensor for these trace modes.
947    if self._parameters.trace_mode in (
948        tensor_tracer_flags.TRACE_MODE_NAN_INF,
949        tensor_tracer_flags.TRACE_MODE_NORM,
950        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR,
951        tensor_tracer_flags.TRACE_MODE_MAX_ABS,
952        tensor_tracer_flags.TRACE_MODE_SUMMARY
953        ):
954      return _show_full_tensor
955
956    raise RuntimeError('Tensor trace fun for %s is not yet implemented'
957                       %self._parameters.trace_mode)
958
959  def _is_in_control_flow(self, op):
960    """Returns true if the given op is inside a tf.cond or in tf.while_loop.
961
962    Args:
963      op: A tensorflow op that should be checked whether in control flow or not.
964    Returns:
965      A boolean value whether the op is in control flow or not.
966    """
967    return control_flow_util.IsInCond(op)
968
969  def _is_in_outmost_while_loop(self, op):
970    """Returns true if the op is at the same level with the training loop.
971
972    Returns false if the op is in an inner while loop or if it is outside of the
973    training loop.
974    Args:
975      op: tf.Operation
976
977    Returns:
978      A boolean.
979    """
980    ctxt = self._get_op_control_flow_context(op)
981    outer_while_context = control_flow_util.GetContainingWhileContext(ctxt)
982    return outer_while_context == control_flow_util.GetContainingWhileContext(
983        self._outmost_context)
984
985  def _should_trace_in_control_flow(self):
986    """Returns false incase it is not safe to trace ops in tf.cond or tf.while_loop."""
987    # As different from the other trace modes, TRACE_MODE_OPTIONAL_SUMMARY
988    # forces the execution of the traced tensors. We should not trace the ops
989    # that may not be executed due to control flow.
990    if self._use_temp_cache():
991      return False
992    elif self._tt_config.device_type == _DEVICE_TYPE_TPU:
993      # On TPUs do not trace in control flow unless we use caches to store
994      # intermediate values as calling outside compilation within an inner loop
995      # causes errors.
996      return self._use_tensor_values_cache() or self._use_tensor_buffer()
997    return True
998
999  def _skip_op(self, op_id, op, ops_in_exec_path, report_handler):
1000    """Returns True if we should not trace Op.
1001
1002    Args:
1003      op_id: Topological index of the op.
1004      op: tf.Operation
1005      ops_in_exec_path: Set of operations that are in the execution path.
1006      report_handler: An instance of tensor_tracer_report.TTReportHandle.
1007    Returns:
1008      True if the op should not be traced, false otherwise.
1009    """
1010    if TensorTracer.while_loop_op(op):
1011      report_handler.instrument_op(
1012          op, TensorTracer.reason(op_id, _REASON_WHILELOOP_OP))
1013      return True
1014    if TensorTracer.control_flow_op(op):
1015      report_handler.instrument_op(
1016          op, TensorTracer.reason(op_id, _REASON_CONTROLFLOW_OP))
1017      return True
1018    if TensorTracer.unsafe_op(op):
1019      report_handler.instrument_op(
1020          op, TensorTracer.reason(op_id, _REASON_UNSAFE_OP))
1021      return True
1022    if TensorTracer.device_mismatch(self._tt_config.device_type, op):
1023      report_handler.instrument_op(
1024          op, TensorTracer.reason(op_id, _REASON_DEVICE_MISMATCH))
1025      return True
1026    if op not in ops_in_exec_path:
1027      report_handler.instrument_op(
1028          op, TensorTracer.reason(op_id, _REASON_NOT_EXECUTED))
1029      return True
1030    # TensorTracer will not trace the operations that are in an inner while loop
1031    # or tf.cond when a temporary cache is used. Temporary cache adds direct
1032    # data dependencies to traced operations, and needs a static number of
1033    # traced operations. For these cases,
1034    # - We do not know the number of slots required when there are inner while
1035    # loops. TensorTracer can only trace the result of a while loop.
1036    # - We do not know ahead of time which branch of the tf.cond
1037    # will be taken, so we avoid introducing data dependencies for the
1038    # operations inside a tf.cond.
1039    # - We also cannot have a data dependency to an operation in a different
1040    # while context.
1041    if self._is_in_control_flow(op) or not self._is_in_outmost_while_loop(op):
1042      if not self._should_trace_in_control_flow():
1043        report_handler.instrument_op(
1044            op, TensorTracer.reason(op_id, _REASON_IN_CONTROL_FLOW))
1045        return True
1046    if self._is_user_included_op(op):
1047      report_handler.instrument_op(
1048          op, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
1049      return False
1050
1051    if not self._inside_op_range(op_id):
1052      report_handler.instrument_op(
1053          op, TensorTracer.reason(op_id, _REASON_OUTSIDE_OP_RANGE))
1054      return True
1055    if not self._is_interesting_op(op):
1056      report_handler.instrument_op(
1057          op, TensorTracer.reason(op_id, _REASON_LESS_INTERESTING_OP))
1058      return True
1059    if self._is_user_excluded_op(op):
1060      report_handler.instrument_op(
1061          op, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
1062      return True
1063    return False
1064
1065  def _skip_tensor(self, op_id, out_tensor, report_handler):
1066    """Returns True if we should not trace out_tensor.
1067
1068    Args:
1069      op_id: Topological index of the op producing tensor.
1070      out_tensor: tf.Tensor
1071      report_handler: An instance of tensor_tracer_report.TTReportHandle.
1072    Returns:
1073      True if the tensor should not be traced, false otherwise.
1074    """
1075
1076    # Skips a tensor if the tensor has a non-numeric type.
1077    #   Note: we cannot use check_ops.is_numeric_tensor(out_tensor)
1078    #         because it also excludes tensors with dtypes, bool, and
1079    #         float32_ref, which we actually want to trace.
1080    non_numeric_tensor_types = set([dtypes.variant, dtypes.resource,
1081                                    dtypes.string])
1082    if out_tensor.dtype in non_numeric_tensor_types:
1083
1084      report_handler.instrument_tensor(
1085          out_tensor, TensorTracer.reason(op_id, _REASON_NON_NUMERIC_TENSOR))
1086      return True
1087    # Skip a tensor if it feeds a special while loop op.
1088    if [consumer for consumer in out_tensor.consumers() if
1089        TensorTracer.while_loop_op(consumer)]:
1090      report_handler.instrument_tensor(
1091          out_tensor, TensorTracer.reason(op_id, _REASON_FEEDS_WHILELOOP_OP))
1092      return True
1093    if self._is_user_included_op(out_tensor.op):
1094      report_handler.instrument_tensor(
1095          out_tensor, TensorTracer.reason(op_id, _REASON_USER_INCLUDED))
1096      return False
1097    if self._is_user_excluded_op(out_tensor.op):
1098      report_handler.instrument_tensor(
1099          out_tensor, TensorTracer.reason(op_id, _REASON_USER_EXCLUDED))
1100      return True
1101    if not out_tensor.get_shape().is_fully_defined():
1102      # If trace mode is nan-inf, norm or max, then the tensor will be reduced
1103      # to a scalar before the outside compilation call.
1104      if self._parameters.trace_mode in (
1105          tensor_tracer_flags.TRACE_MODE_NAN_INF,
1106          tensor_tracer_flags.TRACE_MODE_NORM,
1107          tensor_tracer_flags.TRACE_MODE_MAX_ABS,
1108          tensor_tracer_flags.TRACE_MODE_SUMMARY
1109          ):
1110        report_handler.instrument_tensor(
1111            out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
1112        return False
1113      else:
1114        report_handler.instrument_tensor(
1115            out_tensor, TensorTracer.reason(op_id, _REASON_DYNAMIC_SHAPE))
1116        return True
1117    rank = len(out_tensor.shape)
1118    if rank < 1:
1119      # scalar
1120      if self._parameters.trace_scalar_ops:
1121        if TensorTracer.unsafe_scalar_trace(out_tensor.op):
1122          report_handler.instrument_tensor(
1123              out_tensor, TensorTracer.reason(op_id, _REASON_UNSAFE_SCALAR))
1124          return True
1125        else:
1126          report_handler.instrument_tensor(
1127              out_tensor, TensorTracer.reason(op_id, _REASON_SCALAR_GET_TRACED))
1128          return False
1129      else:
1130        report_handler.instrument_tensor(
1131            out_tensor, TensorTracer.reason(op_id, _REASON_SKIP_SCALAR))
1132        return True
1133    else:
1134      # tensor
1135      report_handler.instrument_tensor(
1136          out_tensor, TensorTracer.reason(op_id, _REASON_TENSOR_GET_TRACED))
1137      return False
1138
1139  def _filter_execution_path_operations(self, operations, fetches):
1140    """Returns the set of ops in the execution path to compute given fetches."""
1141
1142    # If no fetch provided, then return all operations.
1143    if fetches is None:
1144      return set(operations)
1145    # Convert to list, if a single element is provided.
1146    if not isinstance(fetches, (list, tuple)):
1147      fetches = [fetches]
1148    # If a tensor is given as fetch, convert it to op.
1149    op_fetches = []
1150    for fetch in fetches:
1151      if isinstance(fetch, ops.Operation):
1152        op_fetches.append(fetch)
1153      elif isinstance(fetch, ops.Tensor):
1154        op_fetches.append(fetch.op)
1155      else:
1156        raise RuntimeError('Given fetch:%s is neither a tensor nor an op.'
1157                           %fetch)
1158
1159    execution_path_operations = set(op_fetches)
1160    traverse_stack = list(op_fetches)
1161    while True:
1162      if not traverse_stack:
1163        break
1164      head_op = traverse_stack.pop()
1165      input_ops = [tensor_input.op for tensor_input in head_op.inputs]
1166      input_ops.extend(head_op.control_inputs)
1167
1168      for input_op in input_ops:
1169        if input_op not in execution_path_operations:
1170          # Filter out loop condition operations, tracing them causes a cycle.
1171          # Trace only the loop-body.
1172          if TensorTracer.loop_cond_op(input_op):
1173            continue
1174          execution_path_operations.add(input_op)
1175          traverse_stack.append(input_op)
1176    return execution_path_operations
1177
1178  def _determine_and_instrument_traced_tensors(self, graph_order,
1179                                               ops_in_exec_path,
1180                                               tensor_trace_points,
1181                                               report_handler):
1182    """Determines the tensors to trace and instruments the trace details.
1183
1184    Args:
1185      graph_order: graph_order tuple containing graph (tf.graph), operations
1186        (list of operations), op_to_idx (op id mapping), (tensors) list of
1187        tensors, tensor_to_idx (tensor id mapping), contains_cycle (whether
1188        there is a cycle in the graph), topological_order_or_cycle (list of ops
1189        in topological order or list of ops creating a cycle).
1190      ops_in_exec_path: Set of ops in the execution path.
1191      tensor_trace_points: Collection of programatic tensor trace points.
1192      report_handler: An instance of tensor_tracer_report.TTReportHandle.
1193    Returns:
1194      List of tensors to be traced.
1195    """
1196
1197    traced_tensors = []
1198    checkpoint_operations = set([tensor.op
1199                                 for (tensor, _) in tensor_trace_points])
1200    for op_id, op in enumerate(graph_order.operations):
1201      if checkpoint_operations and op not in checkpoint_operations:
1202        continue
1203      if self._skip_op(op_id, op, ops_in_exec_path, report_handler):
1204        continue
1205      for i in range(len(op.outputs)):
1206        out_tensor = op.outputs[i]
1207        if not self._skip_tensor(op_id, out_tensor, report_handler):
1208          traced_tensors.append(out_tensor)
1209    return traced_tensors
1210
1211  def _check_trace_files(self):
1212    """Checks if any requirements for trace files are satisfied."""
1213
1214    if not self._parameters.trace_dir:
1215      # traces will be written to stderr. No need to check trace files.
1216      return
1217    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_SUMMARY:
1218      # Output files are handled by tf.summary operations, no need to precreate
1219      # them.
1220      return
1221    if not gfile.Exists(self._parameters.trace_dir):
1222      file_io.recursive_create_dir(self._parameters.trace_dir)
1223      if not gfile.Exists(self._parameters.trace_dir):
1224        raise RuntimeError('Failed to create %s'%self._parameters.trace_dir)
1225
1226  def _create_temp_cache(self, num_traced_tensors, num_signatures):
1227    """Creates a temporary cache with the given dimensions.
1228
1229    Fills the self._temp_cache_var with num_traced_tensors tf.constant() ops
1230    that have shape of [num_signatures].
1231    Args:
1232      num_traced_tensors: Int, denoting total number of traced tensors.
1233      num_signatures: Int, denoting the number of statistics collected per
1234        tensors.
1235    """
1236    init_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE,
1237                                      dtype=dtypes.float32,
1238                                      shape=[num_signatures])
1239    self._temp_cache_var = [init_value for _ in range(num_traced_tensors)]
1240
1241  def _determine_trace_and_create_report(self, graph, ops_in_exec_path):
1242    """Work needs to be done prior to TPU or CPU tracing.
1243
1244    Args:
1245      graph: tf.graph
1246      ops_in_exec_path: Set of operations in the execution path.
1247    Returns:
1248      An instance of tensor_tracer_report.TensorTraceOrder, containing list of
1249      tensors to be traced with their topological order information.
1250    """
1251
1252    self._check_trace_files()
1253
1254    graph_order = tensor_tracer_report.sort_tensors_and_ops(graph)
1255    tensor_trace_points = graph.get_collection(_TENSOR_TRACER_COLLECTION)
1256
1257    report_handler = tensor_tracer_report.TTReportHandle()
1258    traced_tensors = self._determine_and_instrument_traced_tensors(
1259        graph_order, ops_in_exec_path, tensor_trace_points, report_handler)
1260    logging.info('TensorTracer is tracing %d tensors.', len(traced_tensors))
1261
1262    tensor_trace_order = tensor_tracer_report.TensorTraceOrder(graph_order,
1263                                                               traced_tensors)
1264    num_signatures = self._num_signature_dimensions()
1265    # Create a cache variable if compact_tracing is used.
1266    if num_signatures and self._use_tensor_values_cache():
1267      if self._use_temp_cache():
1268        self._create_temp_cache(len(traced_tensors), num_signatures)
1269      else:
1270        self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG,
1271                                                graph,
1272                                                [len(traced_tensors),
1273                                                 num_signatures])
1274    if self._parameters.trace_mode in (
1275        tensor_tracer_flags.TRACE_MODE_SUMMARY,
1276        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY):
1277      self._report_proto = report_handler.create_report_proto(
1278          self._tt_config, self._parameters, tensor_trace_order,
1279          tensor_trace_points, self._signature_types())
1280      if self._parameters.use_fingerprint_subdir:
1281        self._parameters.trace_dir = os.path.join(
1282            self._parameters.trace_dir, self._report_proto.fingerprint)
1283        logging.info('TensorTracer updating trace_dir to %s',
1284                     self._parameters.trace_dir)
1285      self._report_proto_path = tensor_tracer_report.report_proto_path(
1286          self._parameters.trace_dir)
1287      if self._parameters.report_file_path != _SKIP_REPORT_FILE:
1288        report_handler.write_report_proto(self._report_proto, self._parameters)
1289    else:
1290      report_handler.create_report(self._tt_config, self._parameters,
1291                                   tensor_trace_order, tensor_trace_points)
1292    return tensor_trace_order
1293
1294  def _create_host_call(self):
1295    return self._parameters.trace_mode in (
1296        tensor_tracer_flags.TRACE_MODE_SUMMARY,
1297        tensor_tracer_flags.TRACE_MODE_FULL_TENSOR_SUMMARY)
1298
1299  def _inspect_summary_cache(self, cache, replica_id, step_num, output_stream,
1300                             tensor_trace_order):
1301    """Generates a print operation to print trace inspection.
1302
1303    Args:
1304      cache: Tensor storing the trace results for the step.
1305      replica_id: Tensor storing the replica id of the running core.
1306      step_num: Step number.
1307      output_stream: Where to print the outputs, e.g., file path, or sys.stderr.
1308      tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
1309
1310    Returns:
1311      The Op to flush the cache to file.
1312    """
1313    def _inspect_tensor(tensor):
1314      """Returns the text to be printed for inspection output."""
1315      if (self._parameters.trace_mode ==
1316          tensor_tracer_flags.TRACE_MODE_NAN_INF):
1317        return control_flow_ops.cond(
1318            math_ops.greater(tensor, 0.0),
1319            lambda: 'has NaNs/Infs!',
1320            lambda: 'has no NaNs or Infs.')
1321      else:
1322        return tensor
1323
1324    # Check if the cache includes any nan or inf
1325    if self._parameters.trace_mode == tensor_tracer_flags.TRACE_MODE_NAN_INF:
1326      # Cache has 1s or 0s if the mode is NaN_INF
1327      step_has_nan_or_inf = math_ops.greater(math_ops.reduce_sum(cache), 0.0)
1328    else:
1329      # Cache has the actual numerics for other modes.
1330      step_has_nan_or_inf = math_ops.reduce_any(
1331          gen_math_ops.logical_or(
1332              gen_math_ops.is_nan(cache), gen_math_ops.is_inf(cache)))
1333
1334    # Summarizing message for each step.
1335    step_error_message = control_flow_ops.cond(
1336        step_has_nan_or_inf,
1337        lambda: 'NaNs or Infs in the step!',
1338        lambda: 'No numerical issues have been found for the step.')
1339
1340    # No need to print core numbers if the cache is merged already.
1341    if self._parameters.collect_summary_per_core:
1342      stats = ['\n\n', 'core:', replica_id, ',', 'step:', step_num, '-->',
1343               step_error_message,
1344               'Printing tensors for mode:%s...' % self._parameters.trace_mode]
1345    else:
1346      stats = ['\n\n', 'step:', step_num, '-->', step_error_message,
1347               'Printing tensors for mode:%s...' % self._parameters.trace_mode]
1348
1349    for tensor_name, cache_idx in sorted(
1350        tensor_trace_order.tensorname_to_cache_idx.items(),
1351        key=lambda item: item[1]):
1352      if self._parameters.collect_summary_per_core:
1353        stats.extend([
1354            '\n', 'core:', replica_id, ',', 'step:', step_num, ',',
1355            tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])])
1356      else:
1357        stats.extend([
1358            '\n', 'step:', step_num, ',',
1359            tensor_name, '-->', _inspect_tensor(cache[cache_idx, 0])])
1360    return logging_ops.print_v2(*stats, summarize=-1,
1361                                output_stream=output_stream)
1362
1363  def _get_outfile_suffix(self):
1364    if remote_utils.is_remote_path(self._parameters.trace_dir):
1365      return remote_utils.get_appendable_file_encoding()
1366    else:
1367      return ''
1368
1369  def _generate_flush_cache_op(self, num_replicas, on_tpu, tensor_trace_order):
1370    """Generates an Op that will flush the cache to file.
1371
1372    Args:
1373      num_replicas: total number of replicas.
1374      on_tpu: if the graph is executed on TPU.
1375      tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
1376
1377    Returns:
1378      The Op to flush the cache to file.
1379    """
1380
1381    def _flush_fun(cache, replica_id, step_num):
1382      """Flushes the cache to a file corresponding to replica_id."""
1383
1384      def _f(file_index):
1385        """Generates a func that flushes the cache to a file."""
1386        def _print_cache():
1387          """Flushes the cache to a file."""
1388          replica_str = ('%d' % file_index)
1389          if self._parameters.trace_dir:
1390            output_path = (os.path.join(self._parameters.trace_dir,
1391                                        _COMPACT_TRACE_FILE_PREFIX)
1392                           + replica_str + self._get_outfile_suffix())
1393            output_stream = _OUTPUT_STREAM_ESCAPE + output_path
1394          else:
1395            output_stream = sys.stderr
1396
1397          new_step_line = _REPLICA_ID_TAG + replica_str
1398          print_ops = []
1399          if self._parameters.inspect_trace:
1400            if self._num_signature_dimensions() > 1:
1401              raise ValueError('Inspecting multi signatures are not supported.')
1402            print_ops.append(self._inspect_summary_cache(
1403                cache=cache, replica_id=replica_id, step_num=step_num,
1404                output_stream=output_stream,
1405                tensor_trace_order=tensor_trace_order))
1406          else:
1407            for i in range(self._num_signature_dimensions()):
1408              print_ops.append(logging_ops.print_v2(
1409                  new_step_line, '\n',
1410                  cache[:, i], '\n',
1411                  summarize=-1,
1412                  output_stream=output_stream))
1413          with ops.control_dependencies(print_ops):
1414            return constant_op.constant(0).op
1415        return _print_cache
1416
1417      def _eq(file_index):
1418        return math_ops.equal(replica_id, file_index)
1419
1420      flush_op_cases = {}
1421      flush_op_cases[_eq(0)] = _f(0)
1422      for i in range(1, num_replicas):
1423        if on_tpu and not self._parameters.collect_summary_per_core:
1424          # If this is the case, the cache is already merged for all cores.
1425          # Only first core flushes the cache.
1426          flush_op_cases[_eq(i)] = control_flow_ops.no_op
1427        else:
1428          flush_op_cases[_eq(i)] = _f(i)
1429      # Each replica needs to determine where to write their output.
1430      # To do this, we check if replica_id is 0, then 1, ..., and then
1431      # num_replicas - 1 statically; and return the corresponding static file
1432      # name. We cannot simply set the file name in python, as replica_id is
1433      # only known during tf runtime, and we cannot create dynamic filenames.
1434      return control_flow_ops.case(flush_op_cases, exclusive=True)
1435
1436    cache = self._create_or_get_tensor_values_cache(_TT_SUMMARY_TAG)
1437    if self._use_temp_cache():
1438      cache_val = cache
1439    else:
1440      cache_val = cache.value()
1441
1442    if on_tpu:
1443      # If we do not need to collect traces for all cores, merge and aggregate
1444      # per core trace.
1445      if not self._parameters.collect_summary_per_core:
1446        cache_val = self.merge_caches_on_tpu(cache_val)
1447        cache_val = self.aggregate_global_cache(cache_val)[0]
1448
1449      flush_op = tpu.outside_compilation(
1450          _flush_fun, cache_val, self._replica_id,
1451          array_ops.identity(training_util.get_or_create_global_step()))
1452    else:
1453      flush_op = _flush_fun(cache_val, self._replica_id,
1454                            training_util.get_or_create_global_step())
1455    if self._use_temp_cache():
1456      with ops.control_dependencies([flush_op]):
1457        return constant_op.constant(0).op
1458    else:
1459      # Re-initialize the local cache variable.
1460      with ops.control_dependencies([flush_op]):
1461        reset_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE,
1462                                           dtype=cache.dtype,
1463                                           shape=cache.shape)
1464        assign_op = state_ops.assign(cache, reset_value).op
1465        with ops.control_dependencies([assign_op]):
1466          return constant_op.constant(0).op
1467
1468  def _flush_tensor_values_cache(self, tensor_fetches, op_fetches, on_tpu,
1469                                 tensor_trace_order):
1470    """Flushes the intermediate tensor values in the graph to the cache.
1471
1472    Args:
1473      tensor_fetches: list of tensor results returned by the model_fn.
1474      op_fetches: list of ops that are returned by the model_fn, e.g., train_op.
1475      on_tpu: if the graph is executed on TPU.
1476      tensor_trace_order: TensorTraceOrder object holding tensorname to id map.
1477
1478    Returns:
1479      An identical copy of tensor_fetches.
1480    """
1481    # Add a dependency to op and tensor fetches to make sure that all tracing
1482    # ops are executed before flushing trace results.
1483    with ops.control_dependencies(op_fetches +
1484                                  [tensor.op for tensor in tensor_fetches]):
1485      flush_cache_op = self._generate_flush_cache_op(
1486          self._tt_config.num_replicas, on_tpu, tensor_trace_order)
1487      return control_flow_ops.tuple(tensor_fetches,
1488                                    control_inputs=[flush_cache_op])
1489
1490  def _process_tensor_fetches(self, tensor_fetches):
1491    """Check that tensor_fetches is not empty and have valid tensors."""
1492    # If none or empty list.
1493    if tensor_fetches is None:
1494      raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be '
1495                         'None.')
1496    if not isinstance(tensor_fetches, (list, tuple)):
1497      tensor_fetches = [tensor_fetches]
1498    elif not tensor_fetches:
1499      raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be '
1500                         'empty list.')
1501    fetches = []
1502    for fetch in tensor_fetches:
1503      if isinstance(fetch, ops.Tensor):
1504        fetches.append(fetch)
1505      else:
1506        raise RuntimeError('Given tensor_fetch:%s is not a tensor.' % fetch)
1507    return fetches
1508
1509  def _process_op_fetches(self, op_fetches):
1510    """Check that op_fetches have valid ops."""
1511    if op_fetches is None:
1512      return []
1513
1514    if not isinstance(op_fetches, (list, tuple)):
1515      op_fetches = [op_fetches]
1516
1517    fetches = []
1518    for fetch in op_fetches:
1519      if isinstance(fetch, ops.Operation):
1520        fetches.append(fetch)
1521      elif isinstance(fetch, ops.Tensor):
1522        fetches.append(fetch.op)
1523      else:
1524        logging.warning('Ignoring the given op_fetch:%s, which is not an op.' %
1525                        fetch)
1526    return fetches
1527
1528  def _convert_fetches_to_input_format(self, input_fetches, current_fetches):
1529    """Changes current_fetches' format, so that it matches input_fetches."""
1530    if isinstance(input_fetches, ops.Tensor):
1531      if len(current_fetches) != 1:
1532        raise RuntimeError('Tensor tracer input/output fetches do not match.')
1533      return current_fetches[0]
1534    else:
1535      if len(current_fetches) != len(current_fetches):
1536        raise RuntimeError('Tensor tracer input/output fetches do not match.')
1537      elif isinstance(input_fetches, tuple):
1538        return tuple(current_fetches)
1539      else:
1540        return current_fetches
1541
1542  def _get_op_control_flow_context(self, op):
1543    """Returns the control flow of the given op.
1544
1545    Args:
1546      op: tf.Operation for which the control flow context is requested.
1547    Returns:
1548      op_control_flow_context: which the is control flow context of the given
1549      op. If the operation type is LoopExit, returns the outer control flow
1550      context.
1551    """
1552    # pylint: disable=protected-access
1553    op_control_flow_context = op._control_flow_context
1554    # pylint: enable=protected-access
1555    if control_flow_util.IsLoopExit(op):
1556      op_control_flow_context = op_control_flow_context.outer_context
1557    return op_control_flow_context
1558
1559  def merge_caches_on_tpu(self, local_tpu_cache_tensor):
1560    """Merges the given caches on tpu.
1561
1562    Args:
1563      local_tpu_cache_tensor: A local tensor that needs to be merged
1564        by concanting data from other tpu cores.
1565    Returns:
1566      A merged tf.Tensor.
1567    Raises:
1568      RuntimeError: if there is no aggregate function defined for a signature.
1569    """
1570    x = array_ops.broadcast_to(
1571        local_tpu_cache_tensor,
1572        shape=[self._tt_config.num_replicas] +
1573        local_tpu_cache_tensor.shape.as_list())
1574    return tpu_ops.all_to_all(
1575        x, concat_dimension=0, split_dimension=0,
1576        split_count=self._tt_config.num_replicas)
1577
1578  def aggregate_global_cache(self, global_tt_summary_cache):
1579    """Merges the given caches on tpu.
1580
1581    Args:
1582      global_tt_summary_cache: The global tensor tracer summary cache tensor
1583        with shape (num_cores, num_traced_tensors, num_traced_signatures). First
1584        dimension corresponds to core_id, where global_tpu_cache_tensor[i]
1585        correspond to the local cache from core-i.
1586    Returns:
1587      An aggregated tf.Tensor.
1588    Raises:
1589      RuntimeError: if there is no aggregate function defined for a signature.
1590    """
1591
1592    # Merge only statistics tensor, if it is any other tensor we simply,
1593    # concatenate them.
1594    agg_fn_map = self._parameters.get_signature_to_agg_fn_map()
1595    signature_idx_map = self._signature_types()
1596    aggregation_result = []
1597    for signature, idx in sorted(signature_idx_map.items(),
1598                                 key=operator.itemgetter(1)):
1599      if signature not in agg_fn_map:
1600        raise RuntimeError('No aggregation function is defined for '
1601                           'signature %s.' % signature)
1602      # The dimensions of the statistics tensor is
1603      # num_cores x num_traced_tensors x num_signatures
1604      # value[:,:,idx] will return the portion of the tensor related
1605      # to signature.
1606      signature_tensor = global_tt_summary_cache[:, :, idx]
1607      # Merge it along the first (core) axis.
1608      agg_fn = agg_fn_map[signature]
1609      agg_tensor = agg_fn(signature_tensor, axis=0)
1610      aggregation_result.append(agg_tensor)
1611    # Merge results corresponding to different signatures
1612
1613    merged_signatures = array_ops.stack(aggregation_result)
1614    # merged_signatures has dimensions
1615    # num_signatures x num_traced_tensors, transpose it so that it
1616    # will match with the original structure
1617    # num_traced_tensors x num_signatures.
1618    transposed_signatures = array_ops.transpose(merged_signatures)
1619    # Expand 1 more dimension so that it will match with the expected
1620    # structure num_cores x num_traced_tensors x num_signatures.
1621    return array_ops.expand_dims(transposed_signatures, axis=0)
1622
1623  def _prepare_host_call_fn(self, processed_t_fetches, op_fetches):
1624    """Creates a host call function that will write the cache as tb summary.
1625
1626    Args:
1627      processed_t_fetches: List of tensor provided to session.run.
1628      op_fetches: List of operations provided to session.run.
1629    Raises:
1630      ValueError if trace_dir is not set.
1631    """
1632    if self._parameters.trace_dir is None:
1633      raise ValueError('Provide a trace_dir for tensor tracer in summary mode. '
1634                       '--trace_dir=/model/dir')
1635
1636    def _write_cache(step, event_file_suffix=None, **kwargs):
1637      """Writes the given caches as tensor summary.
1638
1639      Args:
1640        step: Step tensor with dimension [num_cores].
1641        event_file_suffix: Event filename suffix tensor.
1642        **kwargs: The dictionary of tensors that needs to be written as
1643          summaries. Key and value pairs within kwargs correspond to the tag
1644          name, and tensor content that will be written using summary.write.
1645          The trace_modes that use this function are:
1646            - summary: In summary mode, kwargs includes a single (tag, content)
1647            pair which are, _TT_SUMMARY_TAG and a tf.float32 signature_cache
1648            variable. The dimension of the signature_cache is:
1649              num_cores x num_traced_tensors x num_signatures.
1650            - full_tensor_summary: kwargs will include all traced tensors. Tag
1651            and content correspond to the name of the tensor, and its actual
1652            content.
1653      Returns:
1654        A tf.Operation that needs to be executed for the host call dependencies.
1655      Raises:
1656        RuntimeError: if there is no aggregate function defined for a signature.
1657      """
1658      file_suffix = _TT_EVENT_FILE_SUFFIX
1659      if event_file_suffix is not None:
1660        file_suffix = string_ops.string_join([file_suffix, event_file_suffix],
1661                                             separator='.')
1662      # TODO(deveci): Parametrize max_queue, so that flushing op can be called
1663      # less frequently.
1664      # Setting max_queue to 100 appears to be safe even when the number of
1665      # iterations are much lower, as the destructor of the writer flushes it.
1666      summary_write_ops = []
1667      summary_writer = summary.create_file_writer_v2(
1668          self._parameters.trace_dir,
1669          filename_suffix=file_suffix,
1670          max_queue=_TT_SUMMARY_MAX_QUEUE)
1671      ops.get_default_graph().add_to_collection(
1672          TENSOR_TRACER_SUMMARY_COLLECTION, summary_writer)
1673      with summary_writer.as_default():
1674        summary_metadata = summary_pb2.SummaryMetadata(
1675            plugin_data=summary_pb2.SummaryMetadata.PluginData(
1676                plugin_name=_TT_TENSORBOARD_PLUGIN_NAME))
1677        for key, value in kwargs.items():
1678          # Check whether we need to compute aggregated statistics that merge
1679          # all cores statistics.
1680          if not self._parameters.collect_summary_per_core:
1681            # Merge only statistics tensor, if it is any other tensor we simply,
1682            # concatenate them.
1683            # Also, if there is only a single core (first dim. is 0), then skip
1684            # aggregation.
1685            if key == _TT_SUMMARY_TAG and value.shape.as_list()[0] != 1:
1686              value = self.aggregate_global_cache(value)
1687
1688          with ops.control_dependencies([summary_writer.init()]):
1689            summary_write_ops.append(summary.write(
1690                _TT_SUMMARY_TAG + '/' + key, value, metadata=summary_metadata,
1691                step=step[0]))
1692      return control_flow_ops.group(summary_write_ops)
1693
1694    step = array_ops.reshape(training_util.get_or_create_global_step(), [1])
1695    self._host_call_fn = {}
1696
1697    host_call_deps = op_fetches + [tensor.op for tensor in processed_t_fetches]
1698
1699    caches_to_write = {}
1700    with ops.control_dependencies(host_call_deps):
1701      all_caches = self._get_all_cache_variables()
1702      for cache_name, cache_variable in all_caches.items():
1703        # Increase the cache rank by 1, so that when host call concatenates
1704        # tensors from different replicas, we can identify them with [core_id].
1705        new_cache_shape = [1]
1706        new_cache_shape.extend(cache_variable.shape.as_list())
1707        cache = array_ops.reshape(cache_variable, new_cache_shape)
1708        caches_to_write[cache_name] = cache
1709    # Add step to parameter dictionary.
1710    caches_to_write['step'] = step
1711    # Other options without adding step to parameter dictionary are
1712    #  * host_call_fn = (_write_cache(step, caches_to_write)) : fails as it
1713    #    considers caches_to_write as a single parameter, rather than a keyword
1714    #    parameters.
1715    #  * host_call_fn = (_write_cache(step, **caches_to_write)) : fails with
1716    #    a syntax error.
1717    self._host_call_fn[_TT_HOSTCALL_KEY] = (_write_cache, caches_to_write)
1718
1719  def host_call_deps_and_fn(self):
1720    return self._host_call_fn
1721
1722  def get_traced_op_names(self):
1723    """Returns the set of traced op names."""
1724    return self._traced_op_names
1725
1726  def _trace_execution(self, graph,
1727                       tensor_fetches,
1728                       op_fetches=None,
1729                       on_tpu=True):
1730    """Commong tracing function for both CPU and TPUs.
1731
1732    The caller function should set device_type, num_replicas,
1733    num_replicas_per_host, num_hosts and replica_id before calling
1734    _trace_execution.
1735
1736
1737    Args:
1738      graph: the graph of Ops executed on the TPU.
1739      tensor_fetches: a (list,tuple,or a single object) of tensor fetches
1740        returned by model_fn given to session.run. Function must be provided
1741        with as least one tensor to fetch.
1742      op_fetches: A list of op fetches returned by model_fn given to
1743        session.run. op_fetches and tensor_fetches are used to determine the
1744        nodes that will be executed. Can be None.
1745      on_tpu: True if executing on TPU.
1746
1747    Returns:
1748      tensor_fetches: an exact copy of tensor_fetches that has additional
1749                      dependencies.
1750    Raises:
1751      RuntimeError: If tensor_fetches is None or empty.
1752    """
1753    def _cast_unsupported_dtypes(tensor):
1754      """Casts tensor to a supported type."""
1755
1756      if tensor.dtype.__eq__(dtypes.int64):
1757        # outside-compilation doesn't support int64 input yet.
1758        return math_ops.cast(tensor, dtypes.int32)
1759      if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__(
1760          dtypes.float16):
1761        # Since host can't handle bf16, convert tensor to f32.
1762        return math_ops.cast(tensor, dtypes.float32)
1763      return tensor
1764
1765    trace_mode = self._parameters.trace_mode
1766    device_type = self._tt_config.device_type
1767    # pylint: disable=protected-access
1768    self._outmost_context = graph._get_control_flow_context()
1769    # pylint: enable=protected-access
1770
1771    analytics.track_usage('tensor_tracer', [trace_mode, device_type])
1772    TensorTracer.check_device_type(device_type)
1773    TensorTracer.check_trace_mode(device_type, trace_mode)
1774    # Check in_tensor_fetches, and op_fetches and convert them to lists.
1775    processed_t_fetches = self._process_tensor_fetches(tensor_fetches)
1776    op_fetches = self._process_op_fetches(op_fetches)
1777    all_fetches = op_fetches + [tensor.op for tensor in processed_t_fetches]
1778
1779    # Filter out the operations that won't be executed.
1780    # if fetches=None, then ops_in_exec_path = set(operations)
1781    exec_op_set = self._filter_execution_path_operations(graph.get_operations(),
1782                                                         all_fetches)
1783    # Write report file, and determine the traced tensors.
1784    tensor_trace_order = self._determine_trace_and_create_report(
1785        graph, exec_op_set)
1786
1787    tensor_fetch_set = set(processed_t_fetches)
1788    tracing_ops = []
1789
1790    sorted_exec_op_list = list(exec_op_set)
1791    sorted_exec_op_list.sort(key=lambda op: op.name)
1792    # Trace ops only if they are in the execution path.
1793    for op in sorted_exec_op_list:
1794      for i in range(len(op.outputs)):
1795        out_tensor = op.outputs[i]
1796        tensor_name = out_tensor.name
1797        if tensor_name not in tensor_trace_order.tensorname_to_cache_idx:
1798          continue
1799        self._traced_op_names.add(op.name)
1800        # Create the list of consumers before calling _preprocess_traced_tensor.
1801        # Otherwise, adding control input below, will introduce a cycle in the
1802        # graph.
1803        consumers = out_tensor.consumers()
1804        # Not all consumers may be in the exec path. Filter out the consumers
1805        # to keep the graph simpler.
1806        consumers = [cop for cop in consumers if cop in exec_op_set]
1807
1808        # If there is no consumer of the tensor, there is no need to trace it;
1809        # unless the tensor itself is one of the fetches.
1810        is_a_fetched_tensor = out_tensor in tensor_fetch_set
1811        if (not consumers) and (not is_a_fetched_tensor):
1812          continue
1813
1814        op_control_flow_context = self._get_op_control_flow_context(op)
1815        if op_control_flow_context:
1816          # pylint: disable=protected-access
1817          graph._set_control_flow_context(op_control_flow_context)
1818          # pylint: enable=protected-access
1819
1820        processed_tensors = self._preprocess_traced_tensor(out_tensor)
1821
1822        if on_tpu:
1823          for signature in processed_tensors.keys():
1824            processed_tensors[signature] = _cast_unsupported_dtypes(
1825                processed_tensors[signature])
1826
1827        if self._use_tensor_values_cache():
1828          # Use a small cache (either temp cache or tf local variable) to store
1829          # the characteristics of the tensor.
1830          if self._use_temp_cache():
1831            cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name]
1832            self._save_tensor_value_to_tmp_cache(cache_idx, processed_tensors)
1833            trace_op = None
1834          else:
1835            cache_idx = tensor_trace_order.tensorname_to_cache_idx[tensor_name]
1836            trace_op = self._save_tensor_value_to_cache_op(cache_idx,
1837                                                           processed_tensors)
1838        elif self._use_tensor_buffer():
1839          if len(processed_tensors) != 1:
1840            raise RuntimeError('Multiple stats are only allowed in compact '
1841                               'mode.')
1842          processed_out_tensor = list(processed_tensors.values())[0]
1843          # Store the whole tensor in a buffer.
1844          trace_op = self._snapshot_tensor(processed_out_tensor)
1845        else:
1846
1847          def tpu_wrap_trace_fn(tensor, out_tensor_name):
1848            """Wraps the trace_fn with outside compilation if on TPUs."""
1849            tensor_trace_fn = self._make_tensor_trace_fun(out_tensor_name,
1850                                                          tensor_trace_order)
1851            if on_tpu:
1852              return tpu.outside_compilation(tensor_trace_fn, tensor)
1853            else:
1854              return tensor_trace_fn(tensor)
1855
1856          if len(processed_tensors) != 1:
1857            raise RuntimeError('Multiple stats are only allowed in compact '
1858                               'mode.')
1859          # Collecting multiple statistics are only supported in the summary
1860          # mode that uses compact format(self._use_tensor_values_cache = true).
1861          # Non-compact mode currently allows single stat per tensor.
1862          processed_out_tensor = six.next(six.itervalues(processed_tensors))
1863          trace_op = tpu_wrap_trace_fn(processed_out_tensor, tensor_name)
1864
1865        if op_control_flow_context:
1866          # pylint: disable=protected-access
1867          graph._set_control_flow_context(self._outmost_context)
1868          # pylint: enable=protected-access
1869        if trace_op:
1870          if is_a_fetched_tensor:
1871            tracing_ops.append(trace_op)
1872            continue
1873          # Add it to all consumers, as some consumers may not be executed if
1874          # they are in a control flow.
1875          for consumer_op in consumers:
1876            # pylint: disable=protected-access
1877            consumer_op._add_control_input(trace_op)
1878            # pylint: enable=protected-access
1879
1880    # pylint: disable=protected-access
1881    graph._set_control_flow_context(self._outmost_context)
1882    # pylint: enable=protected-access
1883    if tracing_ops:
1884      # If we are tracing a fetched tensor, their dependency is stored in
1885      # tracing_ops.
1886      processed_t_fetches = control_flow_ops.tuple(processed_t_fetches,
1887                                                   control_inputs=tracing_ops)
1888    if self._use_tensor_values_cache() or self._use_tensor_buffer():
1889      if self._use_temp_cache():
1890        # Create the temporary tf cache variable by concantanating all
1891        # statistics.
1892        self._cache_variables[_TT_SUMMARY_TAG] = array_ops.stack(
1893            self._temp_cache_var, axis=0, name='stack_all_op_signatures')
1894      if self._create_host_call():
1895        self._prepare_host_call_fn(processed_t_fetches, op_fetches)
1896        if not on_tpu:
1897          write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY]
1898          cache_write_op = write_cache(**caches_to_write)
1899          processed_t_fetches = control_flow_ops.tuple(
1900              processed_t_fetches, control_inputs=[cache_write_op])
1901          del self._host_call_fn[_TT_HOSTCALL_KEY]
1902        elif self._parameters.flush_summaries_with_outside_compile:
1903          write_cache, caches_to_write = self._host_call_fn[_TT_HOSTCALL_KEY]
1904          if (_TT_SUMMARY_TAG in caches_to_write and 'step' in caches_to_write):
1905            step = caches_to_write['step']
1906            tensor_tracer_summary = caches_to_write[_TT_SUMMARY_TAG]
1907            tt_core_summary = self.merge_caches_on_tpu(tensor_tracer_summary[0])
1908            if not self._parameters.collect_summary_per_core:
1909              tt_core_summary = self.aggregate_global_cache(tt_core_summary)
1910
1911            def write_if_core_0(step, replica_id, tt_summary):
1912
1913              return control_flow_ops.cond(
1914                  math_ops.equal(replica_id, 0),
1915                  lambda: write_cache(step=step, event_file_suffix=None,  # pylint: disable=g-long-lambda
1916                                      tensor_tracer_summary=tt_summary),
1917                  control_flow_ops.no_op)
1918
1919            write_op = tpu.outside_compilation(write_if_core_0, step=step,
1920                                               replica_id=self._replica_id,
1921                                               tt_summary=tt_core_summary)
1922            processed_t_fetches = control_flow_ops.tuple(
1923                processed_t_fetches, control_inputs=[write_op])
1924            del self._host_call_fn[_TT_HOSTCALL_KEY]
1925          else:
1926            raise ValueError('Outside compiled flush in only supported for '
1927                             'summary mode')
1928      else:
1929        processed_t_fetches = self._flush_tensor_values_cache(
1930            processed_t_fetches, op_fetches, on_tpu=on_tpu,
1931            tensor_trace_order=tensor_trace_order)
1932
1933    # processed_t_fetches is a list at this point. Convert it to the same
1934    # format as given in tensor_fetches.
1935    return self._convert_fetches_to_input_format(tensor_fetches,
1936                                                 processed_t_fetches)
1937
1938  def trace_tpu(self, graph,
1939                tensor_fetches,
1940                op_fetches=None,
1941                num_replicas=None,
1942                num_replicas_per_host=None,
1943                num_hosts=None):
1944    """Traces the tensors generated by TPU Ops in a TF graph.
1945
1946    Args:
1947      graph: the graph of Ops executed on the TPU.
1948      tensor_fetches: a (list,tuple,or a single object) of tensor fetches
1949        returned by model_fn given to session.run. Function must be provided
1950        with as least one tensor to fetch.
1951      op_fetches: A list of op fetches returned by model_fn given to
1952        session.run. op_fetches and tensor_fetches are used to determine the
1953        nodes that will be executed. Can be None.
1954      num_replicas: number of replicas used on the TPU.
1955      num_replicas_per_host: number of replicas per TPU host.
1956      num_hosts: total number of TPU hosts.
1957
1958    Returns:
1959      tensor_fetches: an exact copy of tensor_fetches that has additional
1960                      dependencies.
1961    Raises:
1962      RuntimeError: If num_replicas_per_host > 8.
1963      RuntimeError: If tensor_fetches is None or empty.
1964    """
1965    if isinstance(graph, func_graph.FuncGraph) or isinstance(
1966        graph, function._FuncGraph):  # pylint: disable=protected-access
1967      logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. '
1968                      'Ignoring tracing.')
1969      return tensor_fetches
1970
1971    if graph in TensorTracer._traced_graphs:
1972      logging.warning('Graph is already rewritten with tensor tracer, ignoring '
1973                      'multiple calls.')
1974      return tensor_fetches
1975    else:
1976      TensorTracer._traced_graphs.add(graph)
1977    # Reset the parameters in case parameters are changed.
1978    self._parameters = tensor_tracer_flags.TTParameters()
1979    self._tt_config.device_type = _DEVICE_TYPE_TPU
1980    self._tt_config.num_replicas = num_replicas
1981    self._tt_config.num_replicas_per_host = num_replicas_per_host
1982    self._tt_config.num_hosts = num_hosts
1983    if self._tt_config.num_replicas is not None:
1984      if self._tt_config.num_replicas_per_host is None:
1985        self._tt_config.num_replicas_per_host = 8
1986      if self._tt_config.num_hosts is None:
1987        self._tt_config.num_hosts = (
1988            num_replicas // self._tt_config.num_replicas_per_host +
1989            (num_replicas % self._tt_config.num_replicas_per_host > 0))
1990
1991    if self._parameters.graph_dump_path:
1992      graph_io.write_graph(graph, self._parameters.graph_dump_path,
1993                           'graph_before_tt.pbtxt')
1994    with graph.as_default():
1995      self._add_replica_id_to_graph()
1996      tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches,
1997                                             on_tpu=True)
1998    if self._parameters.graph_dump_path:
1999      graph_io.write_graph(graph, self._parameters.graph_dump_path,
2000                           'graph_after_tt.pbtxt')
2001    return tensor_fetches
2002
2003  def trace_cpu(self, graph, tensor_fetches, op_fetches=None):
2004    """Traces the tensors generated by CPU Ops in a TF graph.
2005
2006    Args:
2007      graph: the graph of Ops executed on the CPU.
2008      tensor_fetches: a (list,tuple,or a single object) of tensor fetches
2009        returned by model_fn given to session.run. Function must be provided
2010        with as least one tensor to fetch.
2011      op_fetches: A list of op fetches returned by model_fn given to
2012        session.run. op_fetches and tensor_fetches are used to determine the
2013        nodes that will be executed. Can be None.
2014
2015    Returns:
2016      tensor_fetches: an exact copy of tensor_fetches that has additional
2017                      dependencies.
2018    Raises:
2019      RuntimeError: If tensor_fetches is None or empty.
2020    """
2021    if isinstance(graph, func_graph.FuncGraph) or isinstance(
2022        graph, function._FuncGraph):  # pylint: disable=protected-access
2023      logging.warning('Tensor Tracer is not supported for tracing FuncGraphs. '
2024                      'Ignoring tracing.')
2025      return tensor_fetches
2026
2027    if graph in TensorTracer._traced_graphs:
2028      logging.warning('Graph is already rewritten with tensor tracer, ignoring '
2029                      'multiple calls.')
2030      return tensor_fetches
2031    else:
2032      TensorTracer._traced_graphs.add(graph)
2033    # Reset the parameters in case parameters are changed.
2034    self._parameters = tensor_tracer_flags.TTParameters()
2035
2036    self._tt_config.device_type = _DEVICE_TYPE_CPU
2037    self._tt_config.num_replicas = 1
2038    self._tt_config.num_replicas_per_host = 1
2039    self._tt_config.num_hosts = 1
2040    self._replica_id = 0
2041    if self._parameters.graph_dump_path:
2042      graph_io.write_graph(graph, self._parameters.graph_dump_path,
2043                           'graph_before_tt.pbtxt')
2044    with graph.as_default():
2045      tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches,
2046                                             on_tpu=False)
2047    if self._parameters.graph_dump_path:
2048      graph_io.write_graph(graph, self._parameters.graph_dump_path,
2049                           'graph_after_tt.pbtxt')
2050    return tensor_fetches
2051