1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Logging tensorflow::tfprof::OpLogProto. 16 17OpLogProto is used to add extra model information for offline analysis. 18""" 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import os 24import sys 25 26import six 27from tensorflow.core.profiler import tfprof_log_pb2 28from tensorflow.python.eager import context 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.platform import gfile 32from tensorflow.python.profiler.internal import flops_registry # pylint: disable=unused-import 33from tensorflow.python.util.tf_export import tf_export 34 35TRAINABLE_VARIABLES = '_trainable_variables' 36REGISTERED_FLOP_STATS = 'flops' 37 38 39def _fill_missing_graph_shape(graph, run_meta): 40 """Fill Tensor shapes in 'graph' with run time shape from 'run_meta'.""" 41 for dev_stat in run_meta.step_stats.dev_stats: 42 for node_stat in dev_stat.node_stats: 43 if not node_stat.output: 44 continue 45 try: 46 op = graph.get_operation_by_name(node_stat.node_name) 47 except KeyError as e: 48 # Graph doesn't contains the node_stat, usually RecvTensor. 49 continue 50 if len(node_stat.output) != len(op.outputs): 51 # For example, conditional op has only 1 output at run time. 52 continue 53 for (i, node_stat_out) in enumerate(node_stat.output): 54 if op.outputs[i].get_shape().is_fully_defined(): 55 continue 56 node_stat_dims = node_stat_out.tensor_description.shape.dim 57 node_stat_shape = tensor_shape.TensorShape( 58 [d.size for d in node_stat_dims]) 59 try: 60 op.outputs[i].set_shape(op.outputs[i].get_shape().merge_with( 61 node_stat_shape)) 62 except ValueError as e: 63 sys.stderr.write('Node %s incompatible shapes: %s.\n' % 64 (node_stat.node_name, e)) 65 return graph 66 67 68def _str_id(s, str_to_id): 69 """Maps string to id.""" 70 num = str_to_id.get(s, None) 71 if num is None: 72 num = len(str_to_id) 73 str_to_id[s] = num 74 return num 75 76 77def _get_logged_ops(graph, run_meta=None, add_trace=True, 78 add_trainable_var=True): 79 """Extract trainable model parameters and FLOPs for ops from a Graph. 80 81 Args: 82 graph: tf.Graph. 83 run_meta: RunMetadata proto used to complete shape information. 84 add_trace: Whether to add op trace information. 85 add_trainable_var: Whether to assign tf.trainable_variables() op type 86 '_trainable_variables'. 87 Returns: 88 logged_ops: dict mapping from op_name to OpLogEntry. 89 string_to_id: dict mapping from string to id. 90 """ 91 if run_meta: 92 graph = _fill_missing_graph_shape(graph, run_meta) 93 94 op_missing_shape = 0 95 logged_ops = {} 96 string_to_id = dict() 97 string_to_id['none'] = len(string_to_id) 98 # TODO(xpan): Work with Profiler more efficiently. 99 for op in graph.get_operations(): 100 try: 101 stats = ops.get_stats_for_node_def( 102 graph, op.node_def, REGISTERED_FLOP_STATS) 103 except ValueError: 104 # Catch Exception When shape is incomplete. Skip it. 105 op_missing_shape += 1 106 stats = None 107 108 entry = tfprof_log_pb2.OpLogEntry() 109 entry.name = op.name 110 add_entry = False 111 if stats and stats.value: 112 entry.float_ops = int(stats.value) 113 add_entry = True 114 115 if add_trace: 116 for tb in op.traceback_with_start_lines: 117 trace = entry.code_def.traces.add() 118 trace.file_id = _str_id(tb[0], string_to_id) if tb[0] else 0 119 trace.lineno = tb[1] if tb[1] else -1 120 trace.function_id = _str_id(tb[2], string_to_id) if tb[2] else 0 121 trace.line_id = _str_id(tb[3], string_to_id) if tb[3] else 0 122 trace.func_start_line = tb[4] if tb[4] else -1 123 add_entry = True 124 125 if add_entry: 126 logged_ops[entry.name] = entry 127 128 if add_trainable_var: 129 for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES): 130 if v.op.name not in logged_ops: 131 entry = tfprof_log_pb2.OpLogEntry() 132 entry.name = v.op.name 133 entry.types.append(TRAINABLE_VARIABLES) 134 logged_ops[entry.name] = entry 135 else: 136 logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES) 137 138 if op_missing_shape > 0 and not run_meta: 139 sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' % 140 op_missing_shape) 141 return logged_ops, string_to_id 142 143 144def merge_default_with_oplog(graph, op_log=None, run_meta=None, 145 add_trace=True, add_trainable_var=True): 146 """Merge the tfprof default extra info with caller's op_log. 147 148 Args: 149 graph: tf.Graph. If None and eager execution is not enabled, use 150 default graph. 151 op_log: OpLogProto proto. 152 run_meta: RunMetadata proto used to complete shape information. 153 add_trace: Whether to add op trace information. 154 add_trainable_var: Whether to assign tf.trainable_variables() op type 155 '_trainable_variables'. 156 Returns: 157 tmp_op_log: Merged OpLogProto proto. 158 """ 159 if not graph and not context.executing_eagerly(): 160 graph = ops.get_default_graph() 161 162 tmp_op_log = tfprof_log_pb2.OpLogProto() 163 if not graph: 164 return tmp_op_log 165 166 logged_ops, string_to_id = _get_logged_ops( 167 graph, run_meta, add_trace=add_trace, add_trainable_var=add_trainable_var) 168 169 if not op_log: 170 tmp_op_log.log_entries.extend(logged_ops.values()) 171 else: 172 all_ops = dict() 173 for entry in op_log.log_entries: 174 all_ops[entry.name] = entry 175 for op_name, entry in six.iteritems(logged_ops): 176 if op_name in all_ops: 177 all_ops[op_name].types.extend(entry.types) 178 if entry.float_ops > 0 and all_ops[op_name].float_ops == 0: 179 all_ops[op_name].float_ops = entry.float_ops 180 if entry.code_def.traces and not all_ops[op_name].code_def.traces: 181 all_ops[op_name].code_def.MergeFrom(entry.code_def) 182 else: 183 all_ops[op_name] = entry 184 tmp_op_log.log_entries.extend(all_ops.values()) 185 186 for s, i in six.iteritems(string_to_id): 187 tmp_op_log.id_to_string[i] = s 188 return tmp_op_log 189 190 191@tf_export(v1=['profiler.write_op_log']) 192def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True): 193 """Log provided 'op_log', and add additional model information below. 194 195 The API also assigns ops in tf.trainable_variables() an op type called 196 '_trainable_variables'. 197 The API also logs 'flops' statistics for ops with op.RegisterStatistics() 198 defined. flops calculation depends on Tensor shapes defined in 'graph', 199 which might not be complete. 'run_meta', if provided, completes the shape 200 information with best effort. 201 202 Args: 203 graph: tf.Graph. If None and eager execution is not enabled, use 204 default graph. 205 log_dir: directory to write the log file. 206 op_log: (Optional) OpLogProto proto to be written. If not provided, an new 207 one is created. 208 run_meta: (Optional) RunMetadata proto that helps flops computation using 209 run time shape information. 210 add_trace: Whether to add python code trace information. 211 Used to support "code" view. 212 """ 213 if not graph and not context.executing_eagerly(): 214 graph = ops.get_default_graph() 215 op_log = merge_default_with_oplog(graph, op_log, run_meta, add_trace) 216 217 with gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log: 218 log.write(op_log.SerializeToString()) 219