1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ========================================================================
15"""Tensor Tracer report generation utilities."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import hashlib
23import os
24
25
26from tensorflow.python.platform import gfile
27from tensorflow.python.platform import tf_logging as logging
28from tensorflow.python.tpu import tensor_tracer_pb2
29
30_TRACER_LOG_PREFIX = ' [>>>TT>>>]'
31_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:'
32_MARKER_SECTION_END = '!!!!!!! section-end:'
33
34_SECTION_NAME_CONFIG = 'configuration'
35_SECTION_NAME_REASON = 'reason'
36_SECTION_NAME_OP_LIST = 'op-list'
37_SECTION_NAME_TENSOR_LIST = 'tensor-list'
38_SECTION_NAME_CACHE_INDEX_MAP = 'cache-index-map'
39_SECTION_NAME_GRAPH = 'graph'
40_SECTION_NAME_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint'
41
42_FIELD_NAME_VERSION = 'version:'
43_FIELD_NAME_DEVICE = 'device:'
44_FIELD_NAME_TRACE_MODE = 'trace-mode:'
45_FIELD_NAME_SUBMODE = 'submode:'
46_FIELD_NAME_NUM_REPLICAS = 'num-replicas:'
47_FIELD_NAME_NUM_REPLICAS_PER_HOST = 'num-replicas-per-host:'
48_FIELD_NAME_NUM_HOSTS = 'num-hosts:'
49_FIELD_NAME_NUM_OPS = 'number-of-ops:'
50_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:'
51_FIELD_NAME_NUM_CACHE_INDICES = 'number-of-indices:'
52_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:'
53
54_CURRENT_VERSION = 'use-outside-compilation'
55_TT_REPORT_PROTO = 'tensor_tracer_report.report_pb'
56
57
58def report_proto_path(trace_dir):
59  """Returns the path where report proto should be written.
60
61  Args:
62     trace_dir: String denoting the trace directory.
63
64  Returns:
65     A string denoting the path to the report proto.
66  """
67  return os.path.join(trace_dir, _TT_REPORT_PROTO)
68
69
70def topological_sort(g):
71  """Performs topological sort on the given graph.
72
73  Args:
74     g: the graph.
75
76  Returns:
77     A pair where the first element indicates if the topological
78     sort succeeded (True if there is no cycle found; False if a
79     cycle is found) and the second element is either the sorted
80     list of nodes or the cycle of nodes found.
81  """
82  def _is_loop_edge(op):
83    """Returns true if the op is the end of a while-loop creating a cycle."""
84    return op.type in ['NextIteration']
85
86  def _in_op_degree(op):
87    """Returns the number of incoming edges to the given op.
88
89    The edge calculation skips the edges that come from 'NextIteration' ops.
90    NextIteration creates a cycle in the graph. We break cycles by treating
91    this op as 'sink' and ignoring all outgoing edges from it.
92    Args:
93      op: Tf.Operation
94    Returns:
95      the number of incoming edges.
96    """
97    count = 0
98    for op in op.control_inputs + [in_tensor.op for in_tensor in op.inputs]:
99      if not _is_loop_edge(op):
100        count += 1
101    return count
102
103  sorted_ops = []
104  op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()}
105
106  frontier = [op for (op, degree) in op_in_degree.items() if degree == 0]
107  frontier.sort(key=lambda op: op.name)
108  while frontier:
109    op = frontier.pop()
110    # Remove the op from graph, and remove its outgoing edges.
111    sorted_ops.append(op)
112    if _is_loop_edge(op):
113      continue
114    # pylint: disable=protected-access
115    consumers = list(op._control_outputs)
116    # pylint: enable=protected-access
117    for out_tensor in op.outputs:
118      consumers += [consumer_op for consumer_op in out_tensor.consumers()]
119    consumers.sort(key=lambda op: op.name)
120    for consumer in consumers:
121      # For each deleted edge shift the bucket of the vertex.
122      op_in_degree[consumer] -= 1
123      if op_in_degree[consumer] == 0:
124        frontier.append(consumer)
125      if op_in_degree[consumer] < 0:
126        raise ValueError('consumer:%s degree mismatch'%consumer.name)
127
128  left_ops = set(op for (op, degree) in op_in_degree.items() if degree > 0)
129  if left_ops:
130    return (True, left_ops)
131  else:
132    assert len(g.get_operations()) == len(sorted_ops)
133    return (False, sorted_ops)
134
135
136class TensorTracerConfig(object):
137  """Tensor Tracer config object."""
138
139  def __init__(self):
140    self.version = _CURRENT_VERSION
141    self.device_type = None
142    self.num_replicas = None
143    self.num_replicas_per_host = None
144    self.num_hosts = None
145
146
147class TensorTraceOrder(object):
148  """Class that is responsible from storing the trace-id of the tensors."""
149
150  def __init__(self, graph_order, traced_tensors):
151    self.graph_order = graph_order
152    self.traced_tensors = traced_tensors
153    self._create_tensor_maps()
154
155  def _create_tensor_maps(self):
156    """Creates tensor to cache id maps."""
157    self.tensorname_to_cache_idx = {}
158    self.cache_idx_to_tensor_idx = []
159    for out_tensor in self.traced_tensors:
160      tensor_name = out_tensor.name
161      if tensor_name in self.tensorname_to_cache_idx:
162        raise ValueError(
163            'Tensor name %s should not be already in '
164            'tensorname_to_cache_idx'%tensor_name)
165      if tensor_name not in self.graph_order.tensor_to_idx:
166        raise ValueError(
167            'Tensor name %s is not in the tensor_to_idx'%tensor_name)
168      tensor_idx = self.graph_order.tensor_to_idx[tensor_name]
169      cache_idx = len(self.tensorname_to_cache_idx)
170      self.tensorname_to_cache_idx[tensor_name] = cache_idx
171      self.cache_idx_to_tensor_idx.append(tensor_idx)
172      if len(self.tensorname_to_cache_idx) != len(
173          self.cache_idx_to_tensor_idx):
174        raise RuntimeError('len(self.tensorname_to_cache_idx) != '
175                           'len(self.cache_idx_to_tensor_idx')
176
177
178def sort_tensors_and_ops(graph):
179  """Returns a wrapper that has consistent tensor and op orders."""
180  graph_wrapper = collections.namedtuple('GraphWrapper',
181                                         ['graph', 'operations', 'op_to_idx',
182                                          'tensors', 'tensor_to_idx',
183                                          'contains_cycle',
184                                          'topological_order_or_cycle'])
185  contains_cycle, topological_order_or_cycle = topological_sort(graph)
186  if not contains_cycle:
187    operations = topological_order_or_cycle
188  else:
189    operations = graph.get_operations()
190  op_to_idx = {op.name: index for index, op
191               in enumerate(operations)}
192  tensors = []
193  for op in operations:
194    tensors.extend(op.outputs)
195  tensor_to_idx = {tensor.name: index for index, tensor in
196                   enumerate(tensors)}
197  return graph_wrapper(graph=graph, operations=operations, op_to_idx=op_to_idx,
198                       tensors=tensors, tensor_to_idx=tensor_to_idx,
199                       contains_cycle=contains_cycle,
200                       topological_order_or_cycle=topological_order_or_cycle)
201
202
203class OpenReportFile(object):
204  """Context manager for writing report file."""
205
206  def __init__(self, tt_parameters):
207    if not tt_parameters.report_file_path:
208      self._report_file = None
209      return
210    try:
211      self._report_file = gfile.Open(tt_parameters.report_file_path, 'w')
212    except IOError as e:
213      raise e
214
215  def __enter__(self):
216    return self._report_file
217
218  def __exit__(self, unused_type, unused_value, unused_traceback):
219    if self._report_file:
220      self._report_file.close()
221
222
223def proto_fingerprint(message_proto):
224  serialized_message = message_proto.SerializeToString()
225  hasher = hashlib.sha256(serialized_message)
226  return hasher.hexdigest()
227
228
229class TTReportHandle(object):
230  """Utility class responsible from creating a tensor tracer report."""
231
232  def __init__(self):
233    self.instrument_records = {}
234    self._report_file = None
235
236  def instrument(self, name, explanation):
237    self.instrument_records[name] = explanation
238
239  def instrument_op(self, op, explanation):
240    self.instrument(op.name, explanation)
241
242  def instrument_tensor(self, tensor, explanation):
243    self.instrument(tensor.name, explanation)
244
245  def create_report_proto(self, tt_config, tt_parameters, tensor_trace_order,
246                          tensor_trace_points, collected_signature_types):
247    """Creates and returns a proto that stores tensor tracer configuration.
248
249    Args:
250      tt_config: TensorTracerConfig object holding information about the run
251        environment (device, # cores, # hosts), and tensor tracer version
252        information.
253      tt_parameters: TTParameters objects storing the user provided parameters
254        for tensor tracer.
255      tensor_trace_order: TensorTraceOrder object storing a topological order of
256        the graph.
257      tensor_trace_points: Progromatically added trace_points/checkpoints.
258      collected_signature_types: The signature types collected, e,g, norm,
259        max, min, mean...
260    Returns:
261      TensorTracerReport proto.
262    """
263    report = tensor_tracer_pb2.TensorTracerReport()
264    report.config.version = tt_config.version
265    report.config.device = tt_config.device_type
266    report.config.num_cores = tt_config.num_replicas
267    report.config.num_hosts = tt_config.num_hosts
268    report.config.num_cores_per_host = tt_config.num_replicas_per_host
269    report.config.submode = tt_parameters.submode
270    report.config.trace_mode = tt_parameters.trace_mode
271
272    for signature_name, _ in sorted(collected_signature_types.items(),
273                                    key=lambda x: x[1]):
274      report.config.signatures.append(signature_name)
275
276    for tensor in tensor_trace_order.graph_order.tensors:
277      tensor_def = tensor_tracer_pb2.TensorTracerReport.TracedTensorDef()
278      tensor_def.name = tensor.name
279      if tensor.name in tensor_trace_order.tensorname_to_cache_idx:
280        tensor_def.is_traced = True
281        tensor_def.cache_index = (
282            tensor_trace_order.tensorname_to_cache_idx[tensor.name])
283      else:
284        # To prevent small changes affecting the fingerprint calculation, avoid
285        # writing the untraced tensors to metadata. Fingerprints will be
286        # different only when the list of the traced tensors are different.
287        if tt_parameters.use_fingerprint_subdir:
288          continue
289        tensor_def.is_traced = False
290
291      if tensor.name in tensor_trace_points:
292        tensor_def.trace_point_name = tensor_trace_points[tensor.name]
293      if tensor.name in self.instrument_records:
294        tensor_def.explanation = self.instrument_records[tensor.name]
295      elif tensor.op.name in self.instrument_records:
296        tensor_def.explanation = self.instrument_records[tensor.op.name]
297      report.tensordef[tensor.name].CopyFrom(tensor_def)
298    report.fingerprint = proto_fingerprint(report)
299    logging.info('TensorTracerProto fingerprint is %s.',
300                 report.fingerprint)
301    tf_graph = tensor_trace_order.graph_order.graph
302    report.graphdef.CopyFrom(tf_graph.as_graph_def())
303    return report
304
305  def write_report_proto(self, report_proto, tt_parameters):
306    """Writes the given report proto under trace_dir."""
307    gfile.MakeDirs(tt_parameters.trace_dir)
308    report_path = report_proto_path(tt_parameters.trace_dir)
309    with gfile.GFile(report_path, 'wb') as f:
310      f.write(report_proto.SerializeToString())
311
312  def create_report(self, tt_config, tt_parameters,
313                    tensor_trace_order, tensor_trace_points):
314    """Creates a report file and writes the trace information."""
315    with OpenReportFile(tt_parameters) as self._report_file:
316      self._write_config_section(tt_config, tt_parameters)
317      self._write_op_list_section(tensor_trace_order.graph_order)
318      self._write_tensor_list_section(tensor_trace_order.graph_order)
319      self._write_trace_points(tensor_trace_points)
320      self._write_cache_index_map_section(tensor_trace_order)
321      self._write_reason_section()
322      self._write_graph_section(tensor_trace_order.graph_order)
323
324  def _write_trace_points(self, tensor_trace_points):
325    """Writes the list of checkpoints."""
326    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
327                                  _SECTION_NAME_TENSOR_TRACER_CHECKPOINT))
328    for (tensor, checkpoint_name) in tensor_trace_points:
329      self._write_report('%s %s\n'%(tensor.name, checkpoint_name))
330    self._write_report('%s %s\n'%(_MARKER_SECTION_END,
331                                  _SECTION_NAME_TENSOR_TRACER_CHECKPOINT))
332
333  def _write_report(self, content):
334    """Writes the given content to the report."""
335
336    line = '%s %s'%(_TRACER_LOG_PREFIX, content)
337    if self._report_file:
338      self._report_file.write(line)
339    else:
340      logging.info(line)
341
342  def _write_config_section(self, tt_config, tt_parameters):
343    """Writes the config section of the report."""
344
345    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG))
346    self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, tt_config.version))
347    self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, tt_config.device_type))
348    self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE,
349                                  tt_parameters.trace_mode))
350    self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE,
351                                  tt_parameters.submode))
352    self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS,
353                                  tt_config.num_replicas))
354    self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS_PER_HOST,
355                                  tt_config.num_replicas_per_host))
356    self._write_report('%s %s\n'%(_FIELD_NAME_NUM_HOSTS, tt_config.num_hosts))
357    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG))
358
359  def _write_reason_section(self):
360    """Writes the reason section of the report."""
361
362    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON))
363    for key in sorted(self.instrument_records):
364      self._write_report('"%s" %s\n'%(key, self.instrument_records[key]))
365    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON))
366
367  def _write_op_list_section(self, graph_order):
368    """Writes the Op-list section of the report."""
369
370    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST))
371    self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS,
372                                  len(graph_order.operations)))
373    for i in range(0, len(graph_order.operations)):
374      op = graph_order.operations[i]
375      line = '%d "%s" %s'%(i, op.name, op.type)
376      for out_tensor in op.outputs:
377        if out_tensor.name not in graph_order.tensor_to_idx:
378          raise ValueError(
379              'out_tensor %s is not in tensor_to_idx'%out_tensor.name)
380        line += ' %d'%graph_order.tensor_to_idx[out_tensor.name]
381      line += '\n'
382      self._write_report(line)
383    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST))
384
385  def _write_tensor_list_section(self, graph_order):
386    """Writes the tensor-list section of the report."""
387
388    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
389                                  _SECTION_NAME_TENSOR_LIST))
390    self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS,
391                                  len(graph_order.tensors)))
392    for i in range(0, len(graph_order.tensors)):
393      tensor = graph_order.tensors[i]
394      line = '%d "%s"'%(i, tensor.name)
395      consumers = tensor.consumers()
396      consumers.sort(key=lambda op: op.name)
397      for consumer_op in consumers:
398        if consumer_op.name not in graph_order.op_to_idx:
399          raise ValueError(
400              'consumer_op %s is not in op_to_idx'%consumer_op.name)
401        line += ' %d'%graph_order.op_to_idx[consumer_op.name]
402      line += '\n'
403      self._write_report(line)
404    self._write_report('%s %s\n'%(_MARKER_SECTION_END,
405                                  _SECTION_NAME_TENSOR_LIST))
406
407  def _write_cache_index_map_section(self, tensor_trace_order):
408    """Writes the mapping from cache index to tensor index to the report."""
409    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
410                                  _SECTION_NAME_CACHE_INDEX_MAP))
411    self._write_report('%s %d\n'%(
412        _FIELD_NAME_NUM_CACHE_INDICES,
413        len(tensor_trace_order.cache_idx_to_tensor_idx)))
414    for cache_idx in range(0, len(tensor_trace_order.cache_idx_to_tensor_idx)):
415      tensor_idx = tensor_trace_order.cache_idx_to_tensor_idx[cache_idx]
416      line = '%d %d\n'%(cache_idx, tensor_idx)
417      self._write_report(line)
418    self._write_report('%s %s\n'%(_MARKER_SECTION_END,
419                                  _SECTION_NAME_CACHE_INDEX_MAP))
420
421  def _write_graph_section(self, graph_order):
422    """Writes the graph section of the report."""
423
424    self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH))
425    self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED,
426                                  not graph_order.contains_cycle))
427    l = list(graph_order.topological_order_or_cycle)
428    for i in range(0, len(l)):
429      self._write_report('%d "%s"\n'%(i, l[i].name))
430    self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH))
431