1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Timeline visualization for TensorFlow using Chrome Trace Format."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import json
24import re
25
26# The timeline target is usually imported as part of BUILD target
27# "platform_test", which includes also includes the "platform"
28# dependency.  This is why the logging import here is okay.
29from tensorflow.python.platform import tf_logging as logging
30
31
32class AllocationMaximum(collections.namedtuple(
33    'AllocationMaximum', ('timestamp', 'num_bytes', 'tensors'))):
34  """Stores the maximum allocation for a given allocator within the timelne.
35
36  Parameters:
37    timestamp: `tensorflow::Env::NowMicros()` when this maximum was reached.
38    num_bytes: the total memory used at this time.
39    tensors: the set of tensors allocated at this time.
40  """
41  pass
42
43
44class StepStatsAnalysis(collections.namedtuple(
45    'StepStatsAnalysis', ('chrome_trace', 'allocator_maximums'))):
46  """Stores the step stats analysis output.
47
48  Parameters:
49    chrome_trace: A dict containing the chrome trace analysis.
50    allocator_maximums: A dict mapping allocator names to AllocationMaximum.
51  """
52  pass
53
54
55class _ChromeTraceFormatter(object):
56  """A helper class for generating traces in Chrome Trace Format."""
57
58  def __init__(self, show_memory=False):
59    """Constructs a new Chrome Trace formatter."""
60    self._show_memory = show_memory
61    self._events = []
62    self._metadata = []
63
64  def _create_event(self, ph, category, name, pid, tid, timestamp):
65    """Creates a new Chrome Trace event.
66
67    For details of the file format, see:
68    https://github.com/catapult-project/catapult/blob/master/tracing/README.md
69
70    Args:
71      ph:  The type of event - usually a single character.
72      category: The event category as a string.
73      name:  The event name as a string.
74      pid:  Identifier of the process generating this event as an integer.
75      tid:  Identifier of the thread generating this event as an integer.
76      timestamp:  The timestamp of this event as a long integer.
77
78    Returns:
79      A JSON compatible event object.
80    """
81    event = {}
82    event['ph'] = ph
83    event['cat'] = category
84    event['name'] = name
85    event['pid'] = pid
86    event['tid'] = tid
87    event['ts'] = timestamp
88    return event
89
90  def emit_pid(self, name, pid):
91    """Adds a process metadata event to the trace.
92
93    Args:
94      name:  The process name as a string.
95      pid:  Identifier of the process as an integer.
96    """
97    event = {}
98    event['name'] = 'process_name'
99    event['ph'] = 'M'
100    event['pid'] = pid
101    event['args'] = {'name': name}
102    self._metadata.append(event)
103
104  def emit_tid(self, name, pid, tid):
105    """Adds a thread metadata event to the trace.
106
107    Args:
108      name:  The thread name as a string.
109      pid:  Identifier of the process as an integer.
110      tid:  Identifier of the thread as an integer.
111    """
112    event = {}
113    event['name'] = 'thread_name'
114    event['ph'] = 'M'
115    event['pid'] = pid
116    event['tid'] = tid
117    event['args'] = {'name': name}
118    self._metadata.append(event)
119
120  def emit_region(self, timestamp, duration, pid, tid, category, name, args):
121    """Adds a region event to the trace.
122
123    Args:
124      timestamp:  The start timestamp of this region as a long integer.
125      duration:  The duration of this region as a long integer.
126      pid:  Identifier of the process generating this event as an integer.
127      tid:  Identifier of the thread generating this event as an integer.
128      category: The event category as a string.
129      name:  The event name as a string.
130      args:  A JSON-compatible dictionary of event arguments.
131    """
132    event = self._create_event('X', category, name, pid, tid, timestamp)
133    event['dur'] = duration
134    event['args'] = args
135    self._events.append(event)
136
137  def emit_obj_create(self, category, name, timestamp, pid, tid, object_id):
138    """Adds an object creation event to the trace.
139
140    Args:
141      category: The event category as a string.
142      name:  The event name as a string.
143      timestamp:  The timestamp of this event as a long integer.
144      pid:  Identifier of the process generating this event as an integer.
145      tid:  Identifier of the thread generating this event as an integer.
146      object_id: Identifier of the object as an integer.
147    """
148    event = self._create_event('N', category, name, pid, tid, timestamp)
149    event['id'] = object_id
150    self._events.append(event)
151
152  def emit_obj_delete(self, category, name, timestamp, pid, tid, object_id):
153    """Adds an object deletion event to the trace.
154
155    Args:
156      category: The event category as a string.
157      name:  The event name as a string.
158      timestamp:  The timestamp of this event as a long integer.
159      pid:  Identifier of the process generating this event as an integer.
160      tid:  Identifier of the thread generating this event as an integer.
161      object_id: Identifier of the object as an integer.
162    """
163    event = self._create_event('D', category, name, pid, tid, timestamp)
164    event['id'] = object_id
165    self._events.append(event)
166
167  def emit_obj_snapshot(self, category, name, timestamp, pid, tid, object_id,
168                        snapshot):
169    """Adds an object snapshot event to the trace.
170
171    Args:
172      category: The event category as a string.
173      name:  The event name as a string.
174      timestamp:  The timestamp of this event as a long integer.
175      pid:  Identifier of the process generating this event as an integer.
176      tid:  Identifier of the thread generating this event as an integer.
177      object_id: Identifier of the object as an integer.
178      snapshot:  A JSON-compatible representation of the object.
179    """
180    event = self._create_event('O', category, name, pid, tid, timestamp)
181    event['id'] = object_id
182    event['args'] = {'snapshot': snapshot}
183    self._events.append(event)
184
185  def emit_flow_start(self, name, timestamp, pid, tid, flow_id):
186    """Adds a flow start event to the trace.
187
188    When matched with a flow end event (with the same 'flow_id') this will
189    cause the trace viewer to draw an arrow between the start and end events.
190
191    Args:
192      name:  The event name as a string.
193      timestamp:  The timestamp of this event as a long integer.
194      pid:  Identifier of the process generating this event as an integer.
195      tid:  Identifier of the thread generating this event as an integer.
196      flow_id: Identifier of the flow as an integer.
197    """
198    event = self._create_event('s', 'DataFlow', name, pid, tid, timestamp)
199    event['id'] = flow_id
200    self._events.append(event)
201
202  def emit_flow_end(self, name, timestamp, pid, tid, flow_id):
203    """Adds a flow end event to the trace.
204
205    When matched with a flow start event (with the same 'flow_id') this will
206    cause the trace viewer to draw an arrow between the start and end events.
207
208    Args:
209      name:  The event name as a string.
210      timestamp:  The timestamp of this event as a long integer.
211      pid:  Identifier of the process generating this event as an integer.
212      tid:  Identifier of the thread generating this event as an integer.
213      flow_id: Identifier of the flow as an integer.
214    """
215    event = self._create_event('t', 'DataFlow', name, pid, tid, timestamp)
216    event['id'] = flow_id
217    self._events.append(event)
218
219  def emit_counter(self, category, name, pid, timestamp, counter, value):
220    """Emits a record for a single counter.
221
222    Args:
223      category: The event category as a string.
224      name:  The event name as a string.
225      pid:  Identifier of the process generating this event as an integer.
226      timestamp:  The timestamp of this event as a long integer.
227      counter: Name of the counter as a string.
228      value:  Value of the counter as an integer.
229    """
230    event = self._create_event('C', category, name, pid, 0, timestamp)
231    event['args'] = {counter: value}
232    self._events.append(event)
233
234  def emit_counters(self, category, name, pid, timestamp, counters):
235    """Emits a counter record for the dictionary 'counters'.
236
237    Args:
238      category: The event category as a string.
239      name:  The event name as a string.
240      pid:  Identifier of the process generating this event as an integer.
241      timestamp:  The timestamp of this event as a long integer.
242      counters: Dictionary of counter values.
243    """
244    event = self._create_event('C', category, name, pid, 0, timestamp)
245    event['args'] = counters.copy()
246    self._events.append(event)
247
248  def format_to_string(self, pretty=False):
249    """Formats the chrome trace to a string.
250
251    Args:
252      pretty: (Optional.)  If True, produce human-readable JSON output.
253
254    Returns:
255      A JSON-formatted string in Chrome Trace format.
256    """
257    trace = {}
258    trace['traceEvents'] = self._metadata + self._events
259    if pretty:
260      return json.dumps(trace, indent=4, separators=(',', ': '))
261    else:
262      return json.dumps(trace, separators=(',', ':'))
263
264
265class _TensorTracker(object):
266  """An internal class to track the lifetime of a Tensor."""
267
268  def __init__(self, name, object_id, timestamp, pid, allocator, num_bytes):
269    """Creates an object to track tensor references.
270
271    This class is not thread safe and is intended only for internal use by
272    the 'Timeline' class in this file.
273
274    Args:
275      name:  The name of the Tensor as a string.
276      object_id:  Chrome Trace object identifier assigned for this Tensor.
277      timestamp:  The creation timestamp of this event as a long integer.
278      pid:  Process identifier of the associated device, as an integer.
279      allocator:  Name of the allocator used to create the Tensor.
280      num_bytes:  Number of bytes allocated (long integer).
281
282    Returns:
283      A 'TensorTracker' object.
284    """
285    self._name = name
286    self._pid = pid
287    self._object_id = object_id
288    self._create_time = timestamp
289    self._allocator = allocator
290    self._num_bytes = num_bytes
291    self._ref_times = []
292    self._unref_times = []
293
294  @property
295  def name(self):
296    """Name of this tensor."""
297    return self._name
298
299  @property
300  def pid(self):
301    """ID of the process which created this tensor (an integer)."""
302    return self._pid
303
304  @property
305  def create_time(self):
306    """Timestamp when this tensor was created (long integer)."""
307    return self._create_time
308
309  @property
310  def object_id(self):
311    """Returns the object identifier of this tensor (integer)."""
312    return self._object_id
313
314  @property
315  def num_bytes(self):
316    """Size of this tensor in bytes (long integer)."""
317    return self._num_bytes
318
319  @property
320  def allocator(self):
321    """Name of the allocator used to create this tensor (string)."""
322    return self._allocator
323
324  @property
325  def last_unref(self):
326    """Last unreference timestamp of this tensor (long integer)."""
327    return max(self._unref_times)
328
329  def add_ref(self, timestamp):
330    """Adds a reference to this tensor with the specified timestamp.
331
332    Args:
333      timestamp:  Timestamp of object reference as an integer.
334    """
335    self._ref_times.append(timestamp)
336
337  def add_unref(self, timestamp):
338    """Adds an unref to this tensor with the specified timestamp.
339
340    Args:
341      timestamp:  Timestamp of object unreference as an integer.
342    """
343    self._unref_times.append(timestamp)
344
345
346class Timeline(object):
347  """A class for visualizing execution timelines of TensorFlow steps."""
348
349  def __init__(self, step_stats, graph=None):
350    """Constructs a new Timeline.
351
352    A 'Timeline' is used for visualizing the execution of a TensorFlow
353    computation.  It shows the timings and concurrency of execution at
354    the granularity of TensorFlow Ops.
355    This class is not thread safe.
356
357    Args:
358      step_stats: The 'StepStats' proto recording execution times.
359      graph: (Optional) The 'Graph' that was executed.
360    """
361
362    self._origin_step_stats = step_stats
363    self._step_stats = None
364    self._graph = graph
365    self._chrome_trace = _ChromeTraceFormatter()
366    self._next_pid = 0
367    self._device_pids = {}  # device name -> pid for compute activity.
368    self._tensor_pids = {}  # device name -> pid for tensors.
369    self._tensors = {}  # tensor_name -> TensorTracker
370    self._next_flow_id = 0
371    self._flow_starts = {}  # tensor_name -> (timestamp, pid, tid)
372    self._alloc_times = {}  # tensor_name -> ( time, allocator, size )
373    self._allocator_maximums = {}  # allocator name => maximum bytes long
374
375  def _alloc_pid(self):
376    """Allocate a process Id."""
377    pid = self._next_pid
378    self._next_pid += 1
379    return pid
380
381  def _alloc_flow_id(self):
382    """Allocate a flow Id."""
383    flow_id = self._next_flow_id
384    self._next_flow_id += 1
385    return flow_id
386
387  def _parse_op_label(self, label):
388    """Parses the fields in a node timeline label."""
389    # Expects labels of the form: name = op(arg, arg, ...).
390    match = re.match(r'(.*) = (.*)\((.*)\)', label)
391    if match is None:
392      return 'unknown', 'unknown', []
393    nn, op, inputs = match.groups()
394    if not inputs:
395      inputs = []
396    else:
397      inputs = inputs.split(', ')
398    return nn, op, inputs
399
400  def _parse_kernel_label(self, label, node_name):
401    """Parses the fields in a node timeline label."""
402    # Expects labels of the form: retval (arg) detail @@annotation
403    start = label.find('@@')
404    end = label.find('#')
405    if start >= 0 and end >= 0 and start + 2 < end:
406      node_name = label[start + 2:end]
407    # Node names should always have the form 'name:op'.
408    fields = node_name.split(':') + ['unknown']
409    name, op = fields[:2]
410    return name, op
411
412  def _assign_lanes(self):
413    """Assigns non-overlapping lanes for the activities on each device."""
414    for device_stats in self._step_stats.dev_stats:
415      # TODO(pbar): Genuine thread IDs in NodeExecStats might be helpful.
416      lanes = [0]
417      for ns in device_stats.node_stats:
418        l = -1
419        for (i, lts) in enumerate(lanes):
420          if ns.all_start_micros > lts:
421            l = i
422            lanes[l] = ns.all_start_micros + ns.all_end_rel_micros
423            break
424        if l < 0:
425          l = len(lanes)
426          lanes.append(ns.all_start_micros + ns.all_end_rel_micros)
427        ns.thread_id = l
428
429  def _emit_op(self, nodestats, pid, is_gputrace):
430    """Generates a Chrome Trace event to show Op execution.
431
432    Args:
433      nodestats: The 'NodeExecStats' proto recording op execution.
434      pid: The pid assigned for the device where this op ran.
435      is_gputrace: If True then this op came from the GPUTracer.
436    """
437    node_name = nodestats.node_name
438    start = nodestats.all_start_micros
439    duration = nodestats.all_end_rel_micros
440    tid = nodestats.thread_id
441    inputs = []
442    if is_gputrace:
443      node_name, op = self._parse_kernel_label(nodestats.timeline_label,
444                                               node_name)
445    elif node_name == 'RecvTensor':
446      # RPC tracing does not use the standard timeline_label format.
447      op = 'RecvTensor'
448    else:
449      _, op, inputs = self._parse_op_label(nodestats.timeline_label)
450    args = {'name': node_name, 'op': op}
451    for i, iname in enumerate(inputs):
452      args['input%d' % i] = iname
453    self._chrome_trace.emit_region(start, duration, pid, tid, 'Op', op, args)
454
455  def _emit_tensor_snapshot(self, tensor, timestamp, pid, tid, value):
456    """Generate Chrome Trace snapshot event for a computed Tensor.
457
458    Args:
459      tensor: A 'TensorTracker' object.
460      timestamp:  The timestamp of this snapshot as a long integer.
461      pid: The pid assigned for showing the device where this op ran.
462      tid: The tid of the thread computing the tensor snapshot.
463      value: A JSON-compliant snapshot of the object.
464    """
465    desc = str(value.tensor_description).replace('"', '')
466    snapshot = {'tensor_description': desc}
467    self._chrome_trace.emit_obj_snapshot('Tensor', tensor.name, timestamp, pid,
468                                         tid, tensor.object_id, snapshot)
469
470  def _produce_tensor(self, name, timestamp, tensors_pid, allocator, num_bytes):
471    object_id = len(self._tensors)
472    tensor = _TensorTracker(name, object_id, timestamp, tensors_pid, allocator,
473                            num_bytes)
474    self._tensors[name] = tensor
475    return tensor
476
477  def _is_gputrace_device(self, device_name):
478    """Returns true if this device is part of the GPUTracer logging."""
479    return '/stream:' in device_name or '/memcpy' in device_name
480
481  def _allocate_pids(self):
482    """Allocate fake process ids for each device in the StepStats."""
483    self._allocators_pid = self._alloc_pid()
484    self._chrome_trace.emit_pid('Allocators', self._allocators_pid)
485
486    # Add processes in the Chrome trace to show compute and data activity.
487    for dev_stats in self._step_stats.dev_stats:
488      device_pid = self._alloc_pid()
489      self._device_pids[dev_stats.device] = device_pid
490      tensors_pid = self._alloc_pid()
491      self._tensor_pids[dev_stats.device] = tensors_pid
492      self._chrome_trace.emit_pid(dev_stats.device + ' Compute', device_pid)
493      self._chrome_trace.emit_pid(dev_stats.device + ' Tensors', tensors_pid)
494
495  def _analyze_tensors(self, show_memory):
496    """Analyze tensor references to track dataflow."""
497    for dev_stats in self._step_stats.dev_stats:
498      device_pid = self._device_pids[dev_stats.device]
499      tensors_pid = self._tensor_pids[dev_stats.device]
500      for node_stats in dev_stats.node_stats:
501        tid = node_stats.thread_id
502        node_name = node_stats.node_name
503        start_time = node_stats.all_start_micros
504        end_time = node_stats.all_start_micros + node_stats.all_end_rel_micros
505        for index, output in enumerate(node_stats.output):
506          if index:
507            output_name = '%s:%d' % (node_name, index)
508          else:
509            output_name = node_name
510
511          allocation = output.tensor_description.allocation_description
512          num_bytes = allocation.requested_bytes
513          allocator_name = allocation.allocator_name
514          tensor = self._produce_tensor(output_name, start_time, tensors_pid,
515                                        allocator_name, num_bytes)
516          tensor.add_ref(start_time)
517          tensor.add_unref(end_time)
518          self._flow_starts[output_name] = (end_time, device_pid, tid)
519
520          if show_memory:
521            self._chrome_trace.emit_obj_create('Tensor', output_name,
522                                               start_time, tensors_pid, tid,
523                                               tensor.object_id)
524            self._emit_tensor_snapshot(tensor, end_time - 1, tensors_pid, tid,
525                                       output)
526
527  def _show_compute(self, show_dataflow):
528    """Visualize the computation activity."""
529    for dev_stats in self._step_stats.dev_stats:
530      device_name = dev_stats.device
531      device_pid = self._device_pids[device_name]
532      is_gputrace = self._is_gputrace_device(device_name)
533
534      for node_stats in dev_stats.node_stats:
535        tid = node_stats.thread_id
536        start_time = node_stats.all_start_micros
537        end_time = node_stats.all_start_micros + node_stats.all_end_rel_micros
538        self._emit_op(node_stats, device_pid, is_gputrace)
539
540        if is_gputrace or node_stats.node_name == 'RecvTensor':
541          continue
542
543        _, _, inputs = self._parse_op_label(node_stats.timeline_label)
544        for input_name in inputs:
545          if input_name not in self._tensors:
546            # This can happen when partitioning has inserted a Send/Recv.
547            # We remove the numeric suffix so that the dataflow appears to
548            # come from the original node.  Ideally, the StepStats would
549            # contain logging for the Send and Recv nodes.
550            index = input_name.rfind('/_')
551            if index > 0:
552              input_name = input_name[:index]
553
554          if input_name in self._tensors:
555            tensor = self._tensors[input_name]
556            tensor.add_ref(start_time)
557            tensor.add_unref(end_time - 1)
558
559            if show_dataflow:
560              # We use a different flow ID for every graph edge.
561              create_time, create_pid, create_tid = self._flow_starts[
562                  input_name]
563              # Don't add flows when producer and consumer ops are on the same
564              # pid/tid since the horizontal arrows clutter the visualization.
565              if create_pid != device_pid or create_tid != tid:
566                flow_id = self._alloc_flow_id()
567                self._chrome_trace.emit_flow_start(input_name, create_time,
568                                                   create_pid, create_tid,
569                                                   flow_id)
570                self._chrome_trace.emit_flow_end(input_name, start_time,
571                                                 device_pid, tid, flow_id)
572          else:
573            logging.vlog(1, 'Can\'t find tensor %s - removed by CSE?',
574                         input_name)
575
576  def _show_memory_counters(self):
577    """Produce a counter series for each memory allocator."""
578    # Iterate over all tensor trackers to build a list of allocations and
579    # frees for each allocator. Then sort the lists and emit a cumulative
580    # counter series for each allocator.
581    allocations = {}
582    for name in self._tensors:
583      tensor = self._tensors[name]
584      self._chrome_trace.emit_obj_delete('Tensor', name, tensor.last_unref,
585                                         tensor.pid, 0, tensor.object_id)
586      allocator = tensor.allocator
587      if allocator not in allocations:
588        allocations[allocator] = []
589      num_bytes = tensor.num_bytes
590      allocations[allocator].append((tensor.create_time, num_bytes, name))
591      allocations[allocator].append((tensor.last_unref, -num_bytes, name))
592
593    alloc_maxes = {}
594
595    # Generate a counter series showing total allocations for each allocator.
596    for allocator in allocations:
597      alloc_list = allocations[allocator]
598      alloc_list.sort()
599      total_bytes = 0
600      alloc_tensor_set = set()
601      alloc_maxes[allocator] = AllocationMaximum(
602          timestamp=0, num_bytes=0, tensors=set())
603      for time, num_bytes, name in sorted(
604          alloc_list, key=lambda allocation: allocation[0]):
605        total_bytes += num_bytes
606        if num_bytes < 0:
607          alloc_tensor_set.discard(name)
608        else:
609          alloc_tensor_set.add(name)
610
611        if total_bytes > alloc_maxes[allocator].num_bytes:
612          alloc_maxes[allocator] = AllocationMaximum(
613              timestamp=time,
614              num_bytes=total_bytes,
615              tensors=copy.deepcopy(alloc_tensor_set))
616
617        self._chrome_trace.emit_counter('Memory', allocator,
618                                        self._allocators_pid, time, allocator,
619                                        total_bytes)
620    self._allocator_maximums = alloc_maxes
621
622  def _preprocess_op_time(self, op_time):
623    """Update the start and end time of ops in step stats.
624
625    Args:
626    op_time: How the execution time of op is shown in timeline. Possible values
627      are "schedule", "gpu" and "all". "schedule" will show op from the time it
628      is scheduled to the end of the scheduling. Notice by the end of its
629      scheduling its async kernels may not start yet. It is shown using the
630      default value from step_stats. "gpu" will show op with the execution time
631      of its kernels on GPU. "all" will show op from the start of its scheduling
632      to the end of its last kernel.
633    """
634    if op_time == 'schedule':
635      self._step_stats = self._origin_step_stats
636      return
637    self._step_stats = copy.deepcopy(self._origin_step_stats)
638    # Separate job task and gpu tracer stream
639    stream_all_stats = []
640    job_stats = []
641    for stats in self._step_stats.dev_stats:
642      if '/stream:all' in stats.device:
643        stream_all_stats.append(stats)
644      elif '/job' in stats.device:
645        job_stats.append(stats)
646
647    # Record the start time of the first kernel and the end time of
648    # the last gpu kernel for all ops.
649    op_gpu_start = {}
650    op_gpu_end = {}
651    for stats in stream_all_stats:
652      for kernel in stats.node_stats:
653        name, _ = self._parse_kernel_label(kernel.timeline_label,
654                                           kernel.node_name)
655        start = kernel.all_start_micros
656        end = kernel.all_start_micros + kernel.all_end_rel_micros
657        if name in op_gpu_start:
658          op_gpu_start[name] = min(op_gpu_start[name], start)
659          op_gpu_end[name] = max(op_gpu_end[name], end)
660        else:
661          op_gpu_start[name] = start
662          op_gpu_end[name] = end
663
664    # Update the start and end time of each op according to the op_time
665    for stats in job_stats:
666      for op in stats.node_stats:
667        if op.node_name in op_gpu_start:
668          end = max(op_gpu_end[op.node_name],
669                    op.all_start_micros + op.all_end_rel_micros)
670          if op_time == 'gpu':
671            op.all_start_micros = op_gpu_start[op.node_name]
672          op.all_end_rel_micros = end - op.all_start_micros
673
674  def analyze_step_stats(self,
675                         show_dataflow=True,
676                         show_memory=True,
677                         op_time='schedule'):
678    """Analyze the step stats and format it into Chrome Trace Format.
679
680    Args:
681      show_dataflow: (Optional.) If True, add flow events to the trace
682        connecting producers and consumers of tensors.
683      show_memory: (Optional.) If True, add object snapshot events to the trace
684        showing the sizes and lifetimes of tensors.
685      op_time: (Optional.) How the execution time of op is shown in timeline.
686        Possible values are "schedule", "gpu" and "all". "schedule" will show op
687        from the time it is scheduled to the end of the scheduling. Notice by
688        the end of its scheduling its async kernels may not start yet. It is
689        shown using the default value from step_stats. "gpu" will show op with
690        the execution time of its kernels on GPU. "all" will show op from the
691        start of its scheduling to the end of its last kernel.
692
693    Returns:
694      A 'StepStatsAnalysis' object.
695    """
696    self._preprocess_op_time(op_time)
697    self._allocate_pids()
698    self._assign_lanes()
699    self._analyze_tensors(show_memory)
700    self._show_compute(show_dataflow)
701    if show_memory:
702      self._show_memory_counters()
703    return StepStatsAnalysis(
704        chrome_trace=self._chrome_trace,
705        allocator_maximums=self._allocator_maximums)
706
707  def generate_chrome_trace_format(self,
708                                   show_dataflow=True,
709                                   show_memory=False,
710                                   op_time='schedule'):
711    """Produces a trace in Chrome Trace Format.
712
713    Args:
714      show_dataflow: (Optional.) If True, add flow events to the trace
715        connecting producers and consumers of tensors.
716      show_memory: (Optional.) If True, add object snapshot events to the trace
717        showing the sizes and lifetimes of tensors.
718      op_time: (Optional.) How the execution time of op is shown in timeline.
719        Possible values are "schedule", "gpu" and "all".
720        "schedule" will show op from the time it is scheduled to the end of
721          the scheduling.
722          Notice by the end of its scheduling its async kernels may not start
723          yet. It is shown using the default value from step_stats.
724        "gpu" will show op with the execution time of its kernels on GPU.
725        "all" will show op from the start of its scheduling to the end of
726          its last kernel.
727
728    Returns:
729      A JSON formatted string in Chrome Trace format.
730    """
731    step_stats_analysis = self.analyze_step_stats(
732        show_dataflow=show_dataflow, show_memory=show_memory, op_time=op_time)
733
734    return step_stats_analysis.chrome_trace.format_to_string(pretty=True)
735