1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Logging tensorflow::tfprof::OpLogProto.
16
17OpLogProto is used to add extra model information for offline analysis.
18"""
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import os
24import sys
25
26import six
27from tensorflow.core.profiler import tfprof_log_pb2
28from tensorflow.python.eager import context
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.platform import gfile
32from tensorflow.python.profiler.internal import flops_registry  # pylint: disable=unused-import
33from tensorflow.python.util.tf_export import tf_export
34
35TRAINABLE_VARIABLES = '_trainable_variables'
36REGISTERED_FLOP_STATS = 'flops'
37
38
39def _fill_missing_graph_shape(graph, run_meta):
40  """Fill Tensor shapes in 'graph' with run time shape from 'run_meta'."""
41  for dev_stat in run_meta.step_stats.dev_stats:
42    for node_stat in dev_stat.node_stats:
43      if not node_stat.output:
44        continue
45      try:
46        op = graph.get_operation_by_name(node_stat.node_name)
47      except KeyError as e:
48        # Graph doesn't contains the node_stat, usually RecvTensor.
49        continue
50      if len(node_stat.output) != len(op.outputs):
51        # For example, conditional op has only 1 output at run time.
52        continue
53      for (i, node_stat_out) in enumerate(node_stat.output):
54        if op.outputs[i].get_shape().is_fully_defined():
55          continue
56        node_stat_dims = node_stat_out.tensor_description.shape.dim
57        node_stat_shape = tensor_shape.TensorShape(
58            [d.size for d in node_stat_dims])
59        try:
60          op.outputs[i].set_shape(op.outputs[i].get_shape().merge_with(
61              node_stat_shape))
62        except ValueError as e:
63          sys.stderr.write('Node %s incompatible shapes: %s.\n' %
64                           (node_stat.node_name, e))
65  return graph
66
67
68def _str_id(s, str_to_id):
69  """Maps string to id."""
70  num = str_to_id.get(s, None)
71  if num is None:
72    num = len(str_to_id)
73    str_to_id[s] = num
74  return num
75
76
77def _get_logged_ops(graph, run_meta=None, add_trace=True,
78                    add_trainable_var=True):
79  """Extract trainable model parameters and FLOPs for ops from a Graph.
80
81  Args:
82    graph: tf.Graph.
83    run_meta: RunMetadata proto used to complete shape information.
84    add_trace: Whether to add op trace information.
85    add_trainable_var: Whether to assign tf.compat.v1.trainable_variables() op
86      type '_trainable_variables'.
87  Returns:
88    logged_ops: dict mapping from op_name to OpLogEntry.
89    string_to_id: dict mapping from string to id.
90  """
91  if run_meta:
92    graph = _fill_missing_graph_shape(graph, run_meta)
93
94  op_missing_shape = 0
95  logged_ops = {}
96  string_to_id = {}
97  string_to_id['none'] = len(string_to_id)
98  # TODO(xpan): Work with Profiler more efficiently.
99  for op in graph.get_operations():
100    try:
101      stats = ops.get_stats_for_node_def(
102          graph, op.node_def, REGISTERED_FLOP_STATS)
103    except ValueError:
104      # Catch Exception When shape is incomplete. Skip it.
105      op_missing_shape += 1
106      stats = None
107
108    entry = tfprof_log_pb2.OpLogEntry()
109    entry.name = op.name
110    add_entry = False
111    if stats and stats.value:
112      entry.float_ops = int(stats.value)
113      add_entry = True
114
115    if add_trace:
116      for tb in op.traceback:
117        trace = entry.code_def.traces.add()
118        trace.file_id = _str_id(tb[0], string_to_id) if tb[0] else 0
119        trace.lineno = tb[1] if tb[1] else -1
120        trace.function_id = _str_id(tb[2], string_to_id) if tb[2] else 0
121        trace.line_id = _str_id(tb[3], string_to_id) if tb[3] else 0
122        # TODO(slebedev): remove this unused field from the proto.
123        trace.func_start_line = -1
124      add_entry = True
125
126    if add_entry:
127      logged_ops[entry.name] = entry
128
129  if add_trainable_var:
130    for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES):
131      if v.op.name not in logged_ops:
132        entry = tfprof_log_pb2.OpLogEntry()
133        entry.name = v.op.name
134        entry.types.append(TRAINABLE_VARIABLES)
135        logged_ops[entry.name] = entry
136      else:
137        logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES)
138
139  if op_missing_shape > 0 and not run_meta:
140    sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' %
141                     op_missing_shape)
142  return logged_ops, string_to_id
143
144
145def merge_default_with_oplog(graph, op_log=None, run_meta=None,
146                             add_trace=True, add_trainable_var=True):
147  """Merge the tfprof default extra info with caller's op_log.
148
149  Args:
150    graph: tf.Graph. If None and eager execution is not enabled, use
151        default graph.
152    op_log: OpLogProto proto.
153    run_meta: RunMetadata proto used to complete shape information.
154    add_trace: Whether to add op trace information.
155    add_trainable_var: Whether to assign tf.compat.v1.trainable_variables() op
156      type '_trainable_variables'.
157  Returns:
158    tmp_op_log: Merged OpLogProto proto.
159  """
160  if not graph and not context.executing_eagerly():
161    graph = ops.get_default_graph()
162
163  tmp_op_log = tfprof_log_pb2.OpLogProto()
164  if not graph:
165    return tmp_op_log
166
167  logged_ops, string_to_id = _get_logged_ops(
168      graph, run_meta, add_trace=add_trace, add_trainable_var=add_trainable_var)
169
170  if not op_log:
171    tmp_op_log.log_entries.extend(logged_ops.values())
172  else:
173    all_ops = {}
174    for entry in op_log.log_entries:
175      all_ops[entry.name] = entry
176    for op_name, entry in six.iteritems(logged_ops):
177      if op_name in all_ops:
178        all_ops[op_name].types.extend(entry.types)
179        if entry.float_ops > 0 and all_ops[op_name].float_ops == 0:
180          all_ops[op_name].float_ops = entry.float_ops
181        if entry.code_def.traces and not all_ops[op_name].code_def.traces:
182          all_ops[op_name].code_def.MergeFrom(entry.code_def)
183      else:
184        all_ops[op_name] = entry
185    tmp_op_log.log_entries.extend(all_ops.values())
186
187  for s, i in six.iteritems(string_to_id):
188    tmp_op_log.id_to_string[i] = s
189  return tmp_op_log
190
191
192@tf_export(v1=['profiler.write_op_log'])
193def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True):
194  """Log provided 'op_log', and add additional model information below.
195
196    The API also assigns ops in tf.compat.v1.trainable_variables() an op type
197    called '_trainable_variables'.
198    The API also logs 'flops' statistics for ops with op.RegisterStatistics()
199    defined. flops calculation depends on Tensor shapes defined in 'graph',
200    which might not be complete. 'run_meta', if provided, completes the shape
201    information with best effort.
202
203  Args:
204    graph: tf.Graph. If None and eager execution is not enabled, use
205        default graph.
206    log_dir: directory to write the log file.
207    op_log: (Optional) OpLogProto proto to be written. If not provided, an new
208        one is created.
209    run_meta: (Optional) RunMetadata proto that helps flops computation using
210        run time shape information.
211    add_trace: Whether to add python code trace information.
212        Used to support "code" view.
213  """
214  if not graph and not context.executing_eagerly():
215    graph = ops.get_default_graph()
216  op_log = merge_default_with_oplog(graph, op_log, run_meta, add_trace)
217
218  with gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log:
219    log.write(op_log.SerializeToString())
220