1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Logging tensorflow::tfprof::OpLogProto.
16
17OpLogProto is used to add extra model information for offline analysis.
18"""
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import os
24import sys
25
26import six
27from tensorflow.core.profiler import tfprof_log_pb2
28from tensorflow.python.eager import context
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.platform import gfile
32from tensorflow.python.profiler.internal import flops_registry  # pylint: disable=unused-import
33from tensorflow.python.util.tf_export import tf_export
34
35TRAINABLE_VARIABLES = '_trainable_variables'
36REGISTERED_FLOP_STATS = 'flops'
37
38
39def _fill_missing_graph_shape(graph, run_meta):
40  """Fill Tensor shapes in 'graph' with run time shape from 'run_meta'."""
41  for dev_stat in run_meta.step_stats.dev_stats:
42    for node_stat in dev_stat.node_stats:
43      if not node_stat.output:
44        continue
45      try:
46        op = graph.get_operation_by_name(node_stat.node_name)
47      except KeyError as e:
48        # Graph doesn't contains the node_stat, usually RecvTensor.
49        continue
50      if len(node_stat.output) != len(op.outputs):
51        # For example, conditional op has only 1 output at run time.
52        continue
53      for (i, node_stat_out) in enumerate(node_stat.output):
54        if op.outputs[i].get_shape().is_fully_defined():
55          continue
56        node_stat_dims = node_stat_out.tensor_description.shape.dim
57        node_stat_shape = tensor_shape.TensorShape(
58            [d.size for d in node_stat_dims])
59        try:
60          op.outputs[i].set_shape(op.outputs[i].get_shape().merge_with(
61              node_stat_shape))
62        except ValueError as e:
63          sys.stderr.write('Node %s incompatible shapes: %s.\n' %
64                           (node_stat.node_name, e))
65  return graph
66
67
68def _str_id(s, str_to_id):
69  """Maps string to id."""
70  num = str_to_id.get(s, None)
71  if num is None:
72    num = len(str_to_id)
73    str_to_id[s] = num
74  return num
75
76
77def _get_logged_ops(graph, run_meta=None, add_trace=True,
78                    add_trainable_var=True):
79  """Extract trainable model parameters and FLOPs for ops from a Graph.
80
81  Args:
82    graph: tf.Graph.
83    run_meta: RunMetadata proto used to complete shape information.
84    add_trace: Whether to add op trace information.
85    add_trainable_var: Whether to assign tf.trainable_variables() op type
86      '_trainable_variables'.
87  Returns:
88    logged_ops: dict mapping from op_name to OpLogEntry.
89    string_to_id: dict mapping from string to id.
90  """
91  if run_meta:
92    graph = _fill_missing_graph_shape(graph, run_meta)
93
94  op_missing_shape = 0
95  logged_ops = {}
96  string_to_id = dict()
97  string_to_id['none'] = len(string_to_id)
98  # TODO(xpan): Work with Profiler more efficiently.
99  for op in graph.get_operations():
100    try:
101      stats = ops.get_stats_for_node_def(
102          graph, op.node_def, REGISTERED_FLOP_STATS)
103    except ValueError:
104      # Catch Exception When shape is incomplete. Skip it.
105      op_missing_shape += 1
106      stats = None
107
108    entry = tfprof_log_pb2.OpLogEntry()
109    entry.name = op.name
110    add_entry = False
111    if stats and stats.value:
112      entry.float_ops = int(stats.value)
113      add_entry = True
114
115    if add_trace:
116      for tb in op.traceback_with_start_lines:
117        trace = entry.code_def.traces.add()
118        trace.file_id = _str_id(tb[0], string_to_id) if tb[0] else 0
119        trace.lineno = tb[1] if tb[1] else -1
120        trace.function_id = _str_id(tb[2], string_to_id) if tb[2] else 0
121        trace.line_id = _str_id(tb[3], string_to_id) if tb[3] else 0
122        trace.func_start_line = tb[4] if tb[4] else -1
123      add_entry = True
124
125    if add_entry:
126      logged_ops[entry.name] = entry
127
128  if add_trainable_var:
129    for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES):
130      if v.op.name not in logged_ops:
131        entry = tfprof_log_pb2.OpLogEntry()
132        entry.name = v.op.name
133        entry.types.append(TRAINABLE_VARIABLES)
134        logged_ops[entry.name] = entry
135      else:
136        logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES)
137
138  if op_missing_shape > 0 and not run_meta:
139    sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' %
140                     op_missing_shape)
141  return logged_ops, string_to_id
142
143
144def merge_default_with_oplog(graph, op_log=None, run_meta=None,
145                             add_trace=True, add_trainable_var=True):
146  """Merge the tfprof default extra info with caller's op_log.
147
148  Args:
149    graph: tf.Graph. If None and eager execution is not enabled, use
150        default graph.
151    op_log: OpLogProto proto.
152    run_meta: RunMetadata proto used to complete shape information.
153    add_trace: Whether to add op trace information.
154    add_trainable_var: Whether to assign tf.trainable_variables() op type
155      '_trainable_variables'.
156  Returns:
157    tmp_op_log: Merged OpLogProto proto.
158  """
159  if not graph and not context.executing_eagerly():
160    graph = ops.get_default_graph()
161
162  tmp_op_log = tfprof_log_pb2.OpLogProto()
163  if not graph:
164    return tmp_op_log
165
166  logged_ops, string_to_id = _get_logged_ops(
167      graph, run_meta, add_trace=add_trace, add_trainable_var=add_trainable_var)
168
169  if not op_log:
170    tmp_op_log.log_entries.extend(logged_ops.values())
171  else:
172    all_ops = dict()
173    for entry in op_log.log_entries:
174      all_ops[entry.name] = entry
175    for op_name, entry in six.iteritems(logged_ops):
176      if op_name in all_ops:
177        all_ops[op_name].types.extend(entry.types)
178        if entry.float_ops > 0 and all_ops[op_name].float_ops == 0:
179          all_ops[op_name].float_ops = entry.float_ops
180        if entry.code_def.traces and not all_ops[op_name].code_def.traces:
181          all_ops[op_name].code_def.MergeFrom(entry.code_def)
182      else:
183        all_ops[op_name] = entry
184    tmp_op_log.log_entries.extend(all_ops.values())
185
186  for s, i in six.iteritems(string_to_id):
187    tmp_op_log.id_to_string[i] = s
188  return tmp_op_log
189
190
191@tf_export(v1=['profiler.write_op_log'])
192def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True):
193  """Log provided 'op_log', and add additional model information below.
194
195    The API also assigns ops in tf.trainable_variables() an op type called
196    '_trainable_variables'.
197    The API also logs 'flops' statistics for ops with op.RegisterStatistics()
198    defined. flops calculation depends on Tensor shapes defined in 'graph',
199    which might not be complete. 'run_meta', if provided, completes the shape
200    information with best effort.
201
202  Args:
203    graph: tf.Graph. If None and eager execution is not enabled, use
204        default graph.
205    log_dir: directory to write the log file.
206    op_log: (Optional) OpLogProto proto to be written. If not provided, an new
207        one is created.
208    run_meta: (Optional) RunMetadata proto that helps flops computation using
209        run time shape information.
210    add_trace: Whether to add python code trace information.
211        Used to support "code" view.
212  """
213  if not graph and not context.executing_eagerly():
214    graph = ops.get_default_graph()
215  op_log = merge_default_with_oplog(graph, op_log, run_meta, add_trace)
216
217  with gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log:
218    log.write(op_log.SerializeToString())
219