1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ======================================================================== 15"""A utility to trace tensor values on TPU.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import os.path 23import re 24import sys 25 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import graph_io 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import control_flow_util 34from tensorflow.python.ops import gen_math_ops 35from tensorflow.python.ops import init_ops 36from tensorflow.python.ops import linalg_ops 37from tensorflow.python.ops import logging_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import state_ops 40from tensorflow.python.ops import variable_scope 41from tensorflow.python.platform import gfile 42from tensorflow.python.platform import tf_logging as logging 43from tensorflow.python.tpu import tpu 44from tensorflow.python.tpu.ops import tpu_ops 45 46_TRACER_LOG_PREFIX = ' [>>>TT>>>]' 47_DEVICE_TYPE_TPU = 'tpu' 48_DEVICE_TYPE_CPU = 'cpu' 49_TRACE_MODE_NAN_INF = 'nan-inf' 50_TRACE_MODE_PART_TENSOR = 'part-tensor' 51_TRACE_MODE_PART_TENSOR_SIZE = 3 52_TRACE_MODE_FULL_TENSOR = 'full-tensor' 53_TRACE_MODE_NORM = 'norm' 54_TRACE_MODE_MAX_ABS = 'max-abs' 55_SUBMODE_BRIEF = 'brief' 56_SUBMODE_DETAILED = 'detailed' 57_REASON_OUTSIDE_OP_RANGE = 'not-traced-outside-op-range' 58_REASON_UNSAFE_OP = 'not-traced-unsafe-op' 59_REASON_WHILELOOP_OP = 'not-traced-special-whileloop-op' 60_REASON_UNSAFE_SCALAR = 'not-traced-unsafe-scalar' 61_REASON_LESS_INTERESTING_OP = 'not-traced-less-interesting-op' 62_REASON_DEVICE_MISMATCH = 'not-traced-device-mismatch' 63_REASON_DYNAMIC_SHAPE = 'not-traced-dynamic-shape' 64_REASON_SCALAR_GET_TRACED = 'traced-scalar' 65_REASON_TENSOR_GET_TRACED = 'traced-tensor' 66_REASON_USER_INCLUDED = 'traced-user-included' 67_REASON_USER_EXCLUDED = 'not-traced-user-excluded' 68_REASON_NOT_EXECUTED = 'not-traced-not-in-exec-path' 69_REASON_NON_NUMERIC_TENSOR = 'not-traced-non-numeric-tensor' 70_REASON_FEEDS_WHILELOOP_OP = 'not-traced-feeds-special-whileloop-op' 71_MARKER_SECTION_BEGIN = '!!!!!!! section-begin:' 72_MARKER_SECTION_END = '!!!!!!! section-end:' 73_SECTION_NAME_CONFIG = 'configuration' 74_SECTION_NAME_REASON = 'reason' 75_SECTION_NAME_OP_LIST = 'op-list' 76_SECTION_NAME_TENSOR_LIST = 'tensor-list' 77_SECTION_NAME_CACHE_INDEX_MAP = 'cache-index-map' 78_SECTION_NAME_GRAPH = 'graph' 79_FIELD_NAME_VERSION = 'version:' 80_FIELD_NAME_DEVICE = 'device:' 81_FIELD_NAME_TRACE_MODE = 'trace-mode:' 82_FIELD_NAME_SUBMODE = 'submode:' 83_FIELD_NAME_NUM_REPLICAS = 'num-replicas:' 84_FIELD_NAME_NUM_REPLICAS_PER_HOST = 'num-replicas-per-host:' 85_FIELD_NAME_NUM_HOSTS = 'num-hosts:' 86_FIELD_NAME_NUM_OPS = 'number-of-ops:' 87_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:' 88_FIELD_NAME_NUM_CACHE_INDICES = 'number-of-indices:' 89_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:' 90_FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS' 91_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'") 92_FLAG_DOUBLE_QUOTE_PAT = re.compile(r'\s*--([^=]+)="([^"]*)"') 93_FLAG_NO_QUOTE_PAT = re.compile(r'\s*--([^=]+)=(\S*)') 94_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*') 95_FLAG_NAME_ENABLE = 'enable' 96_FLAG_NAME_TRACE_MODE = 'trace_mode' 97_FLAG_NAME_USE_COMPACT_TRACE = 'compact_trace' 98_FLAG_NAME_SUBMODE = 'submode' 99_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops' 100_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames' 101_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes' 102_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames' 103_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes' 104_FLAG_NAME_TRACE_DIR = 'trace_dir' 105_FLAG_NAME_REPORT_FILE = 'report_file' 106_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir' 107_FLAG_NAME_OP_RANGE = 'op_range' 108# Folder to dump the pre (before tensor tracer updates) and post graphs (after 109# tensor tracer updates). 110_FLAG_DUMP_BEFORE_AFTER_GRAPHS = 'dump_graphs' 111_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)') 112_OUTPUT_STREAM_ESCAPE = 'file://' 113_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR' 114_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables' 115_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint' 116_TRACE_FILE_NAME = 'trace.all' 117_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.' 118_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0 119_TENSOR_TRACER_STORAGE = 'tensor_tracer_storage' 120_TENSOR_VALUES_CACHE = 'tensor_values_cache' 121_REPLICA_ID_TAG = '#replica-id: ' 122 123 124def tensor_tracepoint(tensor, checkpoint_name): 125 """Adds a checkpoint with the given checkpoint name for the given tensor. 126 127 The tensor will be added to the list of tensors that will be traced by the 128 tensor tracer. 129 130 Args: 131 tensor: the tensor object for which the tracing is requested. 132 checkpoint_name: a string name for the checkpoint. This name has to be a 133 unique name if used within model comparison. The tensors that have the same 134 checkpoint identifier is compared in model comparison. 135 Returns: 136 The provided tensor. 137 """ 138 139 tensor.graph.get_collection(_TENSOR_TRACER_COLLECTION) 140 tensor.graph.add_to_collection(_TENSOR_TRACER_COLLECTION, 141 (tensor, checkpoint_name)) 142 return tensor 143 144 145def keras_layer_tracepoint(layer, checkpoint_name): 146 """An interface for adding the tensor outputs of a keras layer. 147 148 Encapsulates tensor_tracepoint. 149 150 Args: 151 layer: A keras layer. 152 checkpoint_name: a string name for the checkpoint. This name has to be a 153 unique name if used within model comparison. The tensors that have the same 154 checkpoint identifier is compared in model comparison. 155 156 Returns: 157 The provided layer. 158 """ 159 try: 160 outputs = layer.output 161 if tensor_util.is_tensor(outputs): 162 tensor_tracepoint(outputs, '%s' % (checkpoint_name)) 163 else: 164 idx = 0 165 for output_tensor in outputs: 166 if tensor_util.is_tensor(outputs): 167 tensor_tracepoint(output_tensor, '%s_%d' % (checkpoint_name, idx)) 168 idx += 1 169 except AttributeError: 170 pass 171 except RuntimeError: 172 pass 173 return layer 174 175 176def _trace_files_need_precreated(output_dir): 177 """Return True if trace files must be pre-created by users.""" 178 179 if not output_dir.startswith('/'): 180 return False 181 if len(output_dir) < 5: 182 return False 183 if output_dir[2] != 'n': 184 return False 185 if output_dir[3] != 's': 186 return False 187 if output_dir[1] != 'c': 188 return False 189 if output_dir[4] != '/': 190 return False 191 return True 192 193 194def _get_tensor_values_cache(graph=None): 195 """Returns the variable that implements tensor-value caching.""" 196 197 graph = graph or ops.get_default_graph() 198 collection = graph.get_collection(_TENSOR_TRACER_STORAGE) 199 if len(collection) == 1: 200 return collection[0] 201 elif not collection: 202 raise RuntimeError('%s has not been created'%_TENSOR_VALUES_CACHE) 203 else: 204 raise RuntimeError('Multiple %s created'%_TENSOR_VALUES_CACHE) 205 return None 206 207 208def _create_tensor_values_cache(graph, num_tensors): 209 """Creates a variable as the cache to store intermediate tensor values.""" 210 graph = graph or ops.get_default_graph() 211 # Create in proper graph and base name_scope. 212 with graph.as_default() as g, g.name_scope(None): 213 return variable_scope.get_variable( 214 _TENSOR_VALUES_CACHE, 215 shape=[num_tensors], 216 dtype=dtypes.float32, 217 initializer=init_ops.constant_initializer( 218 _COMPACT_TRACE_ENTRY_INIT_VALUE), 219 trainable=False, 220 use_resource=True, 221 collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.GLOBAL_VARIABLES]) 222 223 224class TensorTracer(object): 225 """A software construct for tracing tensor values in a TF graph on TPU. 226 227 This utility is disabled by default. It can be enabled by setting 228 the TENSOR_TRACER_FLAGS env variable as: 229 export TENSOR_TRACER_FLAGS="--enable=1" 230 If it is enabled, it will trace the output tensor values of 231 selected Ops in the graph. It has two outputs: (1) the traces and (2) 232 a report. The traces are dumped to a specified local file on the TPU 233 host. The report is printed to the log.info of the TPU job. 234 By passing options via the env variable, users can change: 235 (1) the trace mode (e.g., detecting NaN/Inf, printing partial or 236 full tensor values) 237 (2) which Ops to be traced (via op.name or op.type) 238 (3) output trace file path. 239 """ 240 # The set of graphs that are rewritten by tensor tracer. 241 _traced_graphs = set() 242 @staticmethod 243 def _match_next_flag(flags, pos): 244 """Returns the match for the next TensorTracer flag. 245 246 Args: 247 flags: a string that contains the flags. 248 pos: where in flags to start the search. 249 250 Returns: 251 A pair where the first element is the regular-expression 252 match found and the second element indicates if the match 253 has a value. 254 """ 255 256 match = _FLAG_DOUBLE_QUOTE_PAT.match(flags, pos) 257 if match: 258 return match, True 259 match = _FLAG_SINGLE_QUOTE_PAT.match(flags, pos) 260 if match: 261 return match, True 262 match = _FLAG_NO_QUOTE_PAT.match(flags, pos) 263 if match: 264 return match, True 265 match = _FLAG_NO_EQUAL_PAT.match(flags, pos) 266 if match: 267 # The flag is found but is not given a value. 268 return match, False 269 # The flag is not found. 270 return None, False 271 272 @staticmethod 273 def validate_flag_names(): 274 """Validates if the TensorTrace flags passed are valid.""" 275 valid_flag_names = [_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE, 276 _FLAG_NAME_USE_COMPACT_TRACE, 277 _FLAG_NAME_SUBMODE, 278 _FLAG_NAME_EXCLUDED_OPNAMES, 279 _FLAG_NAME_EXCLUDED_OPTYPES, 280 _FLAG_NAME_INCLUDED_OPNAMES, 281 _FLAG_NAME_INCLUDED_OPTYPES, 282 _FLAG_NAME_TRACE_DIR, 283 _FLAG_NAME_REPORT_FILE, 284 _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR, 285 _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS, 286 _FLAG_NAME_OP_RANGE, 287 _FLAG_DUMP_BEFORE_AFTER_GRAPHS] 288 tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR) 289 if not tensor_tracer_flags: 290 return 291 pos = 0 292 while True: 293 match, _ = TensorTracer._match_next_flag(tensor_tracer_flags, pos) 294 if not match: 295 break 296 flag_name = match.group(1) 297 if flag_name not in valid_flag_names: 298 raise ValueError( 299 'The flag name "%s" passed via the environment variable "%s" ' 300 'is invalid. Valid flag names are:' 301 '\n%s'%(flag_name, _FLAGS_ENV_VAR, valid_flag_names)) 302 pos = match.end() 303 304 @staticmethod 305 def print_flag_values(): 306 """Prints all TensorTracer flags passed via environment variables.""" 307 308 tensor_tracer_flags = os.environ.get(_FLAGS_ENV_VAR) 309 if not tensor_tracer_flags: 310 return 'Env variable "%s" is not set'%_FLAGS_ENV_VAR 311 result = 'Env variable "%s" is set to "%s"\n'%(_FLAGS_ENV_VAR, 312 tensor_tracer_flags) 313 result += 'Individual flag value:\n' 314 pos = 0 315 while True: 316 match, has_value = TensorTracer._match_next_flag( 317 tensor_tracer_flags, pos) 318 if not match: 319 break 320 flag_name = match.group(1) 321 if has_value: 322 flag_value = match.group(2) 323 else: 324 flag_value = None 325 result += ' %s: %s\n'%(flag_name, flag_value) 326 pos = match.end() 327 result += '\n' 328 return result 329 330 @staticmethod 331 def get_flag_value(wanted_flag_name): 332 """Returns the value of a TensorTracer flags. 333 334 Args: 335 wanted_flag_name: the name the the flag we are looking for. 336 337 Returns: 338 A pair where the first element indicates if the flag is 339 found and the second element is the value of the flag. 340 341 Raises: 342 RuntimeError: If supposedly deadcode is reached. 343 """ 344 345 tensor_tracer_flags = os.getenv(_FLAGS_ENV_VAR) 346 if not tensor_tracer_flags: 347 return False, None 348 pos = 0 349 while True: 350 match, has_value = TensorTracer._match_next_flag( 351 tensor_tracer_flags, pos) 352 if not match: 353 return False, None 354 flag_name = match.group(1) 355 if has_value: 356 flag_value = match.group(2) 357 else: 358 flag_value = None 359 if flag_name == wanted_flag_name: 360 return True, flag_value 361 pos = match.end() 362 raise RuntimeError('Should not reach here.') 363 364 @staticmethod 365 def flag_value_to_re_list(flag_name): 366 """Converts list of strings to compiled RE.""" 367 368 re_list = [] 369 found, flag_value = TensorTracer.get_flag_value(flag_name) 370 if not found or not flag_value: 371 return re_list 372 list_of_values = flag_value.split() 373 for v in list_of_values: 374 r = re.compile(v) 375 re_list.append(r) 376 return re_list 377 378 @staticmethod 379 def _is_flag_on(flag_name): 380 """Returns True if the given flag is on.""" 381 382 found, flag_value = TensorTracer.get_flag_value(flag_name) 383 if not found: 384 return False 385 if flag_value is None: 386 return True 387 # Depends on the flag value. 388 flag_value = flag_value.lower() 389 enabled = flag_value in ['1', 't', 'true', 'y', 'yes'] 390 return enabled 391 392 @staticmethod 393 def is_enabled(): 394 """Returns True if TensorTracer is enabled.""" 395 396 return TensorTracer._is_flag_on(_FLAG_NAME_ENABLE) 397 398 @staticmethod 399 def use_test_undeclared_outputs_dir(): 400 """Decides the output directory of the report and trace files. 401 402 Args: 403 None. 404 405 Returns: 406 True if the output files should be written to the 407 test-undeclared-outputs-directory defined via an 408 env variable. 409 """ 410 411 return TensorTracer._is_flag_on( 412 _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR) 413 414 @staticmethod 415 def use_compact_trace(): 416 return TensorTracer._is_flag_on( 417 _FLAG_NAME_USE_COMPACT_TRACE) 418 419 @staticmethod 420 def check_device_type(device_type): 421 """Checks if the given device type is valid.""" 422 423 if device_type not in [_DEVICE_TYPE_TPU, _DEVICE_TYPE_CPU]: 424 raise ValueError('Invalid device_type "%s"'%device_type) 425 426 @staticmethod 427 def check_trace_mode(trace_mode): 428 """Checks if the given trace mode is valid.""" 429 430 valid_trace_modes = [_TRACE_MODE_NAN_INF, _TRACE_MODE_PART_TENSOR, 431 _TRACE_MODE_FULL_TENSOR, _TRACE_MODE_NORM, 432 _TRACE_MODE_MAX_ABS] 433 if trace_mode not in valid_trace_modes: 434 raise ValueError('Invalid trace mode "%s" given to the Tensor_Tracer.' 435 'Valid trace modes are: %s'%(trace_mode, 436 valid_trace_modes)) 437 438 @staticmethod 439 def check_submode(submode): 440 """Checks if the given submode is valid.""" 441 442 if not submode: 443 return 444 valid_submodes = [_SUBMODE_DETAILED, _SUBMODE_BRIEF] 445 if submode not in valid_submodes: 446 raise ValueError('Invalid submode "%s" given to the Tensor_Tracer.' 447 'Valid submodes are: %s'%(submode, 448 valid_submodes)) 449 450 @staticmethod 451 def loop_cond_op(op): 452 return op.type in ('LoopCond', 'RefLoopCond') 453 454 @staticmethod 455 def while_loop_op(op): 456 """Returns true if op is one of the special ops of in a while loop. 457 458 Args: 459 op: A tf.Operation. 460 461 Returns: 462 True if the given op is one of [Switch, Merge, Enter, Exit, 463 NextIteration, LoopCond], which are all building blocks for TF while 464 loops. 465 """ 466 return (control_flow_util.IsLoopSwitch(op) or 467 control_flow_util.IsLoopMerge(op) or 468 control_flow_util.IsLoopEnter(op) or 469 control_flow_util.IsLoopExit(op) or 470 TensorTracer.loop_cond_op(op) or 471 op.type in ('RefNextIteration', 'NextIteration')) 472 473 @staticmethod 474 def unsafe_op(op): 475 """Returns True if this op is not safe to be traced.""" 476 477 if control_flow_util.IsInCond(op): 478 return True 479 # Reasons for not including following op types: 480 # Assign: cause incorrect result with CPU tracing. 481 if op.type in ['Assign']: 482 return True 483 return False 484 485 @staticmethod 486 def device_mismatch(device_type, op): 487 if device_type == _DEVICE_TYPE_TPU: 488 # pylint: disable=protected-access 489 return tpu._TPU_REPLICATE_ATTR not in op.node_def.attr 490 # pylint: enable=protected-access 491 return False 492 493 @staticmethod 494 def unsafe_scalar_trace(op): 495 """Return true if scalar output tensor from Op is not safe to be traced.""" 496 497 # Tracing the following causes cycle in the graph on TPU. 498 if op.type in ['LoopCond', 'Enter', 'Merge', 'Const', 499 'Switch', 'Less', 'ReadVariableOp']: 500 return True 501 # Tracing the following will cause casting-issue 502 # with the norm tracing mode or other compilation issues on CPU. 503 if op.type in ['VarHandleOp', 'IteratorToStringHandle', 504 'IteratorGetNext', 'OneShotIterator', 505 'IteratorV2', 'MakeIterator', 506 'BatchDatasetV2', 'MapDataset', 507 'FixedLengthRecordDataset', 'TakeDataset', 'ZipDataset', 508 'Placeholder', 'PlaceholderWithDefault', 'StridedSlice']: 509 return True 510 return False 511 512 @staticmethod 513 def less_interesting_op(op): 514 """Returns True if the given Op is not an interesting one to be traced.""" 515 516 found, _ = TensorTracer.get_flag_value( 517 _FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS) 518 if found: 519 # users force to include all ops. 520 return False 521 # Following ops are highly unlikey to cause bugs. 522 return op.type in ['Const', 'Identity', 'Cast', 'Shape'] 523 524 @staticmethod 525 def reason(op_idx, details): 526 """Returns reason why the Op at op_idx is traced or not.""" 527 528 return '%d %s'%(op_idx, details) 529 530 @staticmethod 531 def topological_sort(g): 532 """Performs topological sort on the given graph. 533 534 Args: 535 g: the graph. 536 537 Returns: 538 A pair where the first element indicates if the topological 539 sort succeeded (True if there is no cycle found; False if a 540 cycle is found) and the second element is either the sorted 541 list of nodes or the cycle of nodes found. 542 """ 543 def _is_loop_edge(op): 544 """Returns true if the op is the end of a while-loop creating a cycle.""" 545 return op.type in ['NextIteration'] 546 547 def _in_op_degree(op): 548 """Returns the number of incoming edges to the given op. 549 550 The edge calculation skips the edges that come from 'NextIteration' ops. 551 NextIteration creates a cycle in the graph. We break cycles by treating 552 this op as 'sink' and ignoring all outgoing edges from it. 553 Args: 554 op: Tf.Operation 555 Returns: 556 the number of incoming edges. 557 """ 558 count = 0 559 for op in op.control_inputs + [in_tensor.op for in_tensor in op.inputs]: 560 if not _is_loop_edge(op): 561 count += 1 562 return count 563 564 sorted_ops = [] 565 op_in_degree = {op: _in_op_degree(op) for op in g.get_operations()} 566 567 frontier = [op for (op, degree) in op_in_degree.items() if degree == 0] 568 while frontier: 569 op = frontier.pop() 570 # Remove the op from graph, and remove its outgoing edges. 571 sorted_ops.append(op) 572 if _is_loop_edge(op): 573 continue 574 # pylint: disable=protected-access 575 consumers = list(op._control_outputs) 576 # pylint: enable=protected-access 577 for out_tensor in op.outputs: 578 consumers += [consumer_op for consumer_op in out_tensor.consumers()] 579 580 for consumer in consumers: 581 # For each deleted edge shift the bucket of the vertex. 582 op_in_degree[consumer] -= 1 583 if op_in_degree[consumer] == 0: 584 frontier.append(consumer) 585 if op_in_degree[consumer] < 0: 586 raise ValueError('consumer:%s degree mismatch'%consumer.name) 587 588 left_ops = set([op for (op, degree) in op_in_degree.items() if degree > 0]) 589 if left_ops: 590 return (False, left_ops) 591 else: 592 assert len(g.get_operations()) == len(sorted_ops) 593 return (True, sorted_ops) 594 595 @staticmethod 596 def _make_op_and_tensor_maps(op_list): 597 """Creates various maps and lists from op_list. 598 599 Args: 600 op_list: a list of Ops 601 602 Returns: 603 opname_idx_map: a map from Op's name to its index in op_list. 604 tensor_list: a list of output tensors of the Ops in op_list. 605 tensorname_idx_map: a map from output tensor name to its index 606 in tensor_list. 607 """ 608 609 opname_idx_map = {} 610 tensor_list = [] 611 tensorname_idx_map = {} 612 for op_id, op in enumerate(op_list): 613 if op.name in opname_idx_map: 614 raise ValueError('Duplicated Op name: %s'%op.name) 615 opname_idx_map[op.name] = op_id 616 for output_tensor in op.outputs: 617 if output_tensor.name not in tensorname_idx_map: 618 tensor_list.append(output_tensor) 619 tensorname_idx_map[output_tensor.name] = len(tensor_list)-1 620 return (opname_idx_map, tensor_list, tensorname_idx_map) 621 622 def __init__(self): 623 """Initializes a TensorTracer. 624 625 Sets the various member fields from the flags (if given) or the defaults. 626 """ 627 self._version = 'use-outside-compilation' 628 self._device_type = None 629 TensorTracer.validate_flag_names() 630 found, self._trace_mode = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_MODE) 631 if not found or not self._trace_mode: 632 self._trace_mode = _TRACE_MODE_NAN_INF 633 TensorTracer.check_trace_mode(self._trace_mode) 634 found, self._submode = TensorTracer.get_flag_value(_FLAG_NAME_SUBMODE) 635 if not found or not self._submode: 636 self._submode = _SUBMODE_DETAILED 637 TensorTracer.check_submode(self._submode) 638 self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE 639 self._instrument_records = {} 640 self._set_trace_dir() 641 self._set_report_file() 642 self._set_op_range() 643 self._set_excluded_opnames() 644 self._set_excluded_optypes() 645 self._set_included_opnames() 646 self._set_included_optypes() 647 self._num_replicas = None 648 self._num_replicas_per_host = None 649 self._num_hosts = None 650 self._replica_id = None 651 _, self._graph_dump_path = TensorTracer.get_flag_value( 652 _FLAG_DUMP_BEFORE_AFTER_GRAPHS) 653 654 def _add_replica_id_to_graph(self): 655 """Adds nodes for computing the replica ID to the graph.""" 656 657 if self._num_replicas: 658 with ops.control_dependencies(None): 659 # Uses None as dependency to run outside of TPU graph rewrites. 660 self._replica_id = tpu_ops.tpu_replicated_input( 661 list(range(self._num_replicas)), 662 name='tt_replica_id') 663 else: 664 self._replica_id = 'unknown' 665 666 def _set_trace_dir(self): 667 found, self._trace_dir = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_DIR) 668 if found and self._trace_dir \ 669 and TensorTracer.use_test_undeclared_outputs_dir(): 670 raise ValueError('Cannot not use --%s and --%s at the same time' 671 %(_FLAG_NAME_TRACE_DIR, 672 _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)) 673 if TensorTracer.use_test_undeclared_outputs_dir(): 674 self._trace_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) 675 676 def _set_report_file(self): 677 """Sets the path of the output report file.""" 678 679 found, self._report_file_path = TensorTracer.get_flag_value( 680 _FLAG_NAME_REPORT_FILE) 681 if found and self._report_file_path \ 682 and TensorTracer.use_test_undeclared_outputs_dir(): 683 if os.path.isabs(self._report_file_path): 684 raise ValueError('If use_test_undeclared_outputs_dir is set,' 685 'report_file_path cannot be an absolute path (%s)' 686 %self._report_file_path) 687 outputs_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR) 688 self._report_file_path = os.path.join(outputs_dir, 689 self._report_file_path) 690 if not self._report_file_path: 691 self._report_file = None 692 return 693 try: 694 self._report_file = gfile.Open(self._report_file_path, 'w') 695 except IOError as e: 696 raise e 697 698 def _close_report_file(self): 699 if self._report_file: 700 self._report_file.close() 701 702 def _set_op_range(self): 703 """Sets the index range of the Ops that we will consider tracing.""" 704 705 found, op_range = TensorTracer.get_flag_value(_FLAG_NAME_OP_RANGE) 706 if not found or not op_range: 707 self._op_range = (-1, -1) # this means including all ops. 708 return 709 match = _OP_RANGE_PAT.match(op_range) 710 if not match: 711 self._op_range = (-1, -1) # this means including all ops. 712 return 713 self._op_range = (int(match.group(1)), int(match.group(2))) 714 715 def _inside_op_range(self, idx): 716 """Return True if the given index is inside the selected range.""" 717 718 if idx < self._op_range[0]: 719 return False 720 return self._op_range[1] < 0 or idx <= self._op_range[1] 721 722 def _set_excluded_opnames(self): 723 self._excluded_opname_re_list = TensorTracer.flag_value_to_re_list( 724 _FLAG_NAME_EXCLUDED_OPNAMES) 725 726 def _set_excluded_optypes(self): 727 self._excluded_optype_re_list = TensorTracer.flag_value_to_re_list( 728 _FLAG_NAME_EXCLUDED_OPTYPES) 729 730 def _set_included_opnames(self): 731 self._included_opname_re_list = TensorTracer.flag_value_to_re_list( 732 _FLAG_NAME_INCLUDED_OPNAMES) 733 734 def _set_included_optypes(self): 735 self._included_optype_re_list = TensorTracer.flag_value_to_re_list( 736 _FLAG_NAME_INCLUDED_OPTYPES) 737 738 def _is_user_included_op(self, op): 739 for opname_re in self._included_opname_re_list: 740 if opname_re.match(op.name): 741 return True 742 for optype_re in self._included_optype_re_list: 743 if optype_re.match(op.type): 744 return True 745 return False 746 747 def _is_user_excluded_op(self, op): 748 for opname_re in self._excluded_opname_re_list: 749 if opname_re.match(op.name): 750 return True 751 for optype_re in self._excluded_optype_re_list: 752 if optype_re.match(op.type): 753 return True 754 return False 755 756 def _use_tensor_values_cache(self): 757 """Returns True if immediate tensors should be first saved to a cache.""" 758 759 if self._trace_mode not in set([_TRACE_MODE_NAN_INF, 760 _TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS]): 761 return False 762 if self._trace_dir and _trace_files_need_precreated(self._trace_dir): 763 return True 764 if TensorTracer.use_compact_trace(): 765 return True 766 return False 767 768 def _save_tensor_value_to_cache_op(self, graph, cache_idx, updates): 769 """Returns an Op that will save the given updates to an entry in the cache.""" 770 771 cache = _get_tensor_values_cache(graph) 772 indices = constant_op.constant([cache_idx]) 773 return state_ops.scatter_update(cache, indices, updates).op 774 775 def _write_report(self, content): 776 """Writes the given content to the report.""" 777 778 line = '%s %s'%(_TRACER_LOG_PREFIX, content) 779 if self._report_file: 780 self._report_file.write(line) 781 else: 782 logging.info(line) 783 784 def _write_config_section(self): 785 """Writes the config section of the report.""" 786 787 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_CONFIG)) 788 self._write_report('%s %s\n'%(_FIELD_NAME_VERSION, self._version)) 789 self._write_report('%s %s\n'%(_FIELD_NAME_DEVICE, self._device_type)) 790 self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE, self._trace_mode)) 791 self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE, self._submode)) 792 self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS, self._num_replicas)) 793 self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS_PER_HOST, 794 self._num_replicas_per_host)) 795 self._write_report('%s %s\n'%(_FIELD_NAME_NUM_HOSTS, self._num_hosts)) 796 self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG)) 797 798 def _write_reason_section(self): 799 """Writes the reason section of the report.""" 800 801 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_REASON)) 802 for key in sorted(self._instrument_records): 803 self._write_report('"%s" %s\n'%(key, self._instrument_records[key])) 804 self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_REASON)) 805 806 def _write_op_list_section(self, op_list): 807 """Writes the Op-list section of the report.""" 808 809 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_OP_LIST)) 810 self._write_report('%s %d\n'%(_FIELD_NAME_NUM_OPS, len(op_list))) 811 for i in range(0, len(op_list)): 812 op = op_list[i] 813 line = '%d "%s" %s'%(i, op.name, op.type) 814 for out_tensor in op.outputs: 815 if out_tensor.name not in self._tensorname_idx_map: 816 raise ValueError( 817 'out_tensor %s is not in tensorname_idx_map'%out_tensor.name) 818 line += ' %d'%self._tensorname_idx_map[out_tensor.name] 819 line += '\n' 820 self._write_report(line) 821 self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_OP_LIST)) 822 823 def _write_tensor_list_section(self, tensor_list, opname_idx_map): 824 """Writes the tensor-list section of the report.""" 825 826 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, 827 _SECTION_NAME_TENSOR_LIST)) 828 self._write_report('%s %d\n'%(_FIELD_NAME_NUM_TENSORS, len(tensor_list))) 829 for i in range(0, len(tensor_list)): 830 tensor = tensor_list[i] 831 line = '%d "%s"'%(i, tensor.name) 832 for consumer_op in tensor.consumers(): 833 if consumer_op.name not in opname_idx_map: 834 raise ValueError( 835 'consumer_op %s is not in opname_idx_map'%consumer_op.name) 836 line += ' %d'%opname_idx_map[consumer_op.name] 837 line += '\n' 838 self._write_report(line) 839 self._write_report('%s %s\n'%(_MARKER_SECTION_END, 840 _SECTION_NAME_TENSOR_LIST)) 841 842 def _write_cache_index_map_section(self): 843 """Writes the mapping from cache index to tensor index to the report.""" 844 845 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, 846 _SECTION_NAME_CACHE_INDEX_MAP)) 847 self._write_report('%s %d\n'%(_FIELD_NAME_NUM_CACHE_INDICES, 848 len(self._cache_idx_to_tensor_idx))) 849 for cache_idx in range(0, len(self._cache_idx_to_tensor_idx)): 850 tensor_idx = self._cache_idx_to_tensor_idx[cache_idx] 851 line = '%d %d\n'%(cache_idx, tensor_idx) 852 self._write_report(line) 853 self._write_report('%s %s\n'%(_MARKER_SECTION_END, 854 _SECTION_NAME_CACHE_INDEX_MAP)) 855 856 def _write_graph_section(self, succeed, sorted_or_cycle): 857 """Writes the graph section of the report.""" 858 859 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, _SECTION_NAME_GRAPH)) 860 self._write_report('%s %s\n'%(_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED, 861 succeed)) 862 l = list(sorted_or_cycle) 863 for i in range(0, len(l)): 864 self._write_report('%d "%s"\n'%(i, l[i].name)) 865 self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_GRAPH)) 866 867 def _preprocess_traced_tensor(self, tensor): 868 """Computes NAN/Norm/Max on TPUs before sending to CPU. 869 870 Args: 871 tensor: The tensor to be traced. 872 Returns: 873 A tensor that should be input to the trace_function. 874 Raises: 875 RuntimeError: If the trace mode is invalid. 876 """ 877 878 def _detect_nan_inf(tensor): 879 """Trace function for detecting any NaN/Inf in the tensor.""" 880 881 if tensor.dtype.is_floating: 882 mask = math_ops.reduce_any( 883 gen_math_ops.logical_or( 884 gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor))) 885 output_tensor = control_flow_ops.cond(mask, 886 lambda: constant_op.constant(1.0), 887 lambda: constant_op.constant(0.0)) 888 else: 889 output_tensor = constant_op.constant(0.0) 890 # The shape has to be 1. Set it if it does not have the information. 891 output_tensor = array_ops.reshape(output_tensor, [1]) 892 return output_tensor 893 894 def _show_norm(tensor): 895 tensor = math_ops.cast(tensor, dtypes.float32) 896 output_tensor = linalg_ops.norm(tensor) 897 # The shape has to be 1. Set it if it does not have the information. 898 output_tensor = array_ops.reshape(output_tensor, [1]) 899 return output_tensor 900 901 def _show_max_abs(tensor): 902 tensor = math_ops.cast(tensor, dtypes.float32) 903 output_tensor = math_ops.reduce_max(math_ops.abs(tensor)) 904 zero = constant_op.constant(0, dtypes.float32) 905 output_tensor = gen_math_ops.maximum(zero, output_tensor) 906 # The shape has to be 1. Set it if it does not have the information. 907 output_tensor = array_ops.reshape(output_tensor, [1]) 908 return output_tensor 909 910 if self._trace_mode == _TRACE_MODE_NAN_INF: 911 return _detect_nan_inf(tensor) 912 if self._trace_mode == _TRACE_MODE_PART_TENSOR: 913 return tensor 914 if self._trace_mode == _TRACE_MODE_FULL_TENSOR: 915 return tensor 916 if self._trace_mode == _TRACE_MODE_NORM: 917 return _show_norm(tensor) 918 if self._trace_mode == _TRACE_MODE_MAX_ABS: 919 return _show_max_abs(tensor) 920 raise RuntimeError( 921 'Tensor trace fun for %s is not yet implemented' % self._trace_mode) 922 923 def _make_tensor_trace_fun(self, tensor_name): 924 """Makes the tensor tracing function called by outside compilation. 925 926 Args: 927 tensor_name: name of the tensor being traced. 928 929 Returns: 930 A function to be passed as the first argument to outside compilation. 931 932 Raises: 933 RuntimeError: If the trace mode is invalid. 934 """ 935 936 def _print_tensor(tensor_name, num_elements, tensor, output_tensor): 937 """Prints a tensor value to a file. 938 939 Args: 940 tensor_name: name of the tensor being traced. 941 num_elements: number of elements to print (-1 means print all). 942 tensor: the tensor needs to be returned. 943 output_tensor: the tensor needs to be printed. 944 945 Returns: 946 The same tensor passed via the "tensor" argument. 947 948 Raises: 949 ValueError: If tensor_name is not already in 950 self._tensorname_idx_map. 951 """ 952 953 if self._submode == _SUBMODE_BRIEF: 954 if tensor_name not in self._tensorname_idx_map: 955 raise ValueError( 956 'Tensor name %s is not in the tensorname_idx_map'%tensor_name) 957 msg = '%d'%self._tensorname_idx_map[tensor_name] 958 else: 959 msg = '"%s"'%tensor_name 960 961 if self._trace_dir: 962 output_path = os.path.join(self._trace_dir, _TRACE_FILE_NAME) 963 output_stream = _OUTPUT_STREAM_ESCAPE + output_path 964 else: 965 output_stream = sys.stderr 966 return logging_ops.print_v2(msg, array_ops.shape(output_tensor), 967 '@', self._replica_id, 968 '\n', output_tensor, '\n', 969 summarize=num_elements, 970 output_stream=output_stream) 971 972 def _show_part_tensor(tensor): 973 """Trace function for printing part of the tensor.""" 974 975 return _print_tensor(tensor_name, self._part_tensor_size, 976 tensor, tensor) 977 978 def _show_full_tensor(tensor): 979 """Trace function for printing the entire tensor.""" 980 981 return _print_tensor(tensor_name, -1, tensor, tensor) 982 983 if self._trace_mode == _TRACE_MODE_PART_TENSOR: 984 return _show_part_tensor 985 # The input tensor has a shape of "[1]" for _TRACE_MODE_NAN_INF, 986 # _TRACE_MODE_NORM, and _TRACE_MODE_MAX_ABS, as related computations are 987 # performed within TPUs and only their results are transferred to CPU. 988 # Simply, print the full tensor for these trace modes. 989 if self._trace_mode in [ 990 _TRACE_MODE_NAN_INF, _TRACE_MODE_NORM, _TRACE_MODE_FULL_TENSOR, 991 _TRACE_MODE_MAX_ABS 992 ]: 993 return _show_full_tensor 994 995 raise RuntimeError('Tensor trace fun for %s is not yet implemented' 996 %self._trace_mode) 997 998 def _skip_op(self, op_id, op, user_included, user_excluded, 999 in_exec_path=True): 1000 """Returns True if we should not trace Op.""" 1001 1002 if TensorTracer.while_loop_op(op): 1003 self._instrument_records[op.name] = TensorTracer.reason( 1004 op_id, _REASON_WHILELOOP_OP) 1005 return True 1006 if TensorTracer.unsafe_op(op): 1007 self._instrument_records[op.name] = TensorTracer.reason( 1008 op_id, _REASON_UNSAFE_OP) 1009 return True 1010 if TensorTracer.device_mismatch(self._device_type, op): 1011 self._instrument_records[op.name] = TensorTracer.reason( 1012 op_id, _REASON_DEVICE_MISMATCH) 1013 return True 1014 if not in_exec_path: 1015 self._instrument_records[op.name] = TensorTracer.reason( 1016 op_id, _REASON_NOT_EXECUTED) 1017 return True 1018 1019 if not self._inside_op_range(op_id): 1020 self._instrument_records[op.name] = TensorTracer.reason( 1021 op_id, _REASON_OUTSIDE_OP_RANGE) 1022 return True 1023 if TensorTracer.less_interesting_op(op): 1024 self._instrument_records[op.name] = TensorTracer.reason( 1025 op_id, _REASON_LESS_INTERESTING_OP) 1026 return True 1027 if user_included: 1028 self._instrument_records[op.name] = TensorTracer.reason( 1029 op_id, _REASON_USER_INCLUDED) 1030 return False 1031 if user_excluded: 1032 self._instrument_records[op.name] = TensorTracer.reason( 1033 op_id, _REASON_USER_EXCLUDED) 1034 return True 1035 return False 1036 1037 def _skip_tensor(self, op_id, out_tensor, user_included, 1038 user_excluded): 1039 """Returns True if we should not trace out_tensor.""" 1040 1041 # Skips a tensor if the tensor has a non-numeric type. 1042 # Note: we cannot use check_ops.is_numeric_tensor(out_tensor) 1043 # because it also excludes tensors with dtypes, bool, and 1044 # float32_ref, which we actually want to trace. 1045 non_numeric_tensor_types = set([dtypes.variant, dtypes.resource, 1046 dtypes.string]) 1047 if out_tensor.dtype in non_numeric_tensor_types: 1048 self._instrument_records[out_tensor.name] = TensorTracer.reason( 1049 op_id, _REASON_NON_NUMERIC_TENSOR) 1050 return True 1051 # Skip a tensor if it feeds a special while loop op. 1052 if [consumer for consumer in out_tensor.consumers() if 1053 TensorTracer.while_loop_op(consumer)]: 1054 self._instrument_records[out_tensor.name] = TensorTracer.reason( 1055 op_id, _REASON_FEEDS_WHILELOOP_OP) 1056 return True 1057 if user_included: 1058 self._instrument_records[out_tensor.name] = TensorTracer.reason( 1059 op_id, _REASON_USER_INCLUDED) 1060 return False 1061 if user_excluded: 1062 self._instrument_records[out_tensor.name] = TensorTracer.reason( 1063 op_id, _REASON_USER_EXCLUDED) 1064 return True 1065 if not out_tensor.get_shape().is_fully_defined(): 1066 # If trace mode is nan-inf, norm or max, then the tensor will be reduced 1067 # to a scalar before the outside compilation call. 1068 if self._trace_mode in [ 1069 _TRACE_MODE_NAN_INF, _TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS 1070 ]: 1071 self._instrument_records[out_tensor.name] = TensorTracer.reason( 1072 op_id, _REASON_TENSOR_GET_TRACED) 1073 return False 1074 else: 1075 self._instrument_records[out_tensor.name] = TensorTracer.reason( 1076 op_id, _REASON_DYNAMIC_SHAPE) 1077 return True 1078 rank = len(out_tensor.shape) 1079 if rank < 1: 1080 # scalar 1081 if TensorTracer.unsafe_scalar_trace(out_tensor.op): 1082 self._instrument_records[out_tensor.name] = TensorTracer.reason( 1083 op_id, _REASON_UNSAFE_SCALAR) 1084 return True 1085 else: 1086 self._instrument_records[out_tensor.name] = TensorTracer.reason( 1087 op_id, _REASON_SCALAR_GET_TRACED) 1088 return False 1089 else: 1090 # tensor 1091 self._instrument_records[out_tensor.name] = TensorTracer.reason( 1092 op_id, _REASON_TENSOR_GET_TRACED) 1093 return False 1094 1095 def _filter_execution_path_operations(self, operations, fetches): 1096 """Returns the set of ops in the execution path to compute given fetches.""" 1097 1098 # If no fetch provided, then return all operations. 1099 if fetches is None: 1100 return set(operations) 1101 # Convert to list, if a single element is provided. 1102 if not isinstance(fetches, (list, tuple)): 1103 fetches = [fetches] 1104 # If a tensor is given as fetch, convert it to op. 1105 op_fetches = [] 1106 for fetch in fetches: 1107 if isinstance(fetch, ops.Operation): 1108 op_fetches.append(fetch) 1109 elif isinstance(fetch, ops.Tensor): 1110 op_fetches.append(fetch.op) 1111 else: 1112 raise RuntimeError('Given fetch:%s is neither a tensor nor an op.' 1113 %fetch) 1114 1115 execution_path_operations = set(op_fetches) 1116 traverse_stack = list(op_fetches) 1117 while True: 1118 if not traverse_stack: 1119 break 1120 head_op = traverse_stack.pop() 1121 input_ops = [tensor_input.op for tensor_input in head_op.inputs] 1122 input_ops.extend(head_op.control_inputs) 1123 1124 for input_op in input_ops: 1125 if input_op not in execution_path_operations: 1126 # Filter out loop condition operations, tracing them causes a cycle. 1127 # Trace only the loop-body. 1128 if TensorTracer.loop_cond_op(input_op): 1129 continue 1130 execution_path_operations.add(input_op) 1131 traverse_stack.append(input_op) 1132 return execution_path_operations 1133 1134 def _determine_traced_tensors(self, graph, ops_in_exec_path): 1135 """Determines the tensors that will be traced.""" 1136 1137 self._traced_tensorname_to_cache_idx_map = {} 1138 self._cache_idx_to_tensor_idx = [] 1139 operations = graph.get_operations() 1140 checkpoint_operations = self._get_checkpoints(graph) 1141 for op_id, op in enumerate(operations): 1142 if checkpoint_operations and op.name not in checkpoint_operations: 1143 continue 1144 user_included = self._is_user_included_op(op) 1145 user_excluded = self._is_user_excluded_op(op) 1146 in_exec_path = op in ops_in_exec_path 1147 if self._skip_op(op_id, op, user_included, user_excluded, in_exec_path): 1148 continue 1149 for i in range(len(op.outputs)): 1150 out_tensor = op.outputs[i] 1151 if self._skip_tensor(op_id, out_tensor, user_included, 1152 user_excluded): 1153 continue 1154 tensor_name = out_tensor.name 1155 if tensor_name in self._traced_tensorname_to_cache_idx_map: 1156 raise ValueError( 1157 'Tensor name %s should not be already in ' 1158 'traced_tensorname_to_cache_idx_map'%tensor_name) 1159 if tensor_name not in self._tensorname_idx_map: 1160 raise ValueError( 1161 'Tensor name %s is not in the tensorname_idx_map'%tensor_name) 1162 tensor_idx = self._tensorname_idx_map[tensor_name] 1163 cache_idx = len(self._traced_tensorname_to_cache_idx_map) 1164 self._traced_tensorname_to_cache_idx_map[tensor_name] = cache_idx 1165 self._cache_idx_to_tensor_idx.append(tensor_idx) 1166 if len(self._traced_tensorname_to_cache_idx_map) != len( 1167 self._cache_idx_to_tensor_idx): 1168 raise RuntimeError('len(self._traced_tensorname_to_cache_idx_map) != ' 1169 'len(self._cache_idx_to_tensor_idx') 1170 1171 def _check_trace_files(self): 1172 """Checks if any requirements for trace files are satisfied.""" 1173 1174 if not self._trace_dir: 1175 # traces will be written to stderr. No need to check trace files. 1176 return 1177 if _trace_files_need_precreated(self._trace_dir): 1178 for replica_id in range(0, self._num_replicas): 1179 trace_file_path = os.path.join( 1180 self._trace_dir, 1181 _COMPACT_TRACE_FILE_PREFIX) + '%d'%replica_id 1182 if not gfile.Exists(trace_file_path): 1183 raise RuntimeError( 1184 '%s must be pre-created with the ' 1185 'appropriate properties.'%trace_file_path) 1186 else: 1187 if not gfile.Exists(self._trace_dir): 1188 gfile.MkDir(self._trace_dir) 1189 if not gfile.Exists(self._trace_dir): 1190 raise RuntimeError('Failed to create %s'%self._trace_dir) 1191 1192 def _pre_tracing(self, graph, fetches): 1193 """Work needs to be done prior to TPU or CPU tracing.""" 1194 1195 self._check_trace_files() 1196 operations = graph.get_operations() 1197 (opname_idx_map, tensor_list, self._tensorname_idx_map) = ( 1198 TensorTracer._make_op_and_tensor_maps(operations)) 1199 self._write_config_section() 1200 self._write_op_list_section(operations) 1201 self._write_tensor_list_section(tensor_list, opname_idx_map) 1202 # Filter out the operations that won't be executed. 1203 # if fetches=None, then ops_in_exec_path = set(operations) 1204 ops_in_exec_path = self._filter_execution_path_operations(operations, 1205 fetches) 1206 self._determine_traced_tensors(graph, ops_in_exec_path) 1207 self._write_cache_index_map_section() 1208 # Does the topological sort before adding any nodes to the graph. 1209 (succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph) 1210 if self._use_tensor_values_cache(): 1211 _create_tensor_values_cache(graph, 1212 len(self._cache_idx_to_tensor_idx)) 1213 return (ops_in_exec_path, succeed, sorted_or_cycle) 1214 1215 def _post_tracing(self, succeed, sorted_or_cycle): 1216 """Work needs to be done after TPU or CPU tracing.""" 1217 1218 self._write_reason_section() 1219 self._write_graph_section(succeed, sorted_or_cycle) 1220 self._close_report_file() 1221 1222 def _get_checkpoints(self, graph): 1223 """Returns the list of Ops that produce the tensors traced with API. 1224 1225 Args: 1226 graph: the graph of Ops. 1227 1228 Returns: 1229 A set of operation names which should be traced. 1230 """ 1231 1232 self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN, 1233 _TENSOR_TRACER_CHECKPOINT)) 1234 checkpoint_operations = set() 1235 tensor_tracer_variables = graph.get_collection(_TENSOR_TRACER_COLLECTION) 1236 for (tensor, checkpoint_name) in tensor_tracer_variables: 1237 self._write_report('%s %s\n'%(tensor.name, checkpoint_name)) 1238 checkpoint_operations.add(tensor.op.name) 1239 self._write_report('%s %s\n'%(_MARKER_SECTION_END, 1240 _TENSOR_TRACER_CHECKPOINT)) 1241 return checkpoint_operations 1242 1243 def _generate_flush_cache_op(self, graph, start_replica, on_tpu): 1244 """Generates an Op that will flush the cache to file. 1245 1246 Args: 1247 graph: the graph of Ops 1248 start_replica: the ID of the first replica being flushed by this Op. 1249 on_tpu: if the graph is executed on TPU. 1250 1251 Returns: 1252 The Op to flush the cache to file. 1253 """ 1254 def _make_flush_fun(replica_id): 1255 """Makes a function for flushing the cache for the given replica.""" 1256 1257 def _fun(): 1258 """A function that flushes the cache to a file.""" 1259 1260 def _flush_fun(cache): 1261 """Flushes the cache to a file.""" 1262 1263 if isinstance(replica_id, str): 1264 replica_id_str = replica_id 1265 else: 1266 replica_id_str = '%d'%replica_id 1267 if self._trace_dir: 1268 output_path = os.path.join(self._trace_dir, 1269 _COMPACT_TRACE_FILE_PREFIX) \ 1270 + replica_id_str 1271 output_stream = _OUTPUT_STREAM_ESCAPE + output_path 1272 else: 1273 output_stream = sys.stderr 1274 new_step_line = _REPLICA_ID_TAG + replica_id_str 1275 print_op = logging_ops.print_v2( 1276 new_step_line, '\n', 1277 cache, '\n', 1278 summarize=-1, 1279 output_stream=output_stream) 1280 with ops.control_dependencies([print_op]): 1281 return constant_op.constant(0).op 1282 1283 cache = _get_tensor_values_cache(graph) 1284 if on_tpu: 1285 flush_op = tpu.outside_compilation(_flush_fun, cache.value()) 1286 else: 1287 flush_op = _flush_fun(cache.value()) 1288 with ops.control_dependencies([flush_op]): 1289 reset_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE, 1290 dtype=cache.dtype, 1291 shape=cache.shape) 1292 assign_op = state_ops.assign(cache, reset_value).op 1293 with ops.control_dependencies([assign_op]): 1294 return flush_op.outputs[0] 1295 1296 return _fun 1297 1298 def _f(replica_id): 1299 return _make_flush_fun(replica_id) 1300 def _eq(x): 1301 return math_ops.equal(x, self._replica_id) 1302 def _do_nothing(): 1303 return constant_op.constant(0) 1304 1305 return control_flow_ops.case({\ 1306 _eq(start_replica): _f(start_replica), \ 1307 _eq(start_replica+1): _f(start_replica+1), \ 1308 _eq(start_replica+2): _f(start_replica+2), \ 1309 _eq(start_replica+3): _f(start_replica+3), \ 1310 _eq(start_replica+4): _f(start_replica+4), \ 1311 _eq(start_replica+5): _f(start_replica+5), \ 1312 _eq(start_replica+6): _f(start_replica+6), \ 1313 _eq(start_replica+7): _f(start_replica+7), \ 1314 }, 1315 default=_do_nothing, 1316 exclusive=True).op 1317 1318 def _flush_tensor_values_cache(self, graph, tensor_fetches, op_fetches, 1319 on_tpu): 1320 """Flushes the intermediate tensor values in the graph to the cache. 1321 1322 Args: 1323 graph: the graph of Ops 1324 tensor_fetches: list of tensor results returned by the model_fn. 1325 op_fetches: list of ops that are returned by the model_fn, e.g., train_op. 1326 on_tpu: if the graph is executed on TPU. 1327 1328 Returns: 1329 An identical copy of tensor_fetches. 1330 """ 1331 # Add a dependency to op and tensor fetches to make sure that all tracing 1332 # ops are executed before flushing trace results. 1333 with ops.control_dependencies(op_fetches + 1334 [tensor.op for tensor in tensor_fetches]): 1335 flush_cache_op_list = [] 1336 for host in range(self._num_hosts): 1337 start_replica = host * 8 1338 flush_op = self._generate_flush_cache_op(graph, start_replica, on_tpu) 1339 flush_cache_op_list.append(flush_op) 1340 return control_flow_ops.tuple(tensor_fetches, 1341 control_inputs=flush_cache_op_list) 1342 1343 def _process_tensor_fetches(self, tensor_fetches): 1344 """Check that tensor_fetches is not empty and have valid tensors.""" 1345 # If none or empty list. 1346 if tensor_fetches is None: 1347 raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be ' 1348 'None.') 1349 if not isinstance(tensor_fetches, (list, tuple)): 1350 tensor_fetches = [tensor_fetches] 1351 elif not tensor_fetches: 1352 raise RuntimeError('tensor_fetches provided to tensor_tracer cannot be ' 1353 'empty list.') 1354 fetches = [] 1355 for fetch in tensor_fetches: 1356 if isinstance(fetch, ops.Tensor): 1357 fetches.append(fetch) 1358 else: 1359 raise RuntimeError('Given tensor_fetch:%s is not a tensor.' % fetch) 1360 return fetches 1361 1362 def _process_op_fetches(self, op_fetches): 1363 """Check that op_fetches have valid ops.""" 1364 if op_fetches is None: 1365 return [] 1366 1367 if not isinstance(op_fetches, (list, tuple)): 1368 op_fetches = [op_fetches] 1369 1370 fetches = [] 1371 for fetch in op_fetches: 1372 if isinstance(fetch, ops.Operation): 1373 fetches.append(fetch) 1374 else: 1375 logging.warning('Ignoring the given op_fetch:%s, which is not an op.' % 1376 fetch) 1377 return fetches 1378 1379 def _convert_fetches_to_input_format(self, input_fetches, current_fetches): 1380 """Changes current_fetches' format, so that it matches input_fetches.""" 1381 if isinstance(input_fetches, ops.Tensor): 1382 if len(current_fetches) != 1: 1383 raise RuntimeError('Tensor tracer input/output fetches do not match.') 1384 return current_fetches[0] 1385 else: 1386 if len(current_fetches) != len(current_fetches): 1387 raise RuntimeError('Tensor tracer input/output fetches do not match.') 1388 elif isinstance(input_fetches, tuple): 1389 return tuple(current_fetches) 1390 else: 1391 return current_fetches 1392 1393 def _get_op_control_flow_context(self, op): 1394 """Returns the control flow of the given op. 1395 1396 Args: 1397 op: tf.Operation for which the control flow context is requested. 1398 Returns: 1399 op_control_flow_context: which the is control flow context of the given 1400 op. If the operation type is LoopExit, returns the outer control flow 1401 context. 1402 """ 1403 # pylint: disable=protected-access 1404 op_control_flow_context = op._control_flow_context 1405 # pylint: enable=protected-access 1406 if control_flow_util.IsLoopExit(op): 1407 op_control_flow_context = op_control_flow_context.outer_context 1408 return op_control_flow_context 1409 1410 def _trace_execution(self, graph, 1411 tensor_fetches, 1412 op_fetches=None, 1413 on_tpu=True): 1414 """Commong tracing function for both CPU and TPUs. 1415 1416 The caller function should set _device_type, _num_replicas, 1417 _num_replicas_per_host, _num_hosts and _replica_id before calling 1418 _trace_execution. 1419 1420 1421 Args: 1422 graph: the graph of Ops executed on the TPU. 1423 tensor_fetches: a (list,tuple,or a single object) of tensor fetches 1424 returned by model_fn given to session.run. Function must be provided 1425 with as least one tensor to fetch. 1426 op_fetches: A list of op fetches returned by model_fn given to 1427 session.run. op_fetches and tensor_fetches are used to determine the 1428 nodes that will be executed. Can be None. 1429 on_tpu: True if executing on TPU. 1430 1431 Returns: 1432 tensor_fetches: an exact copy of tensor_fetches that has additional 1433 dependencies. 1434 Raises: 1435 RuntimeError: If tensor_fetches is None or empty. 1436 """ 1437 def _cast_unsupported_dtypes(tensor): 1438 """Casts tensor to a supported type.""" 1439 1440 if tensor.dtype.__eq__(dtypes.int64): 1441 # outside-compilation doesn't support int64 input yet. 1442 return math_ops.cast(tensor, dtypes.int32) 1443 if tensor.dtype.__eq__(dtypes.bfloat16) or tensor.dtype.__eq__( 1444 dtypes.float16): 1445 # Since host can't handle bf16, convert tensor to f32. 1446 return math_ops.cast(tensor, dtypes.float32) 1447 return tensor 1448 1449 TensorTracer.check_device_type(self._device_type) 1450 # Check in_tensor_fetches, and op_fetches and convert them to lists. 1451 processed_t_fetches = self._process_tensor_fetches(tensor_fetches) 1452 op_fetches = self._process_op_fetches(op_fetches) 1453 all_fetches = op_fetches + [tensor.op for tensor in processed_t_fetches] 1454 1455 # Filter the set of ops that will be executed, and topological sort. 1456 (exec_op_set, succeed, sorted_or_cycle) = self._pre_tracing(graph, 1457 all_fetches) 1458 1459 tensor_fetch_set = set(processed_t_fetches) 1460 tracing_ops = [] 1461 1462 # pylint: disable=protected-access 1463 current_control_flow_context = graph._get_control_flow_context() 1464 # pylint: enable=protected-access 1465 1466 # Trace ops only if they are in the execution path. 1467 for op in exec_op_set: 1468 for i in range(len(op.outputs)): 1469 out_tensor = op.outputs[i] 1470 tensor_name = out_tensor.name 1471 if tensor_name not in self._traced_tensorname_to_cache_idx_map: 1472 continue 1473 # Create the list of consumers before calling _preprocess_traced_tensor. 1474 # Otherwise, adding control input below, will introduce a cycle in the 1475 # graph. 1476 consumers = out_tensor.consumers() 1477 # Not all consumers may be in the exec path. Filter out the consumers 1478 # to keep the graph simpler. 1479 consumers = [cop for cop in consumers if cop in exec_op_set] 1480 1481 # If there is no consumer of the tensor, there is no need to trace it; 1482 # unless the tensor itself is one of the fetches. 1483 is_a_fetched_tensor = out_tensor in tensor_fetch_set 1484 if (not consumers) and (not is_a_fetched_tensor): 1485 continue 1486 1487 op_control_flow_context = self._get_op_control_flow_context(op) 1488 # pylint: disable=protected-access 1489 graph._set_control_flow_context(op_control_flow_context) 1490 # pylint: enable=protected-access 1491 processed_out_tensor = self._preprocess_traced_tensor(out_tensor) 1492 1493 if on_tpu: 1494 processed_out_tensor = _cast_unsupported_dtypes(processed_out_tensor) 1495 1496 if self._use_tensor_values_cache(): 1497 cache_idx = self._traced_tensorname_to_cache_idx_map[tensor_name] 1498 trace_op = self._save_tensor_value_to_cache_op(graph, 1499 cache_idx, 1500 processed_out_tensor) 1501 elif on_tpu: 1502 trace_op = tpu.outside_compilation( 1503 self._make_tensor_trace_fun(tensor_name), processed_out_tensor) 1504 else: 1505 trace_fun = self._make_tensor_trace_fun(tensor_name) 1506 trace_op = trace_fun(processed_out_tensor) 1507 1508 if is_a_fetched_tensor: 1509 tracing_ops.append(trace_op) 1510 continue 1511 # Add it to all consumers, as some consumers may not be executed if they 1512 # are in a control flow. 1513 for consumer_op in consumers: 1514 # pylint: disable=protected-access 1515 consumer_op._add_control_input(trace_op) 1516 # pylint: enable=protected-access 1517 1518 # pylint: disable=protected-access 1519 graph._set_control_flow_context(current_control_flow_context) 1520 # pylint: enable=protected-access 1521 if tracing_ops: 1522 # If we are tracing a fetched tensor, their dependency is stored in 1523 # tracing_ops. 1524 processed_t_fetches = control_flow_ops.tuple(processed_t_fetches, 1525 control_inputs=tracing_ops) 1526 if self._use_tensor_values_cache(): 1527 processed_t_fetches = self._flush_tensor_values_cache(graph, 1528 processed_t_fetches, 1529 op_fetches, 1530 on_tpu=on_tpu) 1531 self._post_tracing(succeed, sorted_or_cycle) 1532 # processed_t_fetches is a list at this point. Convert it to the same 1533 # format as given in tensor_fetches. 1534 return self._convert_fetches_to_input_format(tensor_fetches, 1535 processed_t_fetches) 1536 1537 def trace_tpu(self, graph, 1538 tensor_fetches, 1539 op_fetches=None, 1540 num_replicas=None, 1541 num_replicas_per_host=None, 1542 num_hosts=None): 1543 """Traces the tensors generated by TPU Ops in a TF graph. 1544 1545 Args: 1546 graph: the graph of Ops executed on the TPU. 1547 tensor_fetches: a (list,tuple,or a single object) of tensor fetches 1548 returned by model_fn given to session.run. Function must be provided 1549 with as least one tensor to fetch. 1550 op_fetches: A list of op fetches returned by model_fn given to 1551 session.run. op_fetches and tensor_fetches are used to determine the 1552 nodes that will be executed. Can be None. 1553 num_replicas: number of replicas used on the TPU. 1554 num_replicas_per_host: number of replicas per TPU host. 1555 num_hosts: total number of TPU hosts. 1556 1557 Returns: 1558 tensor_fetches: an exact copy of tensor_fetches that has additional 1559 dependencies. 1560 Raises: 1561 RuntimeError: If num_replicas_per_host > 8. 1562 RuntimeError: If tensor_fetches is None or empty. 1563 """ 1564 1565 if graph in TensorTracer._traced_graphs: 1566 logging.warning('Graph is already rewritten with tensor tracer, ignoring ' 1567 'multiple calls.') 1568 return tensor_fetches 1569 else: 1570 TensorTracer._traced_graphs.add(graph) 1571 self._device_type = _DEVICE_TYPE_TPU 1572 self._num_replicas = num_replicas 1573 self._num_replicas_per_host = num_replicas_per_host 1574 self._num_hosts = num_hosts 1575 if self._num_replicas is not None: 1576 if self._num_replicas_per_host is None: 1577 self._num_replicas_per_host = 8 1578 if self._num_hosts is None: 1579 self._num_hosts = num_replicas // self._num_replicas_per_host + \ 1580 (num_replicas % self._num_replicas_per_host > 0) 1581 1582 if self._num_replicas_per_host > 8: 1583 # Checks for the assumption in _generate_flush_cache_op(). 1584 raise RuntimeError('num_replicas_per_host (%d) is ' 1585 'greater than 8'%self._num_replicas_per_host) 1586 if self._graph_dump_path: 1587 graph_io.write_graph(graph, self._graph_dump_path, 1588 'graph_before_tt.pbtxt') 1589 with graph.as_default(): 1590 self._add_replica_id_to_graph() 1591 tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches, 1592 on_tpu=True) 1593 if self._graph_dump_path: 1594 graph_io.write_graph(graph, self._graph_dump_path, 1595 'graph_after_tt.pbtxt') 1596 return tensor_fetches 1597 1598 def trace_cpu(self, graph, tensor_fetches, op_fetches=None): 1599 """Traces the tensors generated by CPU Ops in a TF graph. 1600 1601 Args: 1602 graph: the graph of Ops executed on the CPU. 1603 tensor_fetches: a (list,tuple,or a single object) of tensor fetches 1604 returned by model_fn given to session.run. Function must be provided 1605 with as least one tensor to fetch. 1606 op_fetches: A list of op fetches returned by model_fn given to 1607 session.run. op_fetches and tensor_fetches are used to determine the 1608 nodes that will be executed. Can be None. 1609 1610 Returns: 1611 tensor_fetches: an exact copy of tensor_fetches that has additional 1612 dependencies. 1613 Raises: 1614 RuntimeError: If tensor_fetches is None or empty. 1615 """ 1616 1617 if graph in TensorTracer._traced_graphs: 1618 logging.warning('Graph is already rewritten with tensor tracer, ignoring ' 1619 'multiple calls.') 1620 return tensor_fetches 1621 else: 1622 TensorTracer._traced_graphs.add(graph) 1623 1624 self._device_type = _DEVICE_TYPE_CPU 1625 self._num_replicas = 1 1626 self._num_replicas_per_host = 1 1627 self._num_hosts = 1 1628 self._replica_id = 0 1629 if self._graph_dump_path: 1630 graph_io.write_graph(graph, self._graph_dump_path, 1631 'graph_before_tt.pbtxt') 1632 with graph.as_default(): 1633 tensor_fetches = self._trace_execution(graph, tensor_fetches, op_fetches, 1634 on_tpu=False) 1635 if self._graph_dump_path: 1636 graph_io.write_graph(graph, self._graph_dump_path, 1637 'graph_after_tt.pbtxt') 1638 return tensor_fetches 1639 1640 1641