1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Control flow statements: loops, conditionals, etc.
16
17Python 2 compatibility version. Not maintained.
18
19Note: most of these operators accept pairs of get_state/set_state functions, to
20capture mutations that the corresponding code blocks might make. These
21mutations only need to be captured when staging the control flow, and they just
22work when reverting to Python behavior.
23
24__Examples__
25
26```
27while cond:
28  self.x += i
29```
30
31When the functionalized version is executed as a Python loop, it just works:
32
33```
34def loop_body():
35  self.x += i     # works as expected for Python loops
36```
37
38But it won't work for TF loops:
39
40```
41def loop_body():
42  self.x += i     # self.x has the wrong value!
43```
44
45get_state/set_state allow piping the mutations through the loop variables as
46well, in effect changing the loop body:
47
48```
49def loop_body(self_x):
50  self.x = self_x  # self.x now has the proper value
51  self.x += i      # the original block
52  self_x = self.x  # write self.x back into the loop vars
53  return self_x
54
55self_x = tf.while_loop(...)
56self.x = self_x    # the result is not properly captured
57```
58"""
59
60from __future__ import absolute_import
61from __future__ import division
62from __future__ import print_function
63
64import functools
65
66import numpy as np
67
68from tensorflow.python.autograph.operators import py_builtins
69from tensorflow.python.autograph.operators import variables
70from tensorflow.python.autograph.utils import ag_logging
71from tensorflow.python.autograph.utils import misc
72from tensorflow.python.autograph.utils import tensors
73from tensorflow.python.data.experimental.ops import scan_ops
74from tensorflow.python.data.experimental.ops import take_while_ops
75from tensorflow.python.data.ops import dataset_ops
76from tensorflow.python.data.ops import iterator_ops
77from tensorflow.python.framework import constant_op
78from tensorflow.python.framework import dtypes
79from tensorflow.python.framework import func_graph
80from tensorflow.python.framework import ops
81from tensorflow.python.framework import tensor_util
82from tensorflow.python.ops import array_ops
83from tensorflow.python.ops import control_flow_ops
84from tensorflow.python.ops import math_ops
85from tensorflow.python.ops import tensor_array_ops
86from tensorflow.python.ops.ragged import ragged_tensor
87from tensorflow.python.util import lazy_loader
88from tensorflow.python.util import nest
89
90
91# TODO(b/145618471): Remove this dependency.
92# Lazy import to work around circular dependencies
93input_lib = lazy_loader.LazyLoader(
94    'input_lib', globals(),
95    'tensorflow.python.distribute.input_lib')
96
97LIMIT_PYTHON_ITERATIONS = True
98PYTHON_MAX_ITERATIONS = 100000000  # Fails in about one minute for empty loops.
99WARN_INEFFICIENT_UNROLL = True
100INEFFICIENT_UNROLL_MIN_ITERATIONS = 3000
101INEFFICIENT_UNROLL_MIN_OPS = 1
102
103
104def _disallow_undefs_into_loop(*values):
105  """Ensures that all values in the state are defined when entering a loop."""
106  undefined = [v for v in values if isinstance(v, variables.Undefined)]
107  if undefined:
108    raise ValueError(
109        '{} must be defined before the loop.'.format(
110            ','.join(s.symbol_name for s in undefined)))
111  for value in values:
112    if isinstance(value, variables.UndefinedReturnValue):
113      # Assumption: the loop will only capture the variable which tracks the
114      # return value if the loop contained a return statement.
115      # TODO(mdan): This should be checked at the place where return occurs.
116      raise ValueError(
117          'return statements are not supported within a TensorFlow loop.')
118
119
120def _is_subshape(left, right):
121  """Returns True if left shape is at least as specific as right shape."""
122  # TODO(mdan): This code should be in TensorShape.
123  # Note: this is not the same as TensorShape.is_compatible_with, which is
124  # symmetric.
125  # This code also duplicates _ShapeLessThanOrEqual from  control_flow_ops.py.
126  if right.dims is None:
127    return True
128  if left.ndims != right.ndims:
129    return False
130  for ldim, rdim in zip(left.dims, right.dims):
131    if rdim.value is not None and ldim.value != rdim.value:
132      return False
133  return True
134
135
136# TODO(mdan): Remove these verifications once TF ops can properly report names.
137def _verify_single_loop_var(
138    name, check_shape, init, entry, exit_, shape_invariant):
139  """Verifies whether the initial, entry and exit values are consistent."""
140  if isinstance(init, (bool, int, float, str, np.ndarray)):
141    init = ops.convert_to_tensor_v2(init)
142  if isinstance(entry, (bool, int, float, str, np.ndarray)):
143    entry = ops.convert_to_tensor_v2(entry)
144  if isinstance(exit_, (bool, int, float, str)):
145    exit_ = ops.convert_to_tensor_v2(exit_)
146
147  if (not tensor_util.is_tf_type(entry) or
148      not tensor_util.is_tf_type(exit_)):
149    return
150
151  # TODO(mdan): Properly account for CompositeTensors.
152  if (not hasattr(entry, 'dtype') or
153      not hasattr(exit_, 'dtype')):
154    return
155  if (not hasattr(entry, 'shape') or
156      not hasattr(exit_, 'shape')):
157    return
158
159  if entry.dtype != exit_.dtype:
160    raise TypeError(
161        '"{}" has dtype {} before the loop, but dtype {} after one'
162        ' iteration. TensorFlow control flow requires it stays the'
163        ' same.'.format(
164            name,
165            entry.dtype.name,
166            exit_.dtype.name,
167        ))
168  if check_shape:
169    exit_shape = exit_.shape
170    if shape_invariant is None:
171      entry_shape = entry.shape
172      if not _is_subshape(exit_shape, entry_shape):
173        raise ValueError(
174            '"{}" has shape {} before the loop, but shape {} after one'
175            ' iteration. Use tf.autograph.experimental.set_loop_options to set'
176            ' shape invariants.'.format(name, entry_shape, exit_shape))
177    else:
178      init_shape = init.shape
179      if not _is_subshape(init_shape, shape_invariant):
180        raise ValueError(
181            '"{}" has shape {} before the loop, which does not conform with'
182            ' the shape invariant {}.'.format(name, init_shape,
183                                              shape_invariant))
184      if not _is_subshape(exit_shape, shape_invariant):
185        raise ValueError(
186            '"{}" has shape {} after the loop, which does not conform with'
187            ' the shape invariant {}.'.format(
188                name, exit_shape, shape_invariant))
189
190
191def _verify_tf_loop_vars(init_vars,
192                         iter_entry_vars,
193                         iter_exit_vars,
194                         symbol_names,
195                         opts,
196                         check_shapes=True):
197  """Verifies loop variables for consistency."""
198  if check_shapes and 'shape_invariants' in opts:
199    shape_invariants = opts['shape_invariants']
200  else:
201    shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars)
202
203  named_vars = zip(symbol_names, init_vars, iter_entry_vars, iter_exit_vars,
204                   shape_invariants)
205  for name, init, entry, exit_, invariant in named_vars:
206    try:
207      nest.assert_same_structure(entry, exit_, expand_composites=True)
208    except (ValueError, TypeError) as e:
209      raise TypeError('"{}" does not have the same nested structure after one'
210                      ' iteration.\n\n{}'.format(name, e))
211    if invariant is not None:
212      try:
213        nest.assert_same_structure(init, invariant, expand_composites=False)
214      except (ValueError, TypeError) as e:
215        raise TypeError('"{}" does not have the same nested structure as its'
216                        ' corresponding shape invariant.\n\n{}'.format(name, e))
217
218    nest.map_structure(
219        functools.partial(_verify_single_loop_var, name, check_shapes), init,
220        entry, exit_, invariant)
221
222
223def _verify_single_cond_var(name, body_var, orelse_var):
224  """Verifies whether body_var and orelse_var are consistent."""
225  if isinstance(body_var, (bool, int, float, str)):
226    body_var = ops.convert_to_tensor_v2(body_var)
227
228  if isinstance(orelse_var, (bool, int, float, str)):
229    orelse_var = ops.convert_to_tensor_v2(orelse_var)
230
231  if (not tensor_util.is_tf_type(body_var) or
232      not tensor_util.is_tf_type(orelse_var)):
233    return
234
235  # TODO(mdan): Properly account for CompositeTensors.
236  if (not hasattr(body_var, 'dtype') or
237      not hasattr(orelse_var, 'dtype')):
238    return
239
240  if body_var.dtype != orelse_var.dtype:
241    raise TypeError(
242        '"{}" has dtype {} in the TRUE branch, but dtype={} in the FALSE'
243        ' branch. TensorFlow control flow requires that they are the'
244        ' same.'.format(name, body_var.dtype.name,
245                        orelse_var.dtype.name))
246
247
248def _verify_tf_cond_vars(body_vars, orelse_vars, symbol_names):
249  """Verifies variables manipulated by a conditional for consistency."""
250  basic_body_vars, composite_body_vars = body_vars
251  basic_orelse_vars, composite_orelse_vars = orelse_vars
252  assert isinstance(composite_body_vars, tuple)
253  assert isinstance(composite_orelse_vars, tuple)
254
255  # TODO(kkb): Make this more consistent.
256  # The basic outputs should always be a tuple.
257  if not isinstance(basic_body_vars, tuple):
258    basic_body_vars = (basic_body_vars,)
259  if not isinstance(basic_orelse_vars, tuple):
260    basic_orelse_vars = (basic_orelse_vars,)
261
262  body_vars = basic_body_vars + composite_body_vars
263  orelse_vars = basic_orelse_vars + composite_orelse_vars
264
265  named_vars = zip(symbol_names, body_vars, orelse_vars)
266  for name, body_var, orelse_var in named_vars:
267    try:
268      nest.assert_same_structure(
269          body_var, orelse_var, expand_composites=True)
270    except (ValueError, TypeError) as e:
271      raise TypeError(
272          '"{}" does not have the same nested structure in the TRUE and FALSE'
273          ' branches.\n\n{}'.format(name, str(e)))
274
275    nest.map_structure(
276        functools.partial(_verify_single_cond_var, name), body_var, orelse_var)
277
278
279def for_stmt(iter_,
280             extra_test,
281             body,
282             get_state,
283             set_state,
284             init_vars,
285             basic_symbol_names,
286             composite_symbol_names,
287             opts):
288  """Functional form of a for statement.
289
290  The loop operates on a state, which includes all symbols that are
291  variant across loop iterations, excluding the iterate as well as the
292  variables local to the loop.
293
294  For example, given the loop below that calculates the geometric and
295  arithmetic means or some numbers:
296
297    geo_mean = 1
298    arith_mean = 0
299    for i in range(n):
300      a = numbers[i]
301      geo_mean *= a
302      arith_mean += a
303
304  The state is represented by the variables geo_mean and arith_mean. The
305  argument for initial_state may contain the tuple (1, 0), the body will
306  include the arguments geo_mean and arith_mean and will return a tuple
307  representing the new values for geo_mean and respectively arith_mean.
308
309  Args:
310    iter_: The entity being iterated over.
311    extra_test: Callable with the state as arguments, and boolean return type.
312      An additional loop condition.
313    body: Callable with the iterate and the state as arguments, and state as
314      return type. The actual loop body.
315    get_state: Additional callable which can capture additional state (such as
316      the values of composite symbols). This is only useful when staging the
317      loop.
318    set_state: Additional callable which save values captured by get_state back
319      into the Python environment. This is only useful when staging the loop.
320    init_vars: Tuple containing the initial state.
321    basic_symbol_names: Tuple containing basic loop var names.
322    composite_symbol_names: Tuple containing composite loop var names.
323    opts: Optional dict of extra loop parameters.
324
325  Returns:
326    Tuple containing the final state.
327  """
328  if tensor_util.is_tf_type(iter_):
329    if tensors.is_range_tensor(iter_):
330      return _tf_range_for_stmt(iter_, extra_test, body, get_state, set_state,
331                                init_vars, basic_symbol_names,
332                                composite_symbol_names, opts)
333    else:
334      return _known_len_tf_for_stmt(iter_, extra_test, body, get_state,
335                                    set_state, init_vars, basic_symbol_names,
336                                    composite_symbol_names, opts)
337
338  if isinstance(iter_, dataset_ops.DatasetV2):
339    return _tf_dataset_for_stmt(iter_, extra_test, body, get_state, set_state,
340                                init_vars, basic_symbol_names,
341                                composite_symbol_names, opts)
342
343  if isinstance(iter_, iterator_ops.OwnedIterator):
344    return _tf_iterator_for_stmt(iter_, extra_test, body, get_state, set_state,
345                                 init_vars, basic_symbol_names,
346                                 composite_symbol_names, opts)
347
348  if isinstance(iter_, ragged_tensor.RaggedTensor):
349    return _tf_ragged_for_stmt(iter_, extra_test, body, get_state, set_state,
350                               init_vars, basic_symbol_names,
351                               composite_symbol_names, opts)
352
353  if isinstance(iter_, input_lib.DistributedIterator):
354    raise NotImplementedError(
355        'distributed iterators not supported yet, use the distributed dataset'
356        ' directly')
357
358  if isinstance(iter_, input_lib.DistributedDataset):
359    return _tf_distributed_dataset_for_stmt(iter_, extra_test, body, init_vars)
360
361  return _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars)
362
363
364def _py_for_stmt(iter_, extra_test, body, get_state, set_state, init_vars):
365  """Overload of for_stmt that executes a Python for loop."""
366  del get_state, set_state
367  state = init_vars
368
369  if extra_test is not None:
370    if extra_test(*state):
371      for target in iter_:
372        state = body(target, *state)
373        if not extra_test(*state):
374          break
375
376  else:
377    for target in iter_:
378      state = body(target, *state)
379
380  return state
381
382
383def _known_len_tf_for_stmt(iter_,
384                           extra_test,
385                           body,
386                           get_state,
387                           set_state,
388                           init_vars,
389                           basic_symbol_names,
390                           composite_symbol_names,
391                           opts):
392  """Overload of for_stmt that iterates over TF entities that admit a length."""
393  _disallow_undefs_into_loop(*init_vars)
394
395  n = py_builtins.len_(iter_)
396  # TODO(b/117628877): Revisit performance once XLA has the necessary support.
397  # Note: using a TensorArray creates an extra copy, but can calculate
398  # gradients more efficiently than StridedSlice.
399  ta = tensor_array_ops.TensorArray(iter_.dtype, size=n)
400  iter_ = ta.unstack(iter_)
401
402  def while_body(iterate_index, *loop_vars):
403    """Main loop body."""
404    iterate = iter_.read(iterate_index)
405    new_vars = body(iterate, *loop_vars)
406
407    loop_vars = (iterate_index + 1,)
408    if new_vars:
409      loop_vars += new_vars
410
411    return loop_vars
412
413  def while_cond(iterate_index, *loop_vars):
414    if extra_test is not None:
415      return control_flow_ops.cond(iterate_index < n,
416                                   lambda: extra_test(*loop_vars),
417                                   lambda: False)
418    return iterate_index < n
419
420  opts['maximum_iterations'] = n
421
422  results = _tf_while_stmt(
423      while_cond,
424      while_body,
425      get_state,
426      set_state,
427      (array_ops.zeros_like(n),) + init_vars,
428      ('<internal iterate>',) + basic_symbol_names,
429      composite_symbol_names,
430      opts,
431  )
432
433  # Note: the iteration index is not returned by the while loop, however
434  # if a symbol with the same name exists outside the loop, it will be captured
435  # by the loop variables and ultimately updated correctly.
436  if isinstance(results, (tuple, list)):
437    assert len(results) >= 1  # Has at least the iterate.
438    if len(results) > 1:
439      results = results[1:]
440  else:
441    results = ()
442
443  return results
444
445
446def _tf_ragged_for_stmt(iter_,
447                        extra_test,
448                        body,
449                        get_state,
450                        set_state,
451                        init_vars,
452                        basic_symbol_names,
453                        composite_symbol_names,
454                        opts):
455  """Overload of for_stmt that iterates over TF ragged tensors."""
456  _disallow_undefs_into_loop(*init_vars)
457
458  # TODO(mdan): Move this into len()? Requires eager support.
459  if iter_.shape and iter_.shape[0] is not None:
460    n = iter_.shape[0]
461  else:
462    n = iter_.row_lengths()[0]
463
464  opts['maximum_iterations'] = n
465
466  def while_body(iterate_index, *loop_vars):
467    """Main loop body."""
468    iterate = iter_[iterate_index]
469    new_vars = body(iterate, *loop_vars)
470
471    loop_vars = (iterate_index + 1,)
472    if new_vars:
473      loop_vars += new_vars
474
475    return loop_vars
476
477  def while_cond(iterate_index, *loop_vars):
478    if extra_test is not None:
479      return control_flow_ops.cond(
480          iterate_index < n,
481          lambda: extra_test(*loop_vars),
482          lambda: False,
483      )
484    return iterate_index < n
485
486  opts['maximum_iterations'] = n
487
488  results = _tf_while_stmt(
489      while_cond,
490      while_body,
491      get_state,
492      set_state,
493      (array_ops.zeros_like(n),) + init_vars,
494      ('<internal iterate>',) + basic_symbol_names,
495      composite_symbol_names,
496      opts,
497  )
498
499  if isinstance(results, (tuple, list)):
500    assert len(results) >= 1  # Has at least the iterate.
501    if len(results) > 1:
502      results = results[1:]
503  else:
504    results = ()
505
506  return results
507
508
509def _tf_range_for_stmt(iter_,
510                       extra_test,
511                       body,
512                       get_state,
513                       set_state,
514                       init_vars,
515                       basic_symbol_names,
516                       composite_symbol_names,
517                       opts):
518  """Overload of for_stmt that iterates over a TF range (and elides it)."""
519  _disallow_undefs_into_loop(*init_vars)
520
521  start, limit, delta = iter_.op.inputs
522
523  def while_body(iterate, *loop_vars):
524    new_vars = body(iterate, *loop_vars)
525    loop_vars = (iterate + delta,)
526
527    if new_vars:
528      loop_vars += new_vars
529
530    return loop_vars
531
532  def while_cond(iterate, *loop_vars):
533    """Cond function for `tf.while_loop`."""
534    main_test = math_ops.logical_or(
535        math_ops.logical_and(delta >= 0, iterate < limit),
536        math_ops.logical_and(delta < 0, iterate > limit))
537    if extra_test is not None:
538      return control_flow_ops.cond(
539          main_test,
540          lambda: extra_test(*loop_vars),
541          lambda: False,
542      )
543    return main_test
544
545  opts['maximum_iterations'] = math_ops.cast(
546      misc.get_range_len(start, limit, delta), dtypes.int32)
547
548  results = _tf_while_stmt(
549      while_cond,
550      while_body,
551      get_state,
552      set_state,
553      (start,) + init_vars,
554      ('<internal iterate>',) + basic_symbol_names,
555      composite_symbol_names,
556      opts,
557  )
558
559  # Note: the iteration index is not returned by the while loop, however
560  # if a symbol with the same name exists outside the loop, it will be captured
561  # by the loop variables and ultimately updated correctly.
562  if isinstance(results, (tuple, list)):
563    assert len(results) >= 1  # Has at least the iterate.
564    if len(results) > 1:
565      results = results[1:]
566  else:
567    results = ()
568
569  return results
570
571
572def _tf_iterator_for_stmt(itr, extra_test, body, get_state, set_state,
573                          init_vars, basic_symbol_names,
574                          composite_symbol_names, opts):
575  """Overload of for_stmt that iterates over TF Iterators. See for_loop."""
576  _disallow_undefs_into_loop(*init_vars)
577
578  def while_body_actual(opt_iterate, *loop_vars):
579    """Actual main loop body."""
580    new_vars = body(opt_iterate.get_value(), *loop_vars)
581    # TODO(mdan): Fix this inconsistency in the converter.
582    if new_vars is None:
583      new_vars = ()
584    # Note: this verification duplicates that perfrmed in tf_while_stmt,
585    # but needs to be done earlier to prevent the tf.cond inside while_body
586    # from blowing up first.
587    _verify_tf_loop_vars(init_vars, loop_vars, new_vars,
588                         basic_symbol_names + composite_symbol_names, opts)
589    return new_vars
590
591  def while_body(has_next, *loop_vars):
592    """Main loop body."""
593    opt_iterate = itr.get_next_as_optional()
594    has_next = opt_iterate.has_value()
595
596    if not init_vars:
597      # cond_v2 requires at least one state tensor in V1.
598      dummy_state = (constant_op.constant(()),)
599    else:
600      dummy_state = ()
601
602    # TODO(mdan): If tf.while_loop supported Optional, this could be avoided.
603    new_vars = control_flow_ops.cond(
604        has_next,
605        lambda: dummy_state + while_body_actual(opt_iterate, *loop_vars),
606        lambda: dummy_state + loop_vars,
607    )
608
609    if dummy_state:
610      new_vars = new_vars[1:]
611
612    return (has_next,) + new_vars
613
614  def while_cond(has_next, *loop_vars):
615    if extra_test is not None:
616      return control_flow_ops.cond(
617          has_next,
618          lambda: extra_test(*loop_vars),
619          lambda: False,
620      )
621    return has_next
622
623  final_vars = _tf_while_stmt(
624      while_cond,
625      while_body,
626      get_state,
627      set_state,
628      (True,) + init_vars,
629      ('<internal has_next>',) + basic_symbol_names,
630      composite_symbol_names,
631      opts,
632  )
633  return final_vars[1:]
634
635
636def _tf_dataset_for_stmt(ds, extra_test, body, get_state, set_state, init_vars,
637                         basic_symbol_names, composite_symbol_names, opts):
638  """Overload of for_stmt that iterates over TF Datasets."""
639  _disallow_undefs_into_loop(*init_vars)
640
641  if extra_test is not None:
642    assert init_vars, 'Lowering should always add state.'
643    return _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
644                                             set_state, init_vars,
645                                             basic_symbol_names,
646                                             composite_symbol_names, opts)
647
648  return _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state,
649                                         init_vars, basic_symbol_names,
650                                         composite_symbol_names, opts)
651
652
653def _general_purpose_scan(ds, init_state, body):
654  """Variant of Dataset.scan with semantics of general-purpose computation."""
655  # Datasets are typically intended for data preprocessing. However, in
656  # autograph loops they usually appear as general-purpose computations (for
657  # example, a custom training loop). These two use cases require significantly
658  # different optimization policies, the most important of which is the device
659  # placement. The flag override for use_default_device below instructs the
660  # runtime to treat the computation as general-purpose, rather than data
661  # preprocessing.
662  # TODO(mdan): s/use_default_device/specialize_for_input_pipeline.
663  # TODO(mdan): Don't use private symbols.
664  return scan_ops._ScanDataset(ds, init_state, body, use_default_device=False)  # pylint:disable=protected-access
665
666
667def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
668                                      set_state, init_vars, basic_symbol_names,
669                                      composite_symbol_names, opts):
670  """Overload of _dataset_for_stmt with early stopping. See for_stmt."""
671
672  # TODO(mdan): Simplify this - following it is extremely difficult.
673
674  init_state = get_state()
675  aug_init_vars = init_vars, init_state
676
677  def scan_body(aug_vars, iterate):
678    """The main loop body wrapper. Only calculates the stop condition."""
679    loop_vars, state = aug_vars
680
681    def true_fn():
682      """Main path - stop condition is not set."""
683      set_state(state)
684      new_vars = body(iterate, *loop_vars)
685      new_state = get_state()
686      _verify_tf_loop_vars(
687          init_vars + init_state,
688          loop_vars + state,
689          new_vars + new_state,
690          basic_symbol_names + composite_symbol_names,
691          opts,
692          check_shapes=False)
693      return new_vars, new_state
694
695    extra_cond = extra_test(*loop_vars)
696    new_vars, new_state = control_flow_ops.cond(
697        extra_cond,
698        true_fn,
699        lambda: (loop_vars, state),
700    )
701
702    scan_outputs = new_vars, new_state, extra_cond
703    # Note: new_aug_vars is the actual state of scan; scan_outputs is its output
704    # (hence the redundancy).
705    # get_state will pull any mutations that body may have made.
706    new_aug_vars = new_vars, new_state
707    return new_aug_vars, scan_outputs
708
709  def take_while_predicate(unused_loop_vars, unused_state, extra_cond):
710    return extra_cond
711
712  def reduce_body(unused_aug_vars, scan_outputs):
713    output_aug_vars, output_state, extra_cond = scan_outputs
714    del extra_cond
715    return output_aug_vars, output_state
716
717  ds = _general_purpose_scan(ds, aug_init_vars, scan_body)
718  ds = ds.apply(take_while_ops.take_while(take_while_predicate))
719  final_aug_vars = ds.reduce(aug_init_vars, reduce_body)
720  final_vars, final_state = final_aug_vars
721  set_state(final_state)
722  return final_vars
723
724
725def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars,
726                                    basic_symbol_names, composite_symbol_names,
727                                    opts):
728  """Overload of _dataset_for_stmt without early stopping. See for_stmt."""
729  init_state = get_state()
730  assert isinstance(init_vars, tuple)
731  assert isinstance(init_state, tuple)
732
733  symbol_names = basic_symbol_names + composite_symbol_names
734
735  # Workaround for Dataset.reduce not allowing empty state tensors - create
736  # a dummy state variable that remains unused.
737  # TODO(mdan): reduce should allow and match empty structures.
738  no_vars = not init_vars
739  no_state = not init_state
740
741  if no_vars:
742    init_vars = (constant_op.constant(0),)
743    symbol_names = ('<internal dummy>',) + symbol_names
744  if no_state:
745    init_state = (constant_op.constant(0),)
746    symbol_names = symbol_names + ('<internal dummy>',)
747
748  def scan_body(aug_vars, iterate):
749    """The main loop body wrapper."""
750    loop_vars, state = aug_vars
751    if not no_state:
752      set_state(state)
753
754    if no_vars:
755      body(iterate)
756      new_vars = loop_vars
757    else:
758      new_vars = body(iterate, *loop_vars)
759
760    if no_state:
761      new_state = state
762    else:
763      new_state = get_state()
764
765    _verify_tf_loop_vars(
766        init_vars + init_state,
767        loop_vars + state,
768        new_vars + new_state,
769        symbol_names,
770        opts,
771        check_shapes=False)
772
773    scan_outputs = new_vars, new_state
774    # Note: new_aug_vars is the actual state of scan; scan_outputs is its output
775    # (hence the redundancy).
776    # get_state will pull any mutations that body may have made.
777    new_aug_vars = new_vars, new_state
778    return new_aug_vars, scan_outputs
779
780  def reduce_body(unused_aug_vars, scan_outputs):
781    output_aug_vars, output_state = scan_outputs
782    return output_aug_vars, output_state
783
784  aug_vars = init_vars, get_state()
785  ds = _general_purpose_scan(ds, aug_vars, scan_body)
786  final_vars, final_state = ds.reduce(aug_vars, reduce_body)
787  set_state(final_state)
788
789  if no_vars:
790    return ()
791  return final_vars
792
793
794def _tf_distributed_dataset_for_stmt(iter_, extra_test, body, init_state):
795  """Overload of for..in statement that iterates over the input."""
796  _disallow_undefs_into_loop(*init_state)
797
798  if extra_test is not None:
799    raise NotImplementedError(
800        'break and return statements are not yet supported in '
801        'for ... in distributed input loops.')
802
803  def reduce_body(state, iterate):
804    new_state = body(iterate, *state)
805    return new_state
806
807  if init_state:
808    return iter_.reduce(init_state, reduce_body)
809
810  def reduce_body_with_dummy_state(state, iterate):
811    reduce_body((), iterate)
812    return state
813  iter_.reduce((constant_op.constant(0),), reduce_body_with_dummy_state)
814  return ()
815
816
817def while_stmt(test,
818               body,
819               get_state,
820               set_state,
821               init_vars,
822               basic_symbol_names,
823               composite_symbol_names,
824               opts):
825  """Functional form of a while statement.
826
827  The loop operates on a so-called state, which includes all symbols that are
828  variant across loop iterations. In what follows we refer to state as either
829  a tuple of entities that represent an actual state, or a list of arguments
830  of the corresponding types.
831
832  Args:
833    test: Callable with the state as arguments, and boolean return type. The
834      loop condition.
835    body: Callable with the state as arguments, and state as return type. The
836      actual loop body.
837    get_state: Additional callable which can capture additional state (such as
838      the values of composite symbols). This is only useful when staging the
839      loop.
840    set_state: Additional callable which save values captured by get_state back
841      into the Python environment. This is only useful when staging the loop.
842    init_vars: Tuple containing the initial state.
843    basic_symbol_names: Tuple containing basic loop var names.
844    composite_symbol_names: Tuple containing composite loop var names.
845    opts: Optional dict of extra loop parameters.
846
847  Returns:
848    Tuple containing the final state.
849  """
850
851  # Evaluate the initial test once in order to do the dispatch. The evaluation
852  # is isolated to minimize unwanted side effects.
853  # TODO(mdan): Do a full iteration - some state types might lower to Tensor.
854  with func_graph.FuncGraph('tmp').as_default():
855    init_test = test(*init_vars)
856
857  # TensorFlow: Multiple evaluations are acceptable in this case, so we're fine
858  # with the re-evaluation of `test` that `_tf_while_stmt` will make.
859  if tensors.is_dense_tensor(init_test):
860    return _tf_while_stmt(test, body, get_state, set_state, init_vars,
861                          basic_symbol_names, composite_symbol_names, opts)
862
863  # Normal Python: We already consumed one evaluation of `test`; consistently,
864  # unroll one iteration before dispatching to a normal loop.
865  # TODO(mdan): Push the "init_test" value via opts into _py_while_stmt?
866  if not init_test:
867    return init_vars
868  init_vars = body(*init_vars)
869
870  return _py_while_stmt(test, body, get_state, set_state, init_vars, opts)
871
872
873def _shape_invariants_mapping_to_positional_list(mapping, keys):
874  # The keys are not expected to be hashable.
875  mapping = {id(k): (k, v) for k, v in mapping}
876  result = []
877  for k in keys:
878    map_key, map_val = mapping.get(id(k), (None, None))
879    result.append(map_val if map_key is k else None)
880  return tuple(result)
881
882
883def _tf_while_stmt(test, body, get_state, set_state, init_vars,
884                   basic_symbol_names, composite_symbol_names, opts):
885  """Overload of while_stmt that stages a TF while_stmt."""
886  _disallow_undefs_into_loop(*init_vars)
887
888  aug_init_vars = init_vars + get_state()
889
890  # TODO(mdan): Simplify this.
891  loop_vars_slice = slice(len(init_vars))
892  state_slice = slice(len(init_vars), None)
893
894  def aug_test(*aug_loop_vars):
895    state = aug_loop_vars[state_slice]
896    set_state(state)
897    return test(*aug_loop_vars[loop_vars_slice])
898
899  def aug_body(*aug_loop_vars):
900    """Main loop body."""
901    state = aug_loop_vars[state_slice]
902    set_state(state)
903    loop_vars = body(*aug_loop_vars[loop_vars_slice])
904    new_state = loop_vars + get_state()
905    _verify_tf_loop_vars(aug_init_vars, aug_loop_vars, new_state,
906                         basic_symbol_names + composite_symbol_names, opts)
907
908    return new_state
909
910  # Non-v2 while_loop unpacks the results when there is only one return value.
911  # This enforces consistency across versions.
912  opts['return_same_structure'] = True
913
914  if 'shape_invariants' in opts:
915    opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list(
916        opts['shape_invariants'], aug_init_vars)
917
918  final_aug_vars = control_flow_ops.while_loop(aug_test, aug_body,
919                                               aug_init_vars, **opts)
920  final_state = final_aug_vars[state_slice]
921  set_state(final_state)
922  return final_aug_vars[loop_vars_slice]
923
924
925class _PythonLoopChecker(object):
926  """Verifies Python loops for TF-specific limits."""
927
928  def __init__(self):
929    self.iterations = 0
930    self.check_inefficient_unroll = WARN_INEFFICIENT_UNROLL
931
932    # Triggered when we decided to test the op counts.
933    self.check_op_count_after_iteration = False
934
935  def _get_ops(self):
936    return ops.get_default_graph().get_operations()
937
938  def _check_unroll_limits(self):
939    if LIMIT_PYTHON_ITERATIONS and self.iterations > PYTHON_MAX_ITERATIONS:
940      raise ValueError('iteration limit exceeded')
941
942  def _stop_checking_inefficient_unroll(self):
943    self.check_inefficient_unroll = False
944    self.ops_before_iteration = None
945
946  def _verify_ineffcient_unroll(self):
947    """Checks for possibly-inefficient creation of ops in a Python loop."""
948    assert self.ops_before_iteration is not None
949    ops_after_iteration = self._get_ops()
950    new_ops = tuple(
951        op for op in ops_after_iteration if op not in self.ops_before_iteration)
952
953    if len(new_ops) < INEFFICIENT_UNROLL_MIN_OPS:
954      return False
955
956    # TODO(mdan): Add location information.
957    ag_logging.warn(
958        'TensorFlow ops are being created in a Python loop with large number'
959        ' of iterations. This can lead to slow startup. Did you mean to use a'
960        ' TensorFlow loop? For example, `while True:` is a Python loop, and'
961        ' `while tf.constant(True):` is a TensorFlow loop. The following'
962        ' ops were created after iteration %s: %s', self.iterations, new_ops)
963    return True
964
965  def before_iteration(self):
966    """Called before each iteration in a Python loop."""
967    if (self.check_inefficient_unroll and
968        self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS):
969      self.ops_before_iteration = self._get_ops()
970      self.check_op_count_after_iteration = True
971
972  def after_iteration(self):
973    """Called after each iteration in a Python loop."""
974    self.iterations += 1
975
976    self._check_unroll_limits()
977
978    if self.check_inefficient_unroll and self.check_op_count_after_iteration:
979      did_warn = self._verify_ineffcient_unroll()
980      if did_warn:
981        self._stop_checking_inefficient_unroll()  # Only warn once.
982      elif self.iterations > INEFFICIENT_UNROLL_MIN_ITERATIONS + 3:
983        # Once deciding to check the op counts, only do it for a few iterations.
984        self._stop_checking_inefficient_unroll()
985
986
987def _py_while_stmt(test, body, get_state, set_state, init_vars, opts):
988  """Overload of while_stmt that executes a Python while loop."""
989  del opts, get_state, set_state
990
991  if __debug__:
992    checker = _PythonLoopChecker()
993
994  loop_vars = init_vars
995  while test(*loop_vars):
996
997    if __debug__:
998      checker.before_iteration()
999
1000    loop_vars = body(*loop_vars)
1001
1002    if __debug__:
1003      checker.after_iteration()
1004
1005  return loop_vars
1006
1007
1008def if_stmt(cond,
1009            body,
1010            orelse,
1011            get_state,
1012            set_state,
1013            basic_symbol_names,
1014            composite_symbol_names):
1015  """Functional form of an if statement.
1016
1017  Args:
1018    cond: Boolean.
1019    body: Callable with no arguments, and outputs of the positive (if) branch as
1020      return type.
1021    orelse: Callable with no arguments, and outputs of the negative (else)
1022      branch as return type.
1023    get_state: Function that returns a tuple containing the values of all
1024      composite symbols modified within the conditional. This allows access to
1025      state that branches may mutate through side effects. This function is not
1026      needed and should not be called when dispatching to code matching Python's
1027      default semantics. This is useful for checkpointing to avoid unintended
1028      side-effects when staging requires evaluating all code-paths.
1029    set_state: Function to set the values of all composite symbols modified
1030      within the conditional. This is the complement to get_state, used to
1031      restore checkpointed values. The single argument a tuple containing values
1032      for each composite symbol that may be modified in a branch of the
1033      conditional. The is usually the result of a call to get_state.
1034    basic_symbol_names: Tuple containing basic loop var names.
1035    composite_symbol_names: Tuple containing composite loop var names.
1036
1037  Returns:
1038    Tuple containing the statement outputs.
1039  """
1040  # Note: tf.cond doesn't support SparseTensor.
1041  if tensors.is_dense_tensor(cond):
1042    return tf_if_stmt(cond, body, orelse, get_state, set_state,
1043                      basic_symbol_names, composite_symbol_names)
1044  else:
1045    return _py_if_stmt(cond, body, orelse)
1046
1047
1048def tf_if_stmt(cond, body, orelse, get_state, set_state, basic_symbol_names,
1049               composite_symbol_names):
1050  """Overload of if_stmt that stages a TF cond."""
1051  body = _wrap_disallow_undefs_from_cond(body, branch_name='if')
1052  orelse = _wrap_disallow_undefs_from_cond(orelse, branch_name='else')
1053  body = _isolate_state(body, get_state, set_state)
1054  orelse = _isolate_state(orelse, get_state, set_state)
1055
1056  # `state` currently includes the values of any composite symbols (e.g. `a.b`)
1057  # composites modified by the loop. `final_vars` includes the values of basic
1058  # symbols (e.g. `a`) which cannot be passed by reference and must be returned.
1059  # See _isolate_state.
1060  # TODO(mdan): We should minimize calls to get/set_state.
1061
1062  body_branch = 0
1063  orelse_branch = 1
1064  result = [None, None]
1065
1066  def error_checking_body():
1067    result[body_branch] = body()
1068    if result[orelse_branch] is not None:
1069      _verify_tf_cond_vars(result[body_branch], result[orelse_branch],
1070                           basic_symbol_names + composite_symbol_names)
1071    return result[body_branch]
1072
1073  def error_checking_orelse():
1074    result[orelse_branch] = orelse()
1075    if result[body_branch] is not None:
1076      _verify_tf_cond_vars(result[body_branch], result[orelse_branch],
1077                           basic_symbol_names + composite_symbol_names)
1078    return result[orelse_branch]
1079
1080  final_vars, final_state = control_flow_ops.cond(cond, error_checking_body,
1081                                                  error_checking_orelse)
1082
1083  set_state(final_state)
1084
1085  return final_vars
1086
1087
1088def _isolate_state(func, get_state, set_state):
1089  """Wraps func to (best-effort) isolate state mutations that func may do.
1090
1091  The simplest example of state mutation is mutation of variables (via e.g.
1092  attributes), or modification of globals.
1093
1094  This allows us to more safely execute this function without worrying about
1095  side effects when the function wasn't normally expected to execute. For
1096  example, staging requires that the function is executed ahead of time, and
1097  we need to ensure its effects are not observed during normal execution.
1098
1099  Args:
1100    func: () -> Any
1101    get_state: () -> Any, returns the current state
1102    set_state: (Any) -> None, resets the state to the specified values.
1103      Typically the result of an earlier call to `get_state`.
1104
1105  Returns:
1106    Tuple[Any, Any], where the first element is the return value of `func`,
1107    and the second is the final state values.
1108  """
1109
1110  def wrapper():
1111    init_state = get_state()
1112    new_vars = func()
1113    # TODO(mdan): These should be copies, lest set_state might affect them.
1114    new_state = get_state()
1115    set_state(init_state)
1116    return new_vars, new_state
1117
1118  return wrapper
1119
1120
1121def _wrap_disallow_undefs_from_cond(func, branch_name):
1122  """Wraps conditional branch to disallow returning undefined symbols."""
1123
1124  def wrapper():
1125    """Calls function and raises an error if undefined symbols are returned."""
1126    results = func()
1127
1128    if isinstance(results, tuple):
1129      results_tuple = results
1130    else:
1131      results_tuple = results,
1132    undefined = [v for v in results_tuple if isinstance(v, variables.Undefined)]
1133    if undefined:
1134      raise ValueError(
1135          'The following symbols must also be initialized in the {} branch: {}.'
1136          ' Alternatively, you may initialize them before the if'
1137          ' statement.'.format(branch_name,
1138                               tuple(s.symbol_name for s in undefined)))
1139
1140    for result in results_tuple:
1141      if isinstance(result, variables.UndefinedReturnValue):
1142        raise ValueError(
1143            'A value must also be returned from the {} branch. If a value is '
1144            'returned from one branch of a conditional a value must be '
1145            'returned from all branches.'.format(branch_name))
1146
1147    return results
1148
1149  return wrapper
1150
1151
1152def _py_if_stmt(cond, body, orelse):
1153  """Overload of if_stmt that executes a Python if statement."""
1154  return body() if cond else orelse()
1155