1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Control flow statements: loops, conditionals, etc. 16 17Python 2 compatibility version. Not maintained. 18 19Note: most of these operators accept pairs of get_state/set_state functions, to 20capture mutations that the corresponding code blocks might make. These 21mutations only need to be captured when staging the control flow, and they just 22work when reverting to Python behavior. 23 24__Examples__ 25 26``` 27while cond: 28 self.x += i 29``` 30 31When the functionalized version is executed as a Python loop, it just works: 32 33``` 34def loop_body(): 35 self.x += i # works as expected for Python loops 36``` 37 38But it won't work for TF loops: 39 40``` 41def loop_body(): 42 self.x += i # self.x has the wrong value! 43``` 44 45get_state/set_state allow piping the mutations through the loop variables as 46well, in effect changing the loop body: 47 48``` 49def loop_body(self_x): 50 self.x = self_x # self.x now has the proper value 51 self.x += i # the original block 52 self_x = self.x # write self.x back into the loop vars 53 return self_x 54 55self_x = tf.while_loop(...) 56self.x = self_x # the result is not properly captured 57``` 58""" 59 60from __future__ import absolute_import 61from __future__ import division 62from __future__ import print_function 63 64import functools 65 66import numpy as np 67 68from tensorflow.python.autograph.operators import py_builtins 69from tensorflow.python.autograph.operators import variables 70from tensorflow.python.autograph.utils import ag_logging 71from tensorflow.python.autograph.utils import misc 72from tensorflow.python.autograph.utils import tensors 73from tensorflow.python.data.experimental.ops import scan_ops 74from tensorflow.python.data.experimental.ops import take_while_ops 75from tensorflow.python.data.ops import dataset_ops 76from tensorflow.python.data.ops import iterator_ops 77from tensorflow.python.framework import constant_op 78from tensorflow.python.framework import dtypes 79from tensorflow.python.framework import func_graph 80from tensorflow.python.framework import ops 81from tensorflow.python.framework import tensor_util 82from tensorflow.python.ops import array_ops 83from tensorflow.python.ops import control_flow_ops 84from tensorflow.python.ops import math_ops 85from tensorflow.python.ops import tensor_array_ops 86from tensorflow.python.ops.ragged import ragged_tensor 87from tensorflow.python.util import lazy_loader 88from tensorflow.python.util import nest 89 90 91# TODO(b/145618471): Remove this dependency. 92# Lazy import to work around circular dependencies 93input_lib = lazy_loader.LazyLoader( 94 'input_lib', globals(), 95 'tensorflow.python.distribute.input_lib') 96 97LIMIT_PYTHON_ITERATIONS = True 98PYTHON_MAX_ITERATIONS = 100000000 # Fails in about one minute for empty loops. 99WARN_INEFFICIENT_UNROLL = True 100INEFFICIENT_UNROLL_MIN_ITERATIONS = 3000 101INEFFICIENT_UNROLL_MIN_OPS = 1 102 103 104def _disallow_undefs_into_loop(*values): 105 """Ensures that all values in the state are defined when entering a loop.""" 106 undefined = [v for v in values if isinstance(v, variables.Undefined)] 107 if undefined: 108 raise ValueError( 109 '{} must be defined before the loop.'.format( 110 ','.join(s.symbol_name for s in undefined))) 111 for value in values: 112 if isinstance(value, variables.UndefinedReturnValue): 113 # Assumption: the loop will only capture the variable which tracks the 114 # return value if the loop contained a return statement. 115 # TODO(mdan): This should be checked at the place where return occurs. 116 raise ValueError( 117 'return statements are not supported within a TensorFlow loop.') 118 119 120def _is_subshape(left, right): 121 """Returns True if left shape is at least as specific as right shape.""" 122 # TODO(mdan): This code should be in TensorShape. 123 # Note: this is not the same as TensorShape.is_compatible_with, which is 124 # symmetric. 125 # This code also duplicates _ShapeLessThanOrEqual from control_flow_ops.py. 126 if right.dims is None: 127 return True 128 if left.ndims != right.ndims: 129 return False 130 for ldim, rdim in zip(left.dims, right.dims): 131 if rdim.value is not None and ldim.value != rdim.value: 132 return False 133 return True 134 135 136# TODO(mdan): Remove these verifications once TF ops can properly report names. 137def _verify_single_loop_var( 138 name, check_shape, init, entry, exit_, shape_invariant): 139 """Verifies whether the initial, entry and exit values are consistent.""" 140 if isinstance(init, (bool, int, float, str, np.ndarray)): 141 init = ops.convert_to_tensor_v2(init) 142 if isinstance(entry, (bool, int, float, str, np.ndarray)): 143 entry = ops.convert_to_tensor_v2(entry) 144 if isinstance(exit_, (bool, int, float, str)): 145 exit_ = ops.convert_to_tensor_v2(exit_) 146 147 if (not tensor_util.is_tf_type(entry) or 148 not tensor_util.is_tf_type(exit_)): 149 return 150 151 # TODO(mdan): Properly account for CompositeTensors. 152 if (not hasattr(entry, 'dtype') or 153 not hasattr(exit_, 'dtype')): 154 return 155 if (not hasattr(entry, 'shape') or 156 not hasattr(exit_, 'shape')): 157 return 158 159 if entry.dtype != exit_.dtype: 160 raise TypeError( 161 '"{}" has dtype {} before the loop, but dtype {} after one' 162 ' iteration. TensorFlow control flow requires it stays the' 163 ' same.'.format( 164 name, 165 entry.dtype.name, 166 exit_.dtype.name, 167 )) 168 if check_shape: 169 exit_shape = exit_.shape 170 if shape_invariant is None: 171 entry_shape = entry.shape 172 if not _is_subshape(exit_shape, entry_shape): 173 raise ValueError( 174 '"{}" has shape {} before the loop, but shape {} after one' 175 ' iteration. Use tf.autograph.experimental.set_loop_options to set' 176 ' shape invariants.'.format(name, entry_shape, exit_shape)) 177 else: 178 init_shape = init.shape 179 if not _is_subshape(init_shape, shape_invariant): 180 raise ValueError( 181 '"{}" has shape {} before the loop, which does not conform with' 182 ' the shape invariant {}.'.format(name, init_shape, 183 shape_invariant)) 184 if not _is_subshape(exit_shape, shape_invariant): 185 raise ValueError( 186 '"{}" has shape {} after the loop, which does not conform with' 187 ' the shape invariant {}.'.format( 188 name, exit_shape, shape_invariant)) 189 190 191def _verify_tf_loop_vars(init_vars, 192 iter_entry_vars, 193 iter_exit_vars, 194 symbol_names, 195 opts, 196 check_shapes=True): 197 """Verifies loop variables for consistency.""" 198 if check_shapes and 'shape_invariants' in opts: 199 shape_invariants = opts['shape_invariants'] 200 else: 201 shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars) 202 203 named_vars = zip(symbol_names, init_vars, iter_entry_vars, iter_exit_vars, 204 shape_invariants) 205 for name, init, entry, exit_, invariant in named_vars: 206 try: 207 nest.assert_same_structure(entry, exit_, expand_composites=True) 208 except (ValueError, TypeError) as e: 209 raise TypeError('"{}" does not have the same nested structure after one' 210 ' iteration.\n\n{}'.format(name, e)) 211 if invariant is not None: 212 try: 213 nest.assert_same_structure(init, invariant, expand_composites=False) 214 except (ValueError, TypeError) as e: 215 raise TypeError('"{}" does not have the same nested structure as its' 216 ' corresponding shape invariant.\n\n{}'.format(name, e)) 217 218 nest.map_structure( 219 functools.partial(_verify_single_loop_var, name, check_shapes), init, 220 entry, exit_, invariant) 221 222 223def _verify_single_cond_var(name, body_var, orelse_var): 224 """Verifies whether body_var and orelse_var are consistent.""" 225 if isinstance(body_var, (bool, int, float, str)): 226 body_var = ops.convert_to_tensor_v2(body_var) 227 228 if isinstance(orelse_var, (bool, int, float, str)): 229 orelse_var = ops.convert_to_tensor_v2(orelse_var) 230 231 if (not tensor_util.is_tf_type(body_var) or 232 not tensor_util.is_tf_type(orelse_var)): 233 return 234 235 # TODO(mdan): Properly account for CompositeTensors. 236 if (not hasattr(body_var, 'dtype') or 237 not hasattr(orelse_var, 'dtype')): 238 return 239 240 if body_var.dtype != orelse_var.dtype: 241 raise TypeError( 242 '"{}" has dtype {} in the TRUE branch, but dtype={} in the FALSE' 243 ' branch. TensorFlow control flow requires that they are the' 244 ' same.'.format(name, body_var.dtype.name, 245 orelse_var.dtype.name)) 246 247 248def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names): 249 """Verifies variables manipulated by a conditional for consistency.""" 250 basic_body_vars, composite_body_vars = body_vars 251 basic_orelse_vars, composite_orelse_vars = orelse_vars 252 assert isinstance(composite_body_vars, tuple) 253 assert isinstance(composite_orelse_vars, tuple) 254 255 # TODO(kkb): Make this more consistent. 256 # The basic outputs should always be a tuple. 257 if not isinstance(basic_body_vars, tuple): 258 basic_body_vars = (basic_body_vars,) 259 if not isinstance(basic_orelse_vars, tuple): 260 basic_orelse_vars = (basic_orelse_vars,) 261 262 body_vars = basic_body_vars + composite_body_vars 263 orelse_vars = basic_orelse_vars + composite_orelse_vars 264 265 named_vars = zip(symbol_names, body_vars, orelse_vars) 266 for name, body_var, orelse_var in named_vars: 267 try: 268 nest.assert_same_structure( 269 body_var, orelse_var, expand_composites=True) 270 except (ValueError, TypeError) as e: 271 raise TypeError( 272 '"{}" does not have the same nested structure in the TRUE and FALSE' 273 ' branches.\n\n{}'.format(name, str(e))) 274 275 nest.map_structure( 276 functools.partial(_verify_single_cond_var, name), body_var, orelse_var) 277 278 279def for_stmt(iter_, 280 extra_test, 281 body, 282 get_state, 283 set_state, 284 init_vars, 285 basic_symbol_names, 286 composite_symbol_names, 287 opts): 288 """Functional form of a for statement. 289 290 The loop operates on a state, which includes all symbols that are 291 variant across loop iterations, excluding the iterate as well as the 292 variables local to the loop. 293 294 For example, given the loop below that calculates the geometric and 295 arithmetic means or some numbers: 296 297 geo_mean = 1 298 arith_mean = 0 299 for i in range(n): 300 a = numbers[i] 301 geo_mean *= a 302 arith_mean += a 303 304 The state is represented by the variables geo_mean and arith_mean. The 305 argument for initial_state may contain the tuple (1, 0), the body will 306 include the arguments geo_mean and arith_mean and will return a tuple 307 representing the new values for geo_mean and respectively arith_mean. 308 309 Args: 310 iter_: The entity being iterated over. 311 extra_test: Callable with the state as arguments, and boolean return type. 312 An additional loop condition. 313 body: Callable with the iterate and the state as arguments, and state as 314 return type. The actual loop body. 315 get_state: Additional callable which can capture additional state (such as 316 the values of composite symbols). This is only useful when staging the 317 loop. 318 set_state: Additional callable which save values captured by get_state back 319 into the Python environment. This is only useful when staging the loop. 320 init_vars: Tuple containing the initial state. 321 basic_symbol_names: Tuple containing basic loop var names. 322 composite_symbol_names: Tuple containing composite loop var names. 323 opts: Optional dict of extra loop parameters. 324 325 Returns: 326 Tuple containing the final state. 327 """ 328 if tensor_util.is_tf_type(iter_): 329 if tensors.is_range_tensor(iter_): 330 return _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state, 331 init_vars, basic_symbol_names, 332 composite_symbol_names, opts) 333 else: 334 return _known_len_tf_for_stmt(iter_, extra_test, body, get_state, 335 set_state, init_vars, basic_symbol_names, 336 composite_symbol_names, opts) 337 338 if isinstance(iter_, dataset_ops.DatasetV2): 339 return _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state, 340 init_vars, basic_symbol_names, 341 composite_symbol_names, opts) 342 343 if isinstance(iter_, iterator_ops.OwnedIterator): 344 return _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state, 345 init_vars, basic_symbol_names, 346 composite_symbol_names, opts) 347 348 if isinstance(iter_, ragged_tensor.RaggedTensor): 349 return _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state, 350 init_vars, basic_symbol_names, 351 composite_symbol_names, opts) 352 353 if isinstance(iter_, input_lib.DistributedIterator): 354 raise NotImplementedError( 355 'distributed iterators not supported yet, use the distributed dataset' 356 ' directly') 357 358 if isinstance(iter_, input_lib.DistributedDataset): 359 return _tf_distributed_dataset_for_stmt(iter_, extra_test, body, init_vars) 360 361 return _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars) 362 363 364def _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars): 365 """Overload of for_stmt that executes a Python for loop.""" 366 del get_state, set_state 367 state = init_vars 368 369 if extra_test is not None: 370 if extra_test(*state): 371 for target in iter_: 372 state = body(target, *state) 373 if not extra_test(*state): 374 break 375 376 else: 377 for target in iter_: 378 state = body(target, *state) 379 380 return state 381 382 383def _known_len_tf_for_stmt(iter_, 384 extra_test, 385 body, 386 get_state, 387 set_state, 388 init_vars, 389 basic_symbol_names, 390 composite_symbol_names, 391 opts): 392 """Overload of for_stmt that iterates over TF entities that admit a length.""" 393 _disallow_undefs_into_loop(*init_vars) 394 395 n = py_builtins.len_(iter_) 396 # TODO(b/117628877): Revisit performance once XLA has the necessary support. 397 # Note: using a TensorArray creates an extra copy, but can calculate 398 # gradients more efficiently than StridedSlice. 399 ta = tensor_array_ops.TensorArray(iter_.dtype, size=n) 400 iter_ = ta.unstack(iter_) 401 402 def while_body(iterate_index, *loop_vars): 403 """Main loop body.""" 404 iterate = iter_.read(iterate_index) 405 new_vars = body(iterate, *loop_vars) 406 407 loop_vars = (iterate_index + 1,) 408 if new_vars: 409 loop_vars += new_vars 410 411 return loop_vars 412 413 def while_cond(iterate_index, *loop_vars): 414 if extra_test is not None: 415 return control_flow_ops.cond(iterate_index < n, 416 lambda: extra_test(*loop_vars), 417 lambda: False) 418 return iterate_index < n 419 420 opts['maximum_iterations'] = n 421 422 results = _tf_while_stmt( 423 while_cond, 424 while_body, 425 get_state, 426 set_state, 427 (array_ops.zeros_like(n),) + init_vars, 428 ('<internal iterate>',) + basic_symbol_names, 429 composite_symbol_names, 430 opts, 431 ) 432 433 # Note: the iteration index is not returned by the while loop, however 434 # if a symbol with the same name exists outside the loop, it will be captured 435 # by the loop variables and ultimately updated correctly. 436 if isinstance(results, (tuple, list)): 437 assert len(results) >= 1 # Has at least the iterate. 438 if len(results) > 1: 439 results = results[1:] 440 else: 441 results = () 442 443 return results 444 445 446def _tf_ragged_for_stmt(iter_, 447 extra_test, 448 body, 449 get_state, 450 set_state, 451 init_vars, 452 basic_symbol_names, 453 composite_symbol_names, 454 opts): 455 """Overload of for_stmt that iterates over TF ragged tensors.""" 456 _disallow_undefs_into_loop(*init_vars) 457 458 # TODO(mdan): Move this into len()? Requires eager support. 459 if iter_.shape and iter_.shape[0] is not None: 460 n = iter_.shape[0] 461 else: 462 n = iter_.row_lengths()[0] 463 464 opts['maximum_iterations'] = n 465 466 def while_body(iterate_index, *loop_vars): 467 """Main loop body.""" 468 iterate = iter_[iterate_index] 469 new_vars = body(iterate, *loop_vars) 470 471 loop_vars = (iterate_index + 1,) 472 if new_vars: 473 loop_vars += new_vars 474 475 return loop_vars 476 477 def while_cond(iterate_index, *loop_vars): 478 if extra_test is not None: 479 return control_flow_ops.cond( 480 iterate_index < n, 481 lambda: extra_test(*loop_vars), 482 lambda: False, 483 ) 484 return iterate_index < n 485 486 opts['maximum_iterations'] = n 487 488 results = _tf_while_stmt( 489 while_cond, 490 while_body, 491 get_state, 492 set_state, 493 (array_ops.zeros_like(n),) + init_vars, 494 ('<internal iterate>',) + basic_symbol_names, 495 composite_symbol_names, 496 opts, 497 ) 498 499 if isinstance(results, (tuple, list)): 500 assert len(results) >= 1 # Has at least the iterate. 501 if len(results) > 1: 502 results = results[1:] 503 else: 504 results = () 505 506 return results 507 508 509def _tf_range_for_stmt(iter_, 510 extra_test, 511 body, 512 get_state, 513 set_state, 514 init_vars, 515 basic_symbol_names, 516 composite_symbol_names, 517 opts): 518 """Overload of for_stmt that iterates over a TF range (and elides it).""" 519 _disallow_undefs_into_loop(*init_vars) 520 521 start, limit, delta = iter_.op.inputs 522 523 def while_body(iterate, *loop_vars): 524 new_vars = body(iterate, *loop_vars) 525 loop_vars = (iterate + delta,) 526 527 if new_vars: 528 loop_vars += new_vars 529 530 return loop_vars 531 532 def while_cond(iterate, *loop_vars): 533 """Cond function for `tf.while_loop`.""" 534 main_test = math_ops.logical_or( 535 math_ops.logical_and(delta >= 0, iterate < limit), 536 math_ops.logical_and(delta < 0, iterate > limit)) 537 if extra_test is not None: 538 return control_flow_ops.cond( 539 main_test, 540 lambda: extra_test(*loop_vars), 541 lambda: False, 542 ) 543 return main_test 544 545 opts['maximum_iterations'] = math_ops.cast( 546 misc.get_range_len(start, limit, delta), dtypes.int32) 547 548 results = _tf_while_stmt( 549 while_cond, 550 while_body, 551 get_state, 552 set_state, 553 (start,) + init_vars, 554 ('<internal iterate>',) + basic_symbol_names, 555 composite_symbol_names, 556 opts, 557 ) 558 559 # Note: the iteration index is not returned by the while loop, however 560 # if a symbol with the same name exists outside the loop, it will be captured 561 # by the loop variables and ultimately updated correctly. 562 if isinstance(results, (tuple, list)): 563 assert len(results) >= 1 # Has at least the iterate. 564 if len(results) > 1: 565 results = results[1:] 566 else: 567 results = () 568 569 return results 570 571 572def _tf_iterator_for_stmt(itr, extra_test, body, get_state, set_state, 573 init_vars, basic_symbol_names, 574 composite_symbol_names, opts): 575 """Overload of for_stmt that iterates over TF Iterators. See for_loop.""" 576 _disallow_undefs_into_loop(*init_vars) 577 578 def while_body_actual(opt_iterate, *loop_vars): 579 """Actual main loop body.""" 580 new_vars = body(opt_iterate.get_value(), *loop_vars) 581 # TODO(mdan): Fix this inconsistency in the converter. 582 if new_vars is None: 583 new_vars = () 584 # Note: this verification duplicates that perfrmed in tf_while_stmt, 585 # but needs to be done earlier to prevent the tf.cond inside while_body 586 # from blowing up first. 587 _verify_tf_loop_vars(init_vars, loop_vars, new_vars, 588 basic_symbol_names + composite_symbol_names, opts) 589 return new_vars 590 591 def while_body(has_next, *loop_vars): 592 """Main loop body.""" 593 opt_iterate = itr.get_next_as_optional() 594 has_next = opt_iterate.has_value() 595 596 if not init_vars: 597 # cond_v2 requires at least one state tensor in V1. 598 dummy_state = (constant_op.constant(()),) 599 else: 600 dummy_state = () 601 602 # TODO(mdan): If tf.while_loop supported Optional, this could be avoided. 603 new_vars = control_flow_ops.cond( 604 has_next, 605 lambda: dummy_state + while_body_actual(opt_iterate, *loop_vars), 606 lambda: dummy_state + loop_vars, 607 ) 608 609 if dummy_state: 610 new_vars = new_vars[1:] 611 612 return (has_next,) + new_vars 613 614 def while_cond(has_next, *loop_vars): 615 if extra_test is not None: 616 return control_flow_ops.cond( 617 has_next, 618 lambda: extra_test(*loop_vars), 619 lambda: False, 620 ) 621 return has_next 622 623 final_vars = _tf_while_stmt( 624 while_cond, 625 while_body, 626 get_state, 627 set_state, 628 (True,) + init_vars, 629 ('<internal has_next>',) + basic_symbol_names, 630 composite_symbol_names, 631 opts, 632 ) 633 return final_vars[1:] 634 635 636def _tf_dataset_for_stmt(ds, extra_test, body, get_state, set_state, init_vars, 637 basic_symbol_names, composite_symbol_names, opts): 638 """Overload of for_stmt that iterates over TF Datasets.""" 639 _disallow_undefs_into_loop(*init_vars) 640 641 if extra_test is not None: 642 assert init_vars, 'Lowering should always add state.' 643 return _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state, 644 set_state, init_vars, 645 basic_symbol_names, 646 composite_symbol_names, opts) 647 648 return _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, 649 init_vars, basic_symbol_names, 650 composite_symbol_names, opts) 651 652 653def _general_purpose_scan(ds, init_state, body): 654 """Variant of Dataset.scan with semantics of general-purpose computation.""" 655 # Datasets are typically intended for data preprocessing. However, in 656 # autograph loops they usually appear as general-purpose computations (for 657 # example, a custom training loop). These two use cases require significantly 658 # different optimization policies, the most important of which is the device 659 # placement. The flag override for use_default_device below instructs the 660 # runtime to treat the computation as general-purpose, rather than data 661 # preprocessing. 662 # TODO(mdan): s/use_default_device/specialize_for_input_pipeline. 663 # TODO(mdan): Don't use private symbols. 664 return scan_ops._ScanDataset(ds, init_state, body, use_default_device=False) # pylint:disable=protected-access 665 666 667def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state, 668 set_state, init_vars, basic_symbol_names, 669 composite_symbol_names, opts): 670 """Overload of _dataset_for_stmt with early stopping. See for_stmt.""" 671 672 # TODO(mdan): Simplify this - following it is extremely difficult. 673 674 init_state = get_state() 675 aug_init_vars = init_vars, init_state 676 677 def scan_body(aug_vars, iterate): 678 """The main loop body wrapper. Only calculates the stop condition.""" 679 loop_vars, state = aug_vars 680 681 def true_fn(): 682 """Main path - stop condition is not set.""" 683 set_state(state) 684 new_vars = body(iterate, *loop_vars) 685 new_state = get_state() 686 _verify_tf_loop_vars( 687 init_vars + init_state, 688 loop_vars + state, 689 new_vars + new_state, 690 basic_symbol_names + composite_symbol_names, 691 opts, 692 check_shapes=False) 693 return new_vars, new_state 694 695 extra_cond = extra_test(*loop_vars) 696 new_vars, new_state = control_flow_ops.cond( 697 extra_cond, 698 true_fn, 699 lambda: (loop_vars, state), 700 ) 701 702 scan_outputs = new_vars, new_state, extra_cond 703 # Note: new_aug_vars is the actual state of scan; scan_outputs is its output 704 # (hence the redundancy). 705 # get_state will pull any mutations that body may have made. 706 new_aug_vars = new_vars, new_state 707 return new_aug_vars, scan_outputs 708 709 def take_while_predicate(unused_loop_vars, unused_state, extra_cond): 710 return extra_cond 711 712 def reduce_body(unused_aug_vars, scan_outputs): 713 output_aug_vars, output_state, extra_cond = scan_outputs 714 del extra_cond 715 return output_aug_vars, output_state 716 717 ds = _general_purpose_scan(ds, aug_init_vars, scan_body) 718 ds = ds.apply(take_while_ops.take_while(take_while_predicate)) 719 final_aug_vars = ds.reduce(aug_init_vars, reduce_body) 720 final_vars, final_state = final_aug_vars 721 set_state(final_state) 722 return final_vars 723 724 725def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars, 726 basic_symbol_names, composite_symbol_names, 727 opts): 728 """Overload of _dataset_for_stmt without early stopping. See for_stmt.""" 729 init_state = get_state() 730 assert isinstance(init_vars, tuple) 731 assert isinstance(init_state, tuple) 732 733 symbol_names = basic_symbol_names + composite_symbol_names 734 735 # Workaround for Dataset.reduce not allowing empty state tensors - create 736 # a dummy state variable that remains unused. 737 # TODO(mdan): reduce should allow and match empty structures. 738 no_vars = not init_vars 739 no_state = not init_state 740 741 if no_vars: 742 init_vars = (constant_op.constant(0),) 743 symbol_names = ('<internal dummy>',) + symbol_names 744 if no_state: 745 init_state = (constant_op.constant(0),) 746 symbol_names = symbol_names + ('<internal dummy>',) 747 748 def scan_body(aug_vars, iterate): 749 """The main loop body wrapper.""" 750 loop_vars, state = aug_vars 751 if not no_state: 752 set_state(state) 753 754 if no_vars: 755 body(iterate) 756 new_vars = loop_vars 757 else: 758 new_vars = body(iterate, *loop_vars) 759 760 if no_state: 761 new_state = state 762 else: 763 new_state = get_state() 764 765 _verify_tf_loop_vars( 766 init_vars + init_state, 767 loop_vars + state, 768 new_vars + new_state, 769 symbol_names, 770 opts, 771 check_shapes=False) 772 773 scan_outputs = new_vars, new_state 774 # Note: new_aug_vars is the actual state of scan; scan_outputs is its output 775 # (hence the redundancy). 776 # get_state will pull any mutations that body may have made. 777 new_aug_vars = new_vars, new_state 778 return new_aug_vars, scan_outputs 779 780 def reduce_body(unused_aug_vars, scan_outputs): 781 output_aug_vars, output_state = scan_outputs 782 return output_aug_vars, output_state 783 784 aug_vars = init_vars, get_state() 785 ds = _general_purpose_scan(ds, aug_vars, scan_body) 786 final_vars, final_state = ds.reduce(aug_vars, reduce_body) 787 set_state(final_state) 788 789 if no_vars: 790 return () 791 return final_vars 792 793 794def _tf_distributed_dataset_for_stmt(iter_, extra_test, body, init_state): 795 """Overload of for..in statement that iterates over the input.""" 796 _disallow_undefs_into_loop(*init_state) 797 798 if extra_test is not None: 799 raise NotImplementedError( 800 'break and return statements are not yet supported in ' 801 'for ... in distributed input loops.') 802 803 def reduce_body(state, iterate): 804 new_state = body(iterate, *state) 805 return new_state 806 807 if init_state: 808 return iter_.reduce(init_state, reduce_body) 809 810 def reduce_body_with_dummy_state(state, iterate): 811 reduce_body((), iterate) 812 return state 813 iter_.reduce((constant_op.constant(0),), reduce_body_with_dummy_state) 814 return () 815 816 817def while_stmt(test, 818 body, 819 get_state, 820 set_state, 821 init_vars, 822 basic_symbol_names, 823 composite_symbol_names, 824 opts): 825 """Functional form of a while statement. 826 827 The loop operates on a so-called state, which includes all symbols that are 828 variant across loop iterations. In what follows we refer to state as either 829 a tuple of entities that represent an actual state, or a list of arguments 830 of the corresponding types. 831 832 Args: 833 test: Callable with the state as arguments, and boolean return type. The 834 loop condition. 835 body: Callable with the state as arguments, and state as return type. The 836 actual loop body. 837 get_state: Additional callable which can capture additional state (such as 838 the values of composite symbols). This is only useful when staging the 839 loop. 840 set_state: Additional callable which save values captured by get_state back 841 into the Python environment. This is only useful when staging the loop. 842 init_vars: Tuple containing the initial state. 843 basic_symbol_names: Tuple containing basic loop var names. 844 composite_symbol_names: Tuple containing composite loop var names. 845 opts: Optional dict of extra loop parameters. 846 847 Returns: 848 Tuple containing the final state. 849 """ 850 851 # Evaluate the initial test once in order to do the dispatch. The evaluation 852 # is isolated to minimize unwanted side effects. 853 # TODO(mdan): Do a full iteration - some state types might lower to Tensor. 854 with func_graph.FuncGraph('tmp').as_default(): 855 init_test = test(*init_vars) 856 857 # TensorFlow: Multiple evaluations are acceptable in this case, so we're fine 858 # with the re-evaluation of `test` that `_tf_while_stmt` will make. 859 if tensors.is_dense_tensor(init_test): 860 return _tf_while_stmt(test, body, get_state, set_state, init_vars, 861 basic_symbol_names, composite_symbol_names, opts) 862 863 # Normal Python: We already consumed one evaluation of `test`; consistently, 864 # unroll one iteration before dispatching to a normal loop. 865 # TODO(mdan): Push the "init_test" value via opts into _py_while_stmt? 866 if not init_test: 867 return init_vars 868 init_vars = body(*init_vars) 869 870 return _py_while_stmt(test, body, get_state, set_state, init_vars, opts) 871 872 873def _shape_invariants_mapping_to_positional_list(mapping, keys): 874 # The keys are not expected to be hashable. 875 mapping = {id(k): (k, v) for k, v in mapping} 876 result = [] 877 for k in keys: 878 map_key, map_val = mapping.get(id(k), (None, None)) 879 result.append(map_val if map_key is k else None) 880 return tuple(result) 881 882 883def _tf_while_stmt(test, body, get_state, set_state, init_vars, 884 basic_symbol_names, composite_symbol_names, opts): 885 """Overload of while_stmt that stages a TF while_stmt.""" 886 _disallow_undefs_into_loop(*init_vars) 887 888 aug_init_vars = init_vars + get_state() 889 890 # TODO(mdan): Simplify this. 891 loop_vars_slice = slice(len(init_vars)) 892 state_slice = slice(len(init_vars), None) 893 894 def aug_test(*aug_loop_vars): 895 state = aug_loop_vars[state_slice] 896 set_state(state) 897 return test(*aug_loop_vars[loop_vars_slice]) 898 899 def aug_body(*aug_loop_vars): 900 """Main loop body.""" 901 state = aug_loop_vars[state_slice] 902 set_state(state) 903 loop_vars = body(*aug_loop_vars[loop_vars_slice]) 904 new_state = loop_vars + get_state() 905 _verify_tf_loop_vars(aug_init_vars, aug_loop_vars, new_state, 906 basic_symbol_names + composite_symbol_names, opts) 907 908 return new_state 909 910 # Non-v2 while_loop unpacks the results when there is only one return value. 911 # This enforces consistency across versions. 912 opts['return_same_structure'] = True 913 914 if 'shape_invariants' in opts: 915 opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list( 916 opts['shape_invariants'], aug_init_vars) 917 918 final_aug_vars = control_flow_ops.while_loop(aug_test, aug_body, 919 aug_init_vars, **opts) 920 final_state = final_aug_vars[state_slice] 921 set_state(final_state) 922 return final_aug_vars[loop_vars_slice] 923 924 925class _PythonLoopChecker(object): 926 """Verifies Python loops for TF-specific limits.""" 927 928 def __init__(self): 929 self.iterations = 0 930 self.check_inefficient_unroll = WARN_INEFFICIENT_UNROLL 931 932 # Triggered when we decided to test the op counts. 933 self.check_op_count_after_iteration = False 934 935 def _get_ops(self): 936 return ops.get_default_graph().get_operations() 937 938 def _check_unroll_limits(self): 939 if LIMIT_PYTHON_ITERATIONS and self.iterations > PYTHON_MAX_ITERATIONS: 940 raise ValueError('iteration limit exceeded') 941 942 def _stop_checking_inefficient_unroll(self): 943 self.check_inefficient_unroll = False 944 self.ops_before_iteration = None 945 946 def _verify_ineffcient_unroll(self): 947 """Checks for possibly-inefficient creation of ops in a Python loop.""" 948 assert self.ops_before_iteration is not None 949 ops_after_iteration = self._get_ops() 950 new_ops = tuple( 951 op for op in ops_after_iteration if op not in self.ops_before_iteration) 952 953 if len(new_ops) < INEFFICIENT_UNROLL_MIN_OPS: 954 return False 955 956 # TODO(mdan): Add location information. 957 ag_logging.warn( 958 'TensorFlow ops are being created in a Python loop with large number' 959 ' of iterations. This can lead to slow startup. Did you mean to use a' 960 ' TensorFlow loop? For example, `while True:` is a Python loop, and' 961 ' `while tf.constant(True):` is a TensorFlow loop. The following' 962 ' ops were created after iteration %s: %s', self.iterations, new_ops) 963 return True 964 965 def before_iteration(self): 966 """Called before each iteration in a Python loop.""" 967 if (self.check_inefficient_unroll and 968 self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS): 969 self.ops_before_iteration = self._get_ops() 970 self.check_op_count_after_iteration = True 971 972 def after_iteration(self): 973 """Called after each iteration in a Python loop.""" 974 self.iterations += 1 975 976 self._check_unroll_limits() 977 978 if self.check_inefficient_unroll and self.check_op_count_after_iteration: 979 did_warn = self._verify_ineffcient_unroll() 980 if did_warn: 981 self._stop_checking_inefficient_unroll() # Only warn once. 982 elif self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS + 3: 983 # Once deciding to check the op counts, only do it for a few iterations. 984 self._stop_checking_inefficient_unroll() 985 986 987def _py_while_stmt(test, body, get_state, set_state, init_vars, opts): 988 """Overload of while_stmt that executes a Python while loop.""" 989 del opts, get_state, set_state 990 991 if __debug__: 992 checker = _PythonLoopChecker() 993 994 loop_vars = init_vars 995 while test(*loop_vars): 996 997 if __debug__: 998 checker.before_iteration() 999 1000 loop_vars = body(*loop_vars) 1001 1002 if __debug__: 1003 checker.after_iteration() 1004 1005 return loop_vars 1006 1007 1008def if_stmt(cond, 1009 body, 1010 orelse, 1011 get_state, 1012 set_state, 1013 basic_symbol_names, 1014 composite_symbol_names): 1015 """Functional form of an if statement. 1016 1017 Args: 1018 cond: Boolean. 1019 body: Callable with no arguments, and outputs of the positive (if) branch as 1020 return type. 1021 orelse: Callable with no arguments, and outputs of the negative (else) 1022 branch as return type. 1023 get_state: Function that returns a tuple containing the values of all 1024 composite symbols modified within the conditional. This allows access to 1025 state that branches may mutate through side effects. This function is not 1026 needed and should not be called when dispatching to code matching Python's 1027 default semantics. This is useful for checkpointing to avoid unintended 1028 side-effects when staging requires evaluating all code-paths. 1029 set_state: Function to set the values of all composite symbols modified 1030 within the conditional. This is the complement to get_state, used to 1031 restore checkpointed values. The single argument a tuple containing values 1032 for each composite symbol that may be modified in a branch of the 1033 conditional. The is usually the result of a call to get_state. 1034 basic_symbol_names: Tuple containing basic loop var names. 1035 composite_symbol_names: Tuple containing composite loop var names. 1036 1037 Returns: 1038 Tuple containing the statement outputs. 1039 """ 1040 # Note: tf.cond doesn't support SparseTensor. 1041 if tensors.is_dense_tensor(cond): 1042 return tf_if_stmt(cond, body, orelse, get_state, set_state, 1043 basic_symbol_names, composite_symbol_names) 1044 else: 1045 return _py_if_stmt(cond, body, orelse) 1046 1047 1048def tf_if_stmt(cond, body, orelse, get_state, set_state, basic_symbol_names, 1049 composite_symbol_names): 1050 """Overload of if_stmt that stages a TF cond.""" 1051 body = _wrap_disallow_undefs_from_cond(body, branch_name='if') 1052 orelse = _wrap_disallow_undefs_from_cond(orelse, branch_name='else') 1053 body = _isolate_state(body, get_state, set_state) 1054 orelse = _isolate_state(orelse, get_state, set_state) 1055 1056 # `state` currently includes the values of any composite symbols (e.g. `a.b`) 1057 # composites modified by the loop. `final_vars` includes the values of basic 1058 # symbols (e.g. `a`) which cannot be passed by reference and must be returned. 1059 # See _isolate_state. 1060 # TODO(mdan): We should minimize calls to get/set_state. 1061 1062 body_branch = 0 1063 orelse_branch = 1 1064 result = [None, None] 1065 1066 def error_checking_body(): 1067 result[body_branch] = body() 1068 if result[orelse_branch] is not None: 1069 _verify_tf_cond_vars(result[body_branch], result[orelse_branch], 1070 basic_symbol_names + composite_symbol_names) 1071 return result[body_branch] 1072 1073 def error_checking_orelse(): 1074 result[orelse_branch] = orelse() 1075 if result[body_branch] is not None: 1076 _verify_tf_cond_vars(result[body_branch], result[orelse_branch], 1077 basic_symbol_names + composite_symbol_names) 1078 return result[orelse_branch] 1079 1080 final_vars, final_state = control_flow_ops.cond(cond, error_checking_body, 1081 error_checking_orelse) 1082 1083 set_state(final_state) 1084 1085 return final_vars 1086 1087 1088def _isolate_state(func, get_state, set_state): 1089 """Wraps func to (best-effort) isolate state mutations that func may do. 1090 1091 The simplest example of state mutation is mutation of variables (via e.g. 1092 attributes), or modification of globals. 1093 1094 This allows us to more safely execute this function without worrying about 1095 side effects when the function wasn't normally expected to execute. For 1096 example, staging requires that the function is executed ahead of time, and 1097 we need to ensure its effects are not observed during normal execution. 1098 1099 Args: 1100 func: () -> Any 1101 get_state: () -> Any, returns the current state 1102 set_state: (Any) -> None, resets the state to the specified values. 1103 Typically the result of an earlier call to `get_state`. 1104 1105 Returns: 1106 Tuple[Any, Any], where the first element is the return value of `func`, 1107 and the second is the final state values. 1108 """ 1109 1110 def wrapper(): 1111 init_state = get_state() 1112 new_vars = func() 1113 # TODO(mdan): These should be copies, lest set_state might affect them. 1114 new_state = get_state() 1115 set_state(init_state) 1116 return new_vars, new_state 1117 1118 return wrapper 1119 1120 1121def _wrap_disallow_undefs_from_cond(func, branch_name): 1122 """Wraps conditional branch to disallow returning undefined symbols.""" 1123 1124 def wrapper(): 1125 """Calls function and raises an error if undefined symbols are returned.""" 1126 results = func() 1127 1128 if isinstance(results, tuple): 1129 results_tuple = results 1130 else: 1131 results_tuple = results, 1132 undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)] 1133 if undefined: 1134 raise ValueError( 1135 'The following symbols must also be initialized in the {} branch: {}.' 1136 ' Alternatively, you may initialize them before the if' 1137 ' statement.'.format(branch_name, 1138 tuple(s.symbol_name for s in undefined))) 1139 1140 for result in results_tuple: 1141 if isinstance(result, variables.UndefinedReturnValue): 1142 raise ValueError( 1143 'A value must also be returned from the {} branch. If a value is ' 1144 'returned from one branch of a conditional a value must be ' 1145 'returned from all branches.'.format(branch_name)) 1146 1147 return results 1148 1149 return wrapper 1150 1151 1152def _py_if_stmt(cond, body, orelse): 1153 """Overload of if_stmt that executes a Python if statement.""" 1154 return body() if cond else orelse() 1155