1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Logging tensorflow::tfprof::OpLogProto. 16 17OpLogProto is used to add extra model information for offline analysis. 18""" 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import os 24import sys 25 26import six 27from tensorflow.core.profiler import tfprof_log_pb2 28from tensorflow.python.eager import context 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.platform import gfile 32from tensorflow.python.profiler.internal import flops_registry # pylint: disable=unused-import 33from tensorflow.python.util.tf_export import tf_export 34 35TRAINABLE_VARIABLES = '_trainable_variables' 36REGISTERED_FLOP_STATS = 'flops' 37 38 39def _fill_missing_graph_shape(graph, run_meta): 40 """Fill Tensor shapes in 'graph' with run time shape from 'run_meta'.""" 41 for dev_stat in run_meta.step_stats.dev_stats: 42 for node_stat in dev_stat.node_stats: 43 if not node_stat.output: 44 continue 45 try: 46 op = graph.get_operation_by_name(node_stat.node_name) 47 except KeyError as e: 48 # Graph doesn't contains the node_stat, usually RecvTensor. 49 continue 50 if len(node_stat.output) != len(op.outputs): 51 # For example, conditional op has only 1 output at run time. 52 continue 53 for (i, node_stat_out) in enumerate(node_stat.output): 54 if op.outputs[i].get_shape().is_fully_defined(): 55 continue 56 node_stat_dims = node_stat_out.tensor_description.shape.dim 57 node_stat_shape = tensor_shape.TensorShape( 58 [d.size for d in node_stat_dims]) 59 try: 60 op.outputs[i].set_shape(op.outputs[i].get_shape().merge_with( 61 node_stat_shape)) 62 except ValueError as e: 63 sys.stderr.write('Node %s incompatible shapes: %s.\n' % 64 (node_stat.node_name, e)) 65 return graph 66 67 68def _str_id(s, str_to_id): 69 """Maps string to id.""" 70 num = str_to_id.get(s, None) 71 if num is None: 72 num = len(str_to_id) 73 str_to_id[s] = num 74 return num 75 76 77def _get_logged_ops(graph, run_meta=None, add_trace=True, 78 add_trainable_var=True): 79 """Extract trainable model parameters and FLOPs for ops from a Graph. 80 81 Args: 82 graph: tf.Graph. 83 run_meta: RunMetadata proto used to complete shape information. 84 add_trace: Whether to add op trace information. 85 add_trainable_var: Whether to assign tf.compat.v1.trainable_variables() op 86 type '_trainable_variables'. 87 Returns: 88 logged_ops: dict mapping from op_name to OpLogEntry. 89 string_to_id: dict mapping from string to id. 90 """ 91 if run_meta: 92 graph = _fill_missing_graph_shape(graph, run_meta) 93 94 op_missing_shape = 0 95 logged_ops = {} 96 string_to_id = {} 97 string_to_id['none'] = len(string_to_id) 98 # TODO(xpan): Work with Profiler more efficiently. 99 for op in graph.get_operations(): 100 try: 101 stats = ops.get_stats_for_node_def( 102 graph, op.node_def, REGISTERED_FLOP_STATS) 103 except ValueError: 104 # Catch Exception When shape is incomplete. Skip it. 105 op_missing_shape += 1 106 stats = None 107 108 entry = tfprof_log_pb2.OpLogEntry() 109 entry.name = op.name 110 add_entry = False 111 if stats and stats.value: 112 entry.float_ops = int(stats.value) 113 add_entry = True 114 115 if add_trace: 116 for tb in op.traceback: 117 trace = entry.code_def.traces.add() 118 trace.file_id = _str_id(tb[0], string_to_id) if tb[0] else 0 119 trace.lineno = tb[1] if tb[1] else -1 120 trace.function_id = _str_id(tb[2], string_to_id) if tb[2] else 0 121 trace.line_id = _str_id(tb[3], string_to_id) if tb[3] else 0 122 # TODO(slebedev): remove this unused field from the proto. 123 trace.func_start_line = -1 124 add_entry = True 125 126 if add_entry: 127 logged_ops[entry.name] = entry 128 129 if add_trainable_var: 130 for v in graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES): 131 if v.op.name not in logged_ops: 132 entry = tfprof_log_pb2.OpLogEntry() 133 entry.name = v.op.name 134 entry.types.append(TRAINABLE_VARIABLES) 135 logged_ops[entry.name] = entry 136 else: 137 logged_ops[v.op.name].types.append(TRAINABLE_VARIABLES) 138 139 if op_missing_shape > 0 and not run_meta: 140 sys.stderr.write('%d ops no flops stats due to incomplete shapes.\n' % 141 op_missing_shape) 142 return logged_ops, string_to_id 143 144 145def merge_default_with_oplog(graph, op_log=None, run_meta=None, 146 add_trace=True, add_trainable_var=True): 147 """Merge the tfprof default extra info with caller's op_log. 148 149 Args: 150 graph: tf.Graph. If None and eager execution is not enabled, use 151 default graph. 152 op_log: OpLogProto proto. 153 run_meta: RunMetadata proto used to complete shape information. 154 add_trace: Whether to add op trace information. 155 add_trainable_var: Whether to assign tf.compat.v1.trainable_variables() op 156 type '_trainable_variables'. 157 Returns: 158 tmp_op_log: Merged OpLogProto proto. 159 """ 160 if not graph and not context.executing_eagerly(): 161 graph = ops.get_default_graph() 162 163 tmp_op_log = tfprof_log_pb2.OpLogProto() 164 if not graph: 165 return tmp_op_log 166 167 logged_ops, string_to_id = _get_logged_ops( 168 graph, run_meta, add_trace=add_trace, add_trainable_var=add_trainable_var) 169 170 if not op_log: 171 tmp_op_log.log_entries.extend(logged_ops.values()) 172 else: 173 all_ops = {} 174 for entry in op_log.log_entries: 175 all_ops[entry.name] = entry 176 for op_name, entry in six.iteritems(logged_ops): 177 if op_name in all_ops: 178 all_ops[op_name].types.extend(entry.types) 179 if entry.float_ops > 0 and all_ops[op_name].float_ops == 0: 180 all_ops[op_name].float_ops = entry.float_ops 181 if entry.code_def.traces and not all_ops[op_name].code_def.traces: 182 all_ops[op_name].code_def.MergeFrom(entry.code_def) 183 else: 184 all_ops[op_name] = entry 185 tmp_op_log.log_entries.extend(all_ops.values()) 186 187 for s, i in six.iteritems(string_to_id): 188 tmp_op_log.id_to_string[i] = s 189 return tmp_op_log 190 191 192@tf_export(v1=['profiler.write_op_log']) 193def write_op_log(graph, log_dir, op_log=None, run_meta=None, add_trace=True): 194 """Log provided 'op_log', and add additional model information below. 195 196 The API also assigns ops in tf.compat.v1.trainable_variables() an op type 197 called '_trainable_variables'. 198 The API also logs 'flops' statistics for ops with op.RegisterStatistics() 199 defined. flops calculation depends on Tensor shapes defined in 'graph', 200 which might not be complete. 'run_meta', if provided, completes the shape 201 information with best effort. 202 203 Args: 204 graph: tf.Graph. If None and eager execution is not enabled, use 205 default graph. 206 log_dir: directory to write the log file. 207 op_log: (Optional) OpLogProto proto to be written. If not provided, an new 208 one is created. 209 run_meta: (Optional) RunMetadata proto that helps flops computation using 210 run time shape information. 211 add_trace: Whether to add python code trace information. 212 Used to support "code" view. 213 """ 214 if not graph and not context.executing_eagerly(): 215 graph = ops.get_default_graph() 216 op_log = merge_default_with_oplog(graph, op_log, run_meta, add_trace) 217 218 with gfile.Open(os.path.join(log_dir, 'tfprof_log'), 'w') as log: 219 log.write(op_log.SerializeToString()) 220