1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Classes and functions that help to inspect Python source w.r.t. TF graphs.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import os 23import re 24 25import numpy as np 26 27from tensorflow.python.debug.lib import profiling 28 29 30_TENSORFLOW_BASEDIR = os.path.dirname( 31 os.path.dirname(os.path.dirname(os.path.dirname( 32 os.path.normpath(os.path.abspath(__file__)))))) 33 34UNCOMPILED_SOURCE_SUFFIXES = (".py") 35COMPILED_SOURCE_SUFFIXES = (".pyc", ".pyo") 36 37 38def _norm_abs_path(file_path): 39 return os.path.normpath(os.path.abspath(file_path)) 40 41 42def is_extension_uncompiled_python_source(file_path): 43 _, extension = os.path.splitext(file_path) 44 return extension.lower() in UNCOMPILED_SOURCE_SUFFIXES 45 46 47def is_extension_compiled_python_source(file_path): 48 _, extension = os.path.splitext(file_path) 49 return extension.lower() in COMPILED_SOURCE_SUFFIXES 50 51 52def _convert_watch_key_to_tensor_name(watch_key): 53 return watch_key[:watch_key.rfind(":")] 54 55 56def guess_is_tensorflow_py_library(py_file_path): 57 """Guess whether a Python source file is a part of the tensorflow library. 58 59 Special cases: 60 1) Returns False for unit-test files in the library (*_test.py), 61 2) Returns False for files under python/debug/examples. 62 63 Args: 64 py_file_path: full path of the Python source file in question. 65 66 Returns: 67 (`bool`) Whether the file is a part of the tensorflow library. 68 69 Raises: 70 ValueError: if the extension name of py_file_path does not indicate a Python 71 source file (compiled or uncomplied). 72 """ 73 if (not is_extension_uncompiled_python_source(py_file_path) and 74 not is_extension_compiled_python_source(py_file_path)): 75 raise ValueError( 76 "Input file path (%s) is not a Python source file." % py_file_path) 77 py_file_path = _norm_abs_path(py_file_path) 78 79 return (py_file_path.startswith(_TENSORFLOW_BASEDIR) and 80 not py_file_path.endswith("_test.py") and 81 not os.path.dirname(py_file_path).endswith( 82 os.path.normpath("python/debug/examples"))) 83 84 85def load_source(source_file_path): 86 with open(source_file_path, "rU") as f: 87 source_text = f.read() 88 source_lines = source_text.split("\n") 89 line_num_width = int(np.ceil(np.log10(len(source_lines)))) + 3 90 return source_lines, line_num_width 91 92 93def annotate_source(dump, 94 source_file_path, 95 do_dumped_tensors=False, 96 file_stack_top=False, 97 min_line=None, 98 max_line=None): 99 """Annotate a Python source file with a list of ops created at each line. 100 101 (The annotation doesn't change the source file itself.) 102 103 Args: 104 dump: (`DebugDumpDir`) A `DebugDumpDir` object of which the Python graph 105 has been loaded. 106 source_file_path: (`str`) Path to the source file being annotated. 107 do_dumped_tensors: (`str`) Whether dumped Tensors, instead of ops are to be 108 used to annotate the source file. 109 file_stack_top: (`bool`) Whether only the top stack trace in the 110 specified source file is to be annotated. 111 min_line: (`None` or `int`) The 1-based line to start annotate the source 112 file from (inclusive). 113 max_line: (`None` or `int`) The 1-based line number to end the annotation 114 at (exclusive). 115 116 Returns: 117 A `dict` mapping 1-based line number to a list of op name(s) created at 118 that line, or tensor names if `do_dumped_tensors` is True. 119 120 Raises: 121 ValueError: If the dump object does not have a Python graph set. 122 """ 123 124 py_graph = dump.python_graph 125 if not py_graph: 126 raise ValueError("Cannot perform source annotation due to a lack of set " 127 "Python graph in the dump object") 128 129 source_file_path = _norm_abs_path(source_file_path) 130 131 line_to_op_names = {} 132 for op in py_graph.get_operations(): 133 for file_path, line_number, _, _ in reversed(dump.node_traceback(op.name)): 134 if (min_line is not None and line_number < min_line or 135 max_line is not None and line_number >= max_line): 136 continue 137 138 if _norm_abs_path(file_path) != source_file_path: 139 continue 140 141 if do_dumped_tensors: 142 watch_keys = dump.debug_watch_keys(op.name) 143 # Convert watch keys to unique Tensor names. 144 items_to_append = list( 145 set(map(_convert_watch_key_to_tensor_name, watch_keys))) 146 else: 147 items_to_append = [op.name] 148 149 if line_number in line_to_op_names: 150 line_to_op_names[line_number].extend(items_to_append) 151 else: 152 line_to_op_names[line_number] = items_to_append 153 154 if file_stack_top: 155 break 156 157 return line_to_op_names 158 159 160def list_source_files_against_dump(dump, 161 path_regex_whitelist=None, 162 node_name_regex_whitelist=None): 163 """Generate a list of source files with information regarding ops and tensors. 164 165 Args: 166 dump: (`DebugDumpDir`) A `DebugDumpDir` object of which the Python graph 167 has been loaded. 168 path_regex_whitelist: A regular-expression filter for source file path. 169 node_name_regex_whitelist: A regular-expression filter for node names. 170 171 Returns: 172 A list of tuples regarding the Python source files involved in constructing 173 the ops and tensors contained in `dump`. Each tuple is: 174 (source_file_path, is_tf_library, num_nodes, num_tensors, num_dumps, 175 first_line) 176 177 is_tf_library: (`bool`) A guess of whether the file belongs to the 178 TensorFlow Python library. 179 num_nodes: How many nodes were created by lines of this source file. 180 These include nodes with dumps and those without. 181 num_tensors: How many Tensors were created by lines of this source file. 182 These include Tensors with dumps and those without. 183 num_dumps: How many debug Tensor dumps were from nodes (and Tensors) 184 that were created by this source file. 185 first_line: The first line number (1-based) that created any nodes or 186 Tensors in this source file. 187 188 The list is sorted by ascending order of source_file_path. 189 190 Raises: 191 ValueError: If the dump object does not have a Python graph set. 192 """ 193 194 py_graph = dump.python_graph 195 if not py_graph: 196 raise ValueError("Cannot generate source list due to a lack of set " 197 "Python graph in the dump object") 198 199 path_to_node_names = collections.defaultdict(set) 200 path_to_tensor_names = collections.defaultdict(set) 201 path_to_first_line = {} 202 tensor_name_to_num_dumps = {} 203 204 path_regex = (re.compile(path_regex_whitelist) 205 if path_regex_whitelist else None) 206 node_name_regex = (re.compile(node_name_regex_whitelist) 207 if node_name_regex_whitelist else None) 208 209 to_skip_file_paths = set() 210 for op in py_graph.get_operations(): 211 if node_name_regex and not node_name_regex.match(op.name): 212 continue 213 214 for file_path, line_number, _, _ in dump.node_traceback(op.name): 215 file_path = _norm_abs_path(file_path) 216 if (file_path in to_skip_file_paths or 217 path_regex and not path_regex.match(file_path) or 218 not os.path.isfile(file_path)): 219 to_skip_file_paths.add(file_path) 220 continue 221 222 path_to_node_names[file_path].add(op.name) 223 if file_path in path_to_first_line: 224 if path_to_first_line[file_path] > line_number: 225 path_to_first_line[file_path] = line_number 226 else: 227 path_to_first_line[file_path] = line_number 228 229 for output_tensor in op.outputs: 230 tensor_name = output_tensor.name 231 path_to_tensor_names[file_path].add(tensor_name) 232 233 watch_keys = dump.debug_watch_keys(op.name) 234 for watch_key in watch_keys: 235 node_name, output_slot, debug_op = watch_key.split(":") 236 tensor_name = "%s:%s" % (node_name, output_slot) 237 if tensor_name not in tensor_name_to_num_dumps: 238 tensor_name_to_num_dumps[tensor_name] = len( 239 dump.get_tensors(node_name, int(output_slot), debug_op)) 240 241 path_to_num_dumps = {} 242 for path in path_to_tensor_names: 243 path_to_num_dumps[path] = sum( 244 tensor_name_to_num_dumps.get(tensor_name, 0) 245 for tensor_name in path_to_tensor_names[path]) 246 247 output = [] 248 for file_path in path_to_node_names: 249 output.append(( 250 file_path, 251 guess_is_tensorflow_py_library(file_path), 252 len(path_to_node_names.get(file_path, {})), 253 len(path_to_tensor_names.get(file_path, {})), 254 path_to_num_dumps.get(file_path, 0), 255 path_to_first_line[file_path])) 256 257 return sorted(output, key=lambda x: x[0]) 258 259 260def annotate_source_against_profile(profile_data, 261 source_file_path, 262 node_name_filter=None, 263 op_type_filter=None, 264 min_line=None, 265 max_line=None): 266 """Annotate a Python source file with profiling information at each line. 267 268 (The annotation doesn't change the source file itself.) 269 270 Args: 271 profile_data: (`list` of `ProfileDatum`) A list of `ProfileDatum`. 272 source_file_path: (`str`) Path to the source file being annotated. 273 node_name_filter: Regular expression to filter by node name. 274 op_type_filter: Regular expression to filter by op type. 275 min_line: (`None` or `int`) The 1-based line to start annotate the source 276 file from (inclusive). 277 max_line: (`None` or `int`) The 1-based line number to end the annotation 278 at (exclusive). 279 280 Returns: 281 A `dict` mapping 1-based line number to a the namedtuple 282 `profiling.LineOrFuncProfileSummary`. 283 """ 284 285 source_file_path = _norm_abs_path(source_file_path) 286 287 node_name_regex = re.compile(node_name_filter) if node_name_filter else None 288 op_type_regex = re.compile(op_type_filter) if op_type_filter else None 289 290 line_to_profile_summary = {} 291 for profile_datum in profile_data: 292 if not profile_datum.file_path: 293 continue 294 295 if _norm_abs_path(profile_datum.file_path) != source_file_path: 296 continue 297 298 if (min_line is not None and profile_datum.line_number < min_line or 299 max_line is not None and profile_datum.line_number >= max_line): 300 continue 301 302 if (node_name_regex and 303 not node_name_regex.match(profile_datum.node_exec_stats.node_name)): 304 continue 305 306 if op_type_regex and not op_type_regex.match(profile_datum.op_type): 307 continue 308 309 if profile_datum.line_number not in line_to_profile_summary: 310 line_to_profile_summary[profile_datum.line_number] = ( 311 profiling.AggregateProfile(profile_datum)) 312 else: 313 line_to_profile_summary[profile_datum.line_number].add(profile_datum) 314 315 return line_to_profile_summary 316