1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Control Flow Operations.
16
17See the [autograph](https://www.tensorflow.org/guide/autograph) guide.
18"""
19# pylint: disable=g-bad-name
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import abc
25import collections
26import functools
27
28import six
29
30from tensorflow.core.framework import attr_value_pb2
31from tensorflow.core.protobuf import control_flow_pb2
32from tensorflow.python.eager import context
33from tensorflow.python.framework import composite_tensor
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.framework import tensor_spec
40from tensorflow.python.framework import tensor_util
41from tensorflow.python.framework import type_spec
42from tensorflow.python.ops import array_ops
43from tensorflow.python.ops import control_flow_util as util
44from tensorflow.python.ops import gen_array_ops
45from tensorflow.python.ops import gen_control_flow_ops
46from tensorflow.python.ops import gen_functional_ops
47from tensorflow.python.ops import gen_logging_ops
48from tensorflow.python.ops import gen_math_ops
49from tensorflow.python.ops import math_ops
50from tensorflow.python.ops import tensor_array_ops
51# go/tf-wildcard-import
52# pylint: disable=wildcard-import,undefined-variable
53from tensorflow.python.ops.gen_control_flow_ops import *
54# pylint: enable=wildcard-import
55from tensorflow.python.platform import tf_logging as logging
56from tensorflow.python.util import compat
57from tensorflow.python.util import deprecation
58from tensorflow.python.util import dispatch
59from tensorflow.python.util import nest
60from tensorflow.python.util import tf_should_use
61from tensorflow.python.util.lazy_loader import LazyLoader
62from tensorflow.python.util.tf_export import tf_export
63
64# This is to avoid a circular dependency:
65# cond_v2 -> gradients_util -> control_flow_ops
66cond_v2 = LazyLoader("cond_v2", globals(),
67                     "tensorflow.python.ops.cond_v2")
68
69# This is to avoid circular dependencies:
70# while_v2 -> control_flow_ops
71# while_v2 -> gradients_util -> control_flow_ops
72while_v2 = LazyLoader("while_v2", globals(),
73                      "tensorflow.python.ops.while_v2")
74
75# def_function also uses cond
76def_function = LazyLoader(
77    "def_function", globals(),
78    "tensorflow.python.eager.def_function")
79
80
81# We override the 'tuple' for a control flow op, so we keep python's
82# existing 'tuple' for later use in this module.
83_basetuple = tuple
84
85
86def _summarize_eager(tensor, summarize=None):
87  """Returns a summarized string representation of eager `tensor`.
88
89  Args:
90    tensor: EagerTensor to summarize
91    summarize: Include these many first elements of `array`
92  """
93  # Emulate the behavior of Tensor::SummarizeValue()
94  if summarize is None:
95    summarize = 3
96  elif summarize < 0:
97    summarize = array_ops.size(tensor)
98
99  # reshape((-1,)) is the fastest way to get a flat array view
100  if tensor._rank():  # pylint: disable=protected-access
101    flat = tensor.numpy().reshape((-1,))
102    lst = [str(x) for x in flat[:summarize]]
103    if len(lst) < flat.size:
104      lst.append("...")
105  else:
106    # tensor.numpy() returns a scalar for zero dimensional arrays
107    if gen_math_ops.not_equal(summarize, 0):
108      lst = [str(tensor.numpy())]
109    else:
110      lst = []
111
112  return ", ".join(lst)
113
114
115# pylint: disable=protected-access
116
117
118# Assert and Print are special symbols in python, so we must
119# use an upper-case version of them.
120@tf_export("debugging.Assert", "Assert")
121@dispatch.add_dispatch_support
122@tf_should_use.should_use_result
123def Assert(condition, data, summarize=None, name=None):
124  """Asserts that the given condition is true.
125
126  If `condition` evaluates to false, print the list of tensors in `data`.
127  `summarize` determines how many entries of the tensors to print.
128
129  Args:
130    condition: The condition to evaluate.
131    data: The tensors to print out when condition is false.
132    summarize: Print this many entries of each tensor.
133    name: A name for this operation (optional).
134
135  Returns:
136    assert_op: An `Operation` that, when executed, raises a
137    `tf.errors.InvalidArgumentError` if `condition` is not true.
138    @compatibility(eager)
139    returns None
140    @end_compatibility
141
142  Raises:
143    @compatibility(TF1)
144    When in TF V1 mode (that is, outside `tf.function`) Assert needs a control
145    dependency on the output to ensure the assertion executes:
146
147  ```python
148  # Ensure maximum element of x is smaller or equal to 1
149  assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x])
150  with tf.control_dependencies([assert_op]):
151    ... code using x ...
152  ```
153
154    @end_compatibility
155  """
156  if context.executing_eagerly():
157    if not condition:
158      xs = ops.convert_n_to_tensor(data)
159      data_str = [_summarize_eager(x, summarize) for x in xs]
160      raise errors.InvalidArgumentError(
161          node_def=None,
162          op=None,
163          message="Expected '%s' to be true. Summarized data: %s" %
164          (condition, "\n".join(data_str)))
165    return
166
167  with ops.name_scope(name, "Assert", [condition, data]) as name:
168    xs = ops.convert_n_to_tensor(data)
169    if all(x.dtype in {dtypes.string, dtypes.int32} for x in xs):
170      # As a simple heuristic, we assume that string and int32 are
171      # on host to avoid the need to use cond. If it is not case,
172      # we will pay the price copying the tensor to host memory.
173      return gen_logging_ops._assert(condition, data, summarize, name="Assert")
174    else:
175      condition = ops.convert_to_tensor(condition, name="Condition")
176
177      def true_assert():
178        return gen_logging_ops._assert(
179            condition, data, summarize, name="Assert")
180
181      guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
182      if context.executing_eagerly():
183        return
184      return guarded_assert.op
185
186
187def _Identity(data, name=None):
188  """Return a tensor with the same shape and contents as the input tensor.
189
190  Args:
191    data: A Tensor.
192    name: A name for this operation (optional).
193
194  Returns:
195    A Tensor with the same type and value as the input Tensor.
196  """
197  data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
198  if isinstance(data, ops.Tensor):
199    if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
200      return gen_array_ops.ref_identity(data, name=name)
201    else:
202      return array_ops.identity(data, name=name)
203  elif isinstance(data, composite_tensor.CompositeTensor):
204    return nest.map_structure(_Identity, data, expand_composites=True)
205  else:
206    raise TypeError("Type %s not supported" % type(data))
207
208
209def _NextIteration(data, name=None):
210  data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
211  if isinstance(data, ops.Tensor):
212    if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
213      return ref_next_iteration(data, name=name)
214    else:
215      return next_iteration(data, name=name)
216  elif isinstance(data, composite_tensor.CompositeTensor):
217    return nest.map_structure(_NextIteration, data, expand_composites=True)
218  else:
219    raise TypeError("Type %s not supported" % type(data))
220
221
222def _Enter(data,
223           frame_name,
224           is_constant=False,
225           parallel_iterations=10,
226           use_ref=True,
227           use_input_shape=True,
228           name=None):
229  """Creates or finds a child frame, and makes `data` available to it.
230
231  The unique `frame_name` is used by the `Executor` to identify frames. If
232  `is_constant` is true, `data` is a constant in the child frame; otherwise
233  it may be changed in the child frame. At most `parallel_iterations`
234  iterations are run in parallel in the child frame.
235
236  Args:
237    data: The tensor to be made available to the child frame.
238    frame_name: The name of the child frame.
239    is_constant: If true, the output is constant within the child frame.
240    parallel_iterations: The number of iterations allowed to run in parallel.
241    use_ref: If true, use ref_enter if data is of ref type.
242    use_input_shape: If true, set the result's shape based on data's shape.
243    name: A name for this operation (optional).
244
245  Returns:
246    The same tensor as `data`.
247  """
248  data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
249  if isinstance(data, ops.Tensor):
250    if data.dtype._is_ref_dtype and use_ref:  # pylint: disable=protected-access
251      result = gen_control_flow_ops.ref_enter(
252          data, frame_name, is_constant, parallel_iterations, name=name)
253    else:
254      result = gen_control_flow_ops.enter(
255          data, frame_name, is_constant, parallel_iterations, name=name)
256    if use_input_shape:
257      result.set_shape(data.get_shape())
258    return result
259  elif isinstance(data, composite_tensor.CompositeTensor):
260
261    def enter_component(t):
262      return _Enter(t, frame_name, is_constant, parallel_iterations, use_ref,
263                    use_input_shape)
264
265    return nest.map_structure(enter_component, data, expand_composites=True)
266  else:
267    raise TypeError("Type %s not supported" % type(data))
268
269
270def exit(data, name=None):  # pylint: disable=redefined-builtin
271  """Exits the current frame to its parent frame.
272
273  Exit makes its input `data` available to the parent frame.
274
275  Args:
276    data: The tensor to be made available to the parent frame.
277    name: A name for this operation (optional).
278
279  Returns:
280    The same tensor as `data`.
281  """
282  data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
283  if isinstance(data, ops.Tensor):
284    if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
285      return gen_control_flow_ops.ref_exit(data, name)
286    else:
287      return gen_control_flow_ops._exit(data, name)
288  elif isinstance(data, composite_tensor.CompositeTensor):
289    return nest.map_structure(exit, data, expand_composites=True)
290  else:
291    raise TypeError("Type %s not supported" % type(data))
292
293
294def switch(data, pred, dtype=None, name=None):
295  """Forwards `data` to an output determined by `pred`.
296
297  If `pred` is false, the `data` input is forwarded to the first output.
298  Otherwise, the data goes to the second output.
299
300  This op handles `Tensor`s and `IndexedSlices`.
301
302  Args:
303    data: The tensor to be forwarded to the appropriate output.
304    pred: A scalar that specifies which output port will receive data.
305    dtype: Optional element type for the returned tensor. If missing, the type
306      is inferred from the type of `value`.
307    name: A name for this operation (optional).
308
309  Returns:
310    `(output_false, output_true)`: If `pred` is true, data will be forwarded
311    to `output_true`, otherwise it goes to `output_false`.
312  """
313  with ops.name_scope(name, "Switch", [data, pred]) as name:
314    data = ops.internal_convert_to_tensor_or_composite(
315        data, dtype=dtype, name="data", as_ref=True)
316    pred = ops.convert_to_tensor(pred, name="pred")
317    if isinstance(data, ops.Tensor):
318      return gen_control_flow_ops.switch(data, pred, name=name)
319    else:
320      if not isinstance(data, composite_tensor.CompositeTensor):
321        raise TypeError("Type %s not supported" % type(data))
322      tensors = nest.flatten(data, expand_composites=True)
323      mapped = [gen_control_flow_ops.switch(tensor, pred) for tensor in tensors]
324      mapped_f, mapped_t = zip(*mapped)
325      return (nest.pack_sequence_as(data, mapped_f, expand_composites=True),
326              nest.pack_sequence_as(data, mapped_t, expand_composites=True))
327
328
329def _SwitchRefOrTensor(data, pred, name="Switch"):
330  """Forwards `data` to an output determined by `pred`.
331
332  If `pred` is false, the `data` input is forwarded to the first output.
333  Otherwise, the data goes to the second output.
334
335  This op handles `Tensor`s and `IndexedSlices`.
336
337  Args:
338    data: The tensor to be forwarded to the appropriate output.
339    pred: A scalar that specifies which output port will receive data.
340    name: A name for this operation (optional).
341
342  Returns:
343    `(output_false, output_true)`: If `pred` is true, data will be forwarded to
344    `output_true`, otherwise it goes to `output_false`.
345
346  Raises:
347    TypeError: if data is not a Tensor or IndexedSlices
348  """
349  data = ops.convert_to_tensor_or_composite(data, name="data")
350  # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
351  # addresses the following scenario.
352  #
353  # Assume you execute Optimizer.apply_gradients() in a branch of a cond().
354  #
355  # 1. The update op is created inside a `with ops.colocate(var):` block
356  #
357  # 2. Some tensor `data` is captured and a switch is created in a
358  #    `with ops.colocate_with(data):` block.
359  #
360  # with ops.colocate_with(var):
361  #  with ops.colocate_with(data):
362  #    op = ...
363  #
364  # var and data may be pinned to different devices, so we want to ops
365  # created within ops.colocate_with(data) to ignore the existing stack.
366  with ops.colocate_with(data, ignore_existing=True):
367    if isinstance(data, ops.Tensor):
368      if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
369        return ref_switch(data, pred, name=name)
370    return switch(data, pred, name=name)
371
372
373def merge(inputs, name=None):
374  """Returns the value of an available element of `inputs`.
375
376  This op tests each of the tensors in `inputs` in turn to determine if any of
377  them is available. If it finds an available tensor, it returns it and its
378  index in `inputs`.
379
380  It is an error if more than one tensor in `inputs` is available. If no tensor
381  in `inputs` is available, the returned tensor and index are not set.
382
383  This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
384  `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
385  before merging.
386
387  Args:
388    inputs: The input tensors, at most one of which is available.
389    name: A name for this operation (optional).
390
391  Returns:
392    A tuple containing the chosen input tensor and its index in `inputs`.
393
394  Raises:
395    ValueError: If any of the inputs is None, or inputs are IndexedSlices and
396      some but not all have a dense_shape property.
397  """
398  if any(inp is None for inp in inputs):
399    raise ValueError("At least one of the merge inputs is None: %s" % inputs)
400  with ops.name_scope(name, "Merge", inputs) as name:
401    inputs = [
402        ops.internal_convert_to_tensor_or_composite(inp, as_ref=True)
403        for inp in inputs
404    ]
405    if all(isinstance(v, ops.Tensor) for v in inputs):
406      if all(v.dtype._is_ref_dtype for v in inputs):  # pylint: disable=protected-access
407        return gen_control_flow_ops.ref_merge(inputs, name)
408      else:
409        return gen_control_flow_ops.merge(inputs, name)
410    else:
411      # If there is a mix of tensors and indexed slices, then convert the
412      # tensors to indexed slices.
413      if all(isinstance(v, (ops.IndexedSlices, ops.Tensor)) for v in inputs):
414        inputs = math_ops._as_indexed_slices_list(inputs, optimize=False)
415
416      for v in inputs:
417        if not isinstance(v, composite_tensor.CompositeTensor):
418          raise TypeError("Type %s not supported" % type(v))
419
420      for v in inputs[1:]:
421        nest.assert_same_structure(inputs[0], v, expand_composites=True)
422
423      flat_inputs = [nest.flatten(v, expand_composites=True) for v in inputs]
424      merged_results = [
425          gen_control_flow_ops.merge(component)
426          for component in zip(*flat_inputs)
427      ]
428      flat_merged = [tensor for (tensor, _) in merged_results]
429      chosen_index = merged_results[0][1]
430      merged_inputs = nest.pack_sequence_as(
431          inputs[0], flat_merged, expand_composites=True)
432      return (merged_inputs, chosen_index)
433
434
435# pylint: enable=protected-access
436
437
438def _convert_tensorarray_to_flow(tensor_or_tensor_array):
439  if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):
440    return tensor_or_tensor_array.flow
441  else:
442    return tensor_or_tensor_array
443
444
445def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows):
446  if len(tensors_or_tensorarrays) != len(tensors_or_flows):
447    raise ValueError(
448        "Lengths of original Tensor list and new list do not match: %d vs. %d" %
449        (len(tensors_or_tensorarrays), len(tensors_or_flows)))
450  return [
451      tensor_array_ops.build_ta_with_new_flow(ta, t_or_flow) if isinstance(
452          ta, tensor_array_ops.TensorArray) else t_or_flow
453      for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)
454  ]
455
456
457def _ShapeLessThanOrEqual(shape1, shape2):
458  if shape2.dims is None:
459    return True
460  if shape1.ndims != shape2.ndims:
461    return False
462  for dim1, dim2 in zip(shape1.dims, shape2.dims):
463    if dim2.value is not None and dim1.value != dim2.value:
464      return False
465  return True
466
467
468def _get_shape_invariant(var, shape=None):
469  """Returns shape invariant(s) for the given variable.
470
471  Args:
472    var: The tensor whose shape is described.
473    shape: The shape invariant for the tensor.  If not specified, then a default
474      shape invariant for `var` is returned.
475
476  Returns:
477    `TensorShape` or `list` of `TensorShape`: The shape invariant for `var` (if
478    it is a `Tensor`), or the shape invariants for the components that comprise
479    `var` (if it is a `CompositeTensor`).
480  """
481  if isinstance(var, composite_tensor.CompositeTensor):
482    # Get a TypeSpec for `var`.
483    if shape is None:
484      spec = var._type_spec  # pylint: disable=protected-access
485    else:
486      spec = _shape_invariant_to_type_spec(var, shape)
487
488    tensor_specs = nest.flatten(spec, expand_composites=True)
489    return [tspec.shape for tspec in tensor_specs]
490
491  elif shape is None:
492    return var.shape
493  elif isinstance(shape, tensor_spec.TensorSpec):
494    if var.dtype != shape.dtype:
495      raise TypeError("TensorSpec %r is not compatible with %r" % (shape, var))
496    return shape.shape
497  elif isinstance(shape, type_spec.TypeSpec):
498    raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
499  else:
500    return shape
501
502
503def _shape_invariant_to_type_spec(var, shape):
504  """Converts a shape invariant to a TypeSpec.
505
506  Args:
507    var: The tensor whose shape is described by the shape invariant.
508    shape: A `TypeSpec` or `TensorShape`.  If `shape` is already a `TypeSpec`,
509      then it is simply returned as-is.
510
511  Returns:
512    A `TypeSpec` for `var`, consistent with the given shape.
513  """
514  if shape is None:
515    return type_spec.type_spec_from_value(var)
516  elif isinstance(shape, type_spec.TypeSpec):
517    if not shape.is_compatible_with(var):
518      raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
519    return shape
520  elif not isinstance(shape, tensor_shape.TensorShape):
521    raise TypeError(
522        "Expected shape to be a TypeSpec, TensorShape or None, got %r for"
523        " value %r" % (shape, var))
524
525  if isinstance(var, ops.Tensor):
526    return tensor_spec.TensorSpec(shape, var.dtype)
527
528  elif isinstance(var, composite_tensor.CompositeTensor):
529    try:
530      return var._shape_invariant_to_type_spec(shape)  # pylint: disable=protected-access
531    except NotImplementedError:
532      raise TypeError(
533          "To describe or constrain a %s, use a %s instead of a TensorShape." %
534          (type(var).__name__, type(var._type_spec).__name__))  # pylint: disable=protected-access
535
536  else:
537    raise TypeError("Expected var to be a Tensor or CompositeTensor, got %s"
538                    % var)
539
540
541def _SetShapeInvariants(input_vars, enter_vars, shapes):
542  """Set the shapes of the tensors in `enter_vars` to `shapes`.
543
544  Args:
545    input_vars: A list of tensors that are inputs to `enter_vars`.
546    enter_vars: A list of tensors whose shapes will be set.
547    shapes: A (possibly nested) list of shapes.
548
549  Raises:
550    ValueError: If any tensor in `enter_vars` has a less specific shape
551      than its corresponding shape in `shapes`.
552  """
553  if shapes is None:
554    return
555  flat_shapes = nest.flatten(shapes)
556  if not all(isinstance(s, tensor_shape.TensorShape) for s in flat_shapes):
557    raise ValueError("`shapes` must be a (possibly nested) list of shapes.")
558  # Check that the shapes of the inputs are less than the shape invariants,
559  # and set the shapes of `enter_vars` to the shape invariants.
560  for inp, var, shape in zip(input_vars, enter_vars, flat_shapes):
561    if isinstance(var, ops.Tensor):
562      if not _ShapeLessThanOrEqual(inp.get_shape(), shape):
563        raise ValueError(
564            "The shape invariant specified for %s is not compatible with "
565            "the initial shape of the loop variable. It enters the loop "
566            "with shape %s, but the specified shape invariant is %s." %
567            (inp.name, inp.get_shape(), shape))
568      var.set_shape(shape)
569    else:
570      raise TypeError("Type %s not supported" % type(var))
571
572
573def _EnforceShapeInvariant(merge_var, next_var):
574  """Check if the shapes of the loops variables are invariants.
575
576  Args:
577    merge_var: The list of tensors representing the initial values of the loop
578      variables.
579    next_var: The list of tensors representing the values of the loop variables
580      after one loop iteration.
581
582  Raises:
583    ValueError: If any tensor in `merge_var` has a more specific shape than
584      its corresponding tensor in `next_var`.
585  """
586  if isinstance(merge_var, ops.Tensor):
587    m_shape = merge_var.get_shape()
588    n_shape = next_var.get_shape()
589    if not _ShapeLessThanOrEqual(n_shape, m_shape):
590      enter = merge_var.op.inputs[0].op
591      assert util.IsLoopEnter(enter)
592      input_t = enter.inputs[0]
593      raise ValueError(
594          "Input tensor '%s' enters the loop with shape %s, but has shape %s "
595          "after one iteration. To allow the shape to vary across iterations, "
596          "use the `shape_invariants` argument of tf.while_loop to specify a "
597          "less-specific shape." % (input_t.name, input_t.shape, n_shape))
598  else:
599    raise TypeError("Type %s not supported" % type(merge_var))
600
601
602def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True):
603  """Add NextIteration and back edge from v to m."""
604  if isinstance(m, ops.Tensor):
605    v = ops.convert_to_tensor(v)
606    v = _NextIteration(v)
607    if enforce_shape_invariant:
608      # Make sure the shapes of loop outputs are correct. We do this before
609      # calling _update_input, which will raise a less-helpful error message if
610      # the types don't match.
611      # TODO(skyewm): call this for other cases below (needs testing)
612      _EnforceShapeInvariant(m, v)
613    m.op._update_input(1, v)  # pylint: disable=protected-access
614  elif isinstance(m, composite_tensor.CompositeTensor):
615    # pylint: disable=protected-access
616    def update_component(m_component, v_component):
617      m_component.op._update_input(1, v_component)
618
619    if isinstance(m, ops.IndexedSlices):
620      v = math_ops._as_indexed_slices(v, optimize=False)
621    # pylint: enable=protected-access
622    v = _NextIteration(v)
623    return nest.map_structure(update_component, m, v, expand_composites=True)
624  else:
625    raise TypeError("Type %s not supported" % type(m))
626  return v
627
628
629@six.add_metaclass(abc.ABCMeta)
630class ControlFlowContext(object):
631  """The base class for control flow context.
632
633  The usage pattern is a sequence of (Enter, Exit) followed by a final
634  ExitResult.
635
636  We maintain the following state for control flow contexts during graph
637  construction:
638   1. graph has _control_flow_context: the current context used to
639      construct new nodes. Changed by ctxt.Enter() and ctxt.Exit()
640   2. op has _control_flow_context: the context to which the op belongs.
641      Set at the time the op is created. Immutable.
642   3. A ControlFlowContext has _outer_context: the context in which this
643      context is created. Set at the time a context is created. Immutable.
644   4. A ControlFlowContext has _context_stack.
645      Pushed and popped by ctxt.Enter() and ctxt.Exit()
646  """
647
648  def __init__(self, values_def=None, import_scope=None):
649    self._nested_contexts = []
650    self._outer_context = ops.get_default_graph()._get_control_flow_context()
651    if self._outer_context:
652      self._outer_context._nested_contexts.append(self)  # pylint: disable=protected-access
653    self._context_stack = []
654    if values_def:
655      self._init_values_from_proto(values_def, import_scope=import_scope)
656    else:
657      # The names of tensors that have been already seen in this context.
658      self._values = set()
659      # The keys are the names of tensors referenced by but external to this
660      # context. Each value is the Tensor that should be used by this context to
661      # access the key value (e.g. a switch output guarding a cond input value).
662      self._external_values = {}
663
664  def _init_values_from_proto(self, values_def, import_scope=None):
665    """Initializes values and external_values from `ValuesDef` protocol buffer.
666
667    Args:
668      values_def: `ValuesDef` protocol buffer.
669      import_scope: Optional `string`. Name scope to add.
670    """
671    assert isinstance(values_def, control_flow_pb2.ValuesDef)
672    self._values = set(
673        ops.prepend_name_scope(value, import_scope)
674        for value in values_def.values)
675    g = ops.get_default_graph()
676    self._external_values = {}
677    for k, v in values_def.external_values.items():
678      k = ops.prepend_name_scope(k, import_scope)
679      self._external_values[k] = g.as_graph_element(
680          ops.prepend_name_scope(v, import_scope))
681    op_names = set([
682        op.split(":")[0]
683        for op in self._values - set(self._external_values.keys())
684    ])
685    for op in op_names:
686      # pylint: disable=protected-access
687      g.as_graph_element(op)._set_control_flow_context(self)
688      # pylint: enable=protected-access
689
690  @property
691  def name(self):
692    return self._name
693
694  @property
695  def outer_context(self):
696    """Return the context containing this context."""
697    return self._outer_context
698
699  @property
700  def grad_state(self):
701    raise NotImplementedError("Abstract method")
702
703  @property
704  def back_prop(self):
705    raise NotImplementedError("Abstract method")
706
707  @abc.abstractmethod
708  def to_control_flow_context_def(self, context_def, export_scope=None):
709    """Serializes this into `context_def`.
710
711    Args:
712      context_def: a `ControlFlowContextDef` protocol buffer.
713      export_scope: Optional `string`. Name scope to remove.
714    """
715    raise NotImplementedError("Abstract method")
716
717  def _to_values_def(self, export_scope=None):
718    """Converts the values to a `ValuesDef` protocol buffer.
719
720    Args:
721      export_scope: Optional `string`. Name scope to remove.
722
723    Returns:
724      A `ValuesDef` protocol buffer.
725    """
726    values_def = control_flow_pb2.ValuesDef()
727    values_def.values.extend(
728        [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)])
729    for k, v in self._external_values.items():
730      k = ops.strip_name_scope(k, export_scope)
731      values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope)
732    return values_def
733
734  def AddName(self, name):
735    self._values.add(name)
736
737  # pylint: disable=protected-access
738  def Enter(self):
739    """Enter this control flow context."""
740    graph = ops.get_default_graph()
741    self._context_stack.append(graph._get_control_flow_context())
742    graph._set_control_flow_context(self)
743
744  def Exit(self):
745    """Exit this control flow context."""
746    graph = ops.get_default_graph()
747    last_context = self._context_stack.pop()
748    graph._set_control_flow_context(last_context)
749
750  def EnterGradientColocation(self, op, gradient_uid):
751    """Start building a gradient colocated with an op."""
752    if self._outer_context:
753      self._outer_context.EnterGradientColocation(op, gradient_uid)
754
755  def ExitGradientColocation(self, op, gradient_uid):
756    """Start building a gradient colocated with an op."""
757    if self._outer_context:
758      self._outer_context.ExitGradientColocation(op, gradient_uid)
759
760  def ExitResult(self, result):
761    """Make a list of tensors available in the outer context."""
762    if self._outer_context:
763      def fn(x):
764        self._outer_context.AddName(x.name)
765        return x
766      nest.map_structure(fn, result, expand_composites=True)
767
768  def GetWhileContext(self):
769    """Return the while context containing this context."""
770    if self._outer_context:
771      return self._outer_context.GetWhileContext()
772    return None
773
774  def _RemoveExternalControlEdges(self, op):
775    """Remove any external control dependency on this op."""
776    while_ctxt = self.GetWhileContext()
777    # A control input of `op` is internal if it is in the same while
778    # loop context as the enclosing while loop context of self.
779    if while_ctxt is None:
780      internal_control_inputs = op.control_inputs
781    else:
782      internal_control_inputs = []
783      for x in op.control_inputs:
784        ctxt = util.GetOutputContext(x)
785        if ctxt is not None and ctxt.GetWhileContext() == while_ctxt:
786          internal_control_inputs.append(x)
787    external_control_inputs = []
788    if len(internal_control_inputs) != len(op.control_inputs):
789      external_control_inputs = list(
790          set(op.control_inputs) - set(internal_control_inputs))
791      op._remove_all_control_inputs()
792      op._add_control_inputs(internal_control_inputs)
793    return internal_control_inputs, external_control_inputs
794
795  # pylint: enable=protected-access
796
797  def AddInnerOp(self, op):
798    """Notifies a scope about an operator added to an inner scope."""
799    if self._outer_context:
800      self._outer_context.AddInnerOp(op)
801
802  def GetControlPivot(self):
803    """Returns the pivot node for this context, or None."""
804    return None
805
806  def IsWhileContext(self):
807    return False
808
809  def IsCondContext(self):
810    return False
811
812  def IsXLAContext(self):
813    return False
814
815  def __str__(self):
816    return self.name
817
818
819class CondContext(ControlFlowContext):
820  """The context for the conditional construct."""
821
822  def __init__(self,
823               pred=None,
824               pivot=None,
825               branch=None,
826               name="cond_text",
827               context_def=None,
828               import_scope=None):
829    """Creates a `CondContext`.
830
831    Args:
832      pred: The `boolean` tensor for the conditional predicate.
833      pivot: The predicate tensor in this branch.
834      branch: 0 or 1 representing this branch.
835      name: Name of the `CondContext` python object.
836      context_def: Optional `ContextDef` protocol buffer to initialize the
837        `CondContext` object from.
838      import_scope: Optional `string`. Name scope to add. Only used when
839        initialing from protocol buffer.
840    """
841    self._name = ops.get_default_graph().unique_name(name)
842
843    if context_def:
844      self._init_from_proto(context_def, import_scope=import_scope)
845    else:
846      # Initializes the default fields.
847      ControlFlowContext.__init__(self)
848      self._pred = pred  # The boolean tensor for the cond predicate
849      self._pivot = pivot  # The predicate tensor in this branch
850      self._branch = branch  # 0 or 1 representing this branch
851
852      # Values considered to have been already seen in this context. pred is not
853      # included in this context.
854      self._values.add(pred.name)
855      self._external_values[pred.name] = pred
856      self._values.add(pivot.name)
857      pivot.op._set_control_flow_context(self)  # pylint: disable=protected-access
858
859  def _init_from_proto(self, context_def, import_scope=None):
860    """Creates a new `CondContext` from protocol buffer.
861
862    Args:
863      context_def: `CondContextDef` protocol buffer.
864      import_scope: Optional `string`. Name scope to add.
865    """
866    assert isinstance(context_def, control_flow_pb2.CondContextDef)
867    # Create from context_def.
868    g = ops.get_default_graph()
869    self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
870    self._pred = g.as_graph_element(
871        ops.prepend_name_scope(context_def.pred_name, import_scope))
872    self._pivot = g.as_graph_element(
873        ops.prepend_name_scope(context_def.pivot_name, import_scope))
874    self._branch = context_def.branch
875    super(CondContext, self).__init__(
876        values_def=context_def.values_def, import_scope=import_scope)
877
878  @property
879  def pred(self):
880    return self._pred
881
882  @property
883  def pivot(self):
884    return self._pivot
885
886  @property
887  def branch(self):
888    return self._branch
889
890  @property
891  def grad_state(self):
892    if self.GetWhileContext():
893      return self.GetWhileContext().grad_state
894    return None
895
896  @property
897  def back_prop(self):
898    if self.GetWhileContext():
899      self.GetWhileContext().back_prop
900    return False
901
902  def GetControlPivot(self):
903    return self._pivot
904
905  def to_proto(self, export_scope=None):
906    """Converts a `CondContext` to a `CondContextDef` protocol buffer.
907
908    Args:
909      export_scope: Optional `string`. Name scope to remove.
910
911    Returns:
912      A `CondContextDef` protocol buffer.
913    """
914    if (export_scope is None or self.name.startswith(export_scope)):
915      context_def = control_flow_pb2.CondContextDef()
916      context_def.context_name = ops.strip_name_scope(self.name, export_scope)
917      context_def.pred_name = ops.strip_name_scope(self._pred.name,
918                                                   export_scope)
919      context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
920                                                    export_scope)
921      context_def.branch = self._branch
922      context_def.values_def.MergeFrom(
923          super(CondContext, self)._to_values_def(export_scope))
924      for nested in self._nested_contexts:
925        nested_def = context_def.nested_contexts.add()
926        nested.to_control_flow_context_def(nested_def)
927
928      return context_def
929    else:
930      return None
931
932  @staticmethod
933  def from_proto(context_def, import_scope=None):
934    """Returns a `CondContext` object created from `context_def`."""
935    ret = CondContext(context_def=context_def, import_scope=import_scope)
936
937    ret.Enter()
938    for nested_def in context_def.nested_contexts:
939      from_control_flow_context_def(nested_def, import_scope=import_scope)
940    ret.Exit()
941    return ret
942
943  def to_control_flow_context_def(self, context_def, export_scope=None):
944    context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
945
946  def AddValue(self, val):
947    """Add `val` to the current context and its outer context recursively."""
948    if val.name in self._values:
949      # Use the real value if it comes from outer context. This is needed in
950      # particular for nested conds.
951      result = self._external_values.get(val.name)
952      result = val if result is None else result
953    else:
954      result = val
955      self._values.add(val.name)
956      if self._outer_context:
957        result = self._outer_context.AddValue(val)
958        self._values.add(result.name)
959        self._external_values[result.name] = result
960      with ops.control_dependencies(None):
961        result = _SwitchRefOrTensor(result, self._pred)[self._branch]
962        if self._outer_context:
963          self._outer_context.AddInnerOp(result.op)
964
965      result.op.graph.prevent_fetching(result.op)
966      # pylint: disable=protected-access
967      result.op._set_control_flow_context(self)
968      # pylint: enable=protected-access
969
970      # Mark Switch output as seen by this context and any outer contexts,
971      # just like what we do for normal op outputs in _AddOpInternal() below.
972      ctxt = self
973      while ctxt is not None:
974        # pylint: disable=protected-access
975        ctxt._values.add(result.name)
976        ctxt = ctxt._outer_context
977        # pylint: enable=protected-access
978
979      self._external_values[val.name] = result
980    return result
981
982  def AddOp(self, op):
983    self._AddOpInternal(op)
984
985  def _AddOpInternal(self, op):
986    """Add `op` to the current context."""
987    if not op.inputs:
988      # If we're in a while loop, remove any control inputs from outside the
989      # loop.
990      self._RemoveExternalControlEdges(op)
991
992      if not any(
993          util.OpInContext(input_op, self) for input_op in op.control_inputs):
994        # pylint: disable=protected-access
995        op._add_control_input(self._pivot.op)
996        # pylint: enable=protected-access
997    else:
998      # Make each input to 'op' available in this CondContext. If an input is
999      # already part of this context there's nothing to do, but if it's
1000      # external, AddValue() will handle adding the appropriate Switch node and
1001      # other bookkeeping.
1002      for index in range(len(op.inputs)):
1003        x = op.inputs[index]
1004        if op.type == "Merge" and x.op.type == "NextIteration":
1005          # Edge case: if we're importing a while loop inside this CondContext,
1006          # AddValue() will not correctly handle the NextIteration inputs to
1007          # Merge node. The problem is that the NextIteration should also be
1008          # part of this context, but if we're importing it won't have been
1009          # processed and added to the context yet, so AddValue() will try to
1010          # add a Switch which results in an invalid graph. Instead, we use the
1011          # NextIteration input as-is here, and it will eventually be added to
1012          # the context via AddOp().
1013          real_x = x
1014        else:
1015          real_x = self.AddValue(x)
1016        if real_x != x:
1017          # pylint: disable=protected-access
1018          op._update_input(index, real_x)
1019          # pylint: enable=protected-access
1020      # Remove any external control dependency on this op.
1021      self._RemoveExternalControlEdges(op)
1022      # pylint: disable=protected-access
1023      if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
1024        op._add_control_input(self._pivot.op)
1025      # pylint: enable=protected-access
1026
1027    # Mark op's outputs as seen by this context and any outer contexts.
1028    output_names = [x.name for x in op.outputs]
1029    ctxt = self
1030    while ctxt is not None:
1031      # pylint: disable=protected-access
1032      ctxt._values.update(output_names)
1033      ctxt = ctxt._outer_context
1034      # pylint: enable=protected-access
1035
1036    if self._outer_context or not util.IsLoopExit(op):
1037      op.graph.prevent_fetching(op)
1038
1039    if self._outer_context:
1040      self._outer_context.AddInnerOp(op)
1041
1042  def _ProcessOutputTensor(self, val):
1043    """Process an output tensor of a conditional branch."""
1044    real_val = val
1045    if val.name not in self._values:
1046      # Handle the special case of lambda: x
1047      self._values.add(val.name)
1048      if self._outer_context:
1049        real_val = self._outer_context.AddValue(val)
1050        self._values.add(real_val.name)
1051        self._external_values[real_val.name] = real_val
1052      real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]
1053      self._external_values[val.name] = real_val
1054    else:
1055      external_val = self._external_values.get(val.name)
1056      if external_val is not None:
1057        real_val = external_val
1058    return real_val
1059
1060  def _BuildCondTensor(self, v):
1061    if isinstance(v, ops.Operation):
1062      # Use pivot as the proxy for this op.
1063      return with_dependencies([v], self._pivot)
1064    else:
1065      v = nest.map_structure(
1066          _convert_tensorarray_to_flow, v, expand_composites=True)
1067      return self._ProcessOutputTensor(ops.convert_to_tensor(v))
1068
1069  def BuildCondBranch(self, fn):
1070    """Add the subgraph defined by fn() to the graph."""
1071    pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1072    original_result = fn()
1073    post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1074    if len(post_summaries) > len(pre_summaries):
1075      new_summaries = post_summaries[len(pre_summaries):]
1076      summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1077      summary_ref[:] = pre_summaries
1078      with ops.control_dependencies(new_summaries):
1079        if original_result is None:
1080          return no_op(), None
1081        elif not isinstance(original_result, ops.Operation):
1082          original_result = nest.map_structure(
1083              array_ops.identity, original_result, expand_composites=True)
1084    if original_result is None:
1085      return None, None
1086
1087    result = nest.map_structure(
1088        self._BuildCondTensor, original_result, expand_composites=True)
1089    if not isinstance(result, (list, _basetuple)):
1090      result = [result]
1091    return original_result, result
1092
1093  def IsCondContext(self):
1094    return True
1095
1096
1097def _UnpackIfSingleton(res):
1098  if isinstance(res, (list, _basetuple)) and len(res) == 1:
1099    return res[0]
1100  else:
1101    return res
1102
1103
1104def _eager_cond_implementation(pred, true_fn, false_fn, strict, name):
1105  """Special cases for `cond` when executing eagerly."""
1106  pred = ops.convert_to_tensor(pred)
1107  pred_constant_value = tensor_util.constant_value(pred)
1108  if pred_constant_value is None:
1109    # Eager tensors from a parallel device may not have a constant
1110    # value. Running the cond op itself would work, but we don't have logic to
1111    # build cond ops without wrapping in a function first.
1112    if (not isinstance(true_fn, def_function.Function)
1113        or not isinstance(false_fn, def_function.Function)):
1114      raise TypeError("When running tf.cond on a parallel device, `true_fn` "
1115                      "and `false_fn` must be decorated with `tf.function`.")
1116    @def_function.function
1117    def _parallel_device_cond_wrapper():
1118      return cond_v2.cond_v2(pred, true_fn, false_fn, name)
1119    functions_run_eagerly = def_function.functions_run_eagerly()
1120    if functions_run_eagerly:
1121      # We need to use tf.function to deal with variable creation inside the
1122      # cond, and skipping it because of run_functions_eagerly would just
1123      # crash immediately.
1124      logging.warning(
1125          "It looks like tf.function behavior was disabled, perhaps using "
1126          "tf.config.run_functions_eagerly. Parallelized tf.cond requires "
1127          "tf.function to work. This primitive will override the disable.")
1128    def_function.run_functions_eagerly(False)
1129    try:
1130      return _parallel_device_cond_wrapper()
1131    finally:
1132      if functions_run_eagerly is not None:
1133        def_function.run_functions_eagerly(functions_run_eagerly)
1134  else:
1135    # For conditions which are eager tensors with a constant value (most of
1136    # them), we only call the relevant branch function and execute it eagerly.
1137    with ops.name_scope(name, "cond", [pred]):
1138      if pred_constant_value:
1139        result = true_fn()
1140      else:
1141        result = false_fn()
1142      if not strict:
1143        result = _UnpackIfSingleton(result)
1144      return result
1145
1146
1147# pylint: disable=redefined-outer-name
1148# pylint: disable=g-doc-args
1149@tf_export(v1=["cond"])
1150@dispatch.add_dispatch_support
1151@deprecation.deprecated_args(
1152    None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
1153    "fn1", "fn2")
1154def cond(pred,
1155         true_fn=None,
1156         false_fn=None,
1157         strict=False,
1158         name=None,
1159         fn1=None,
1160         fn2=None):
1161  """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
1162
1163  `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
1164  `false_fn` must have the same non-zero number and type of outputs.
1165
1166  **WARNING**: Any Tensors or Operations created outside of `true_fn` and
1167  `false_fn` will be executed regardless of which branch is selected at runtime.
1168
1169  Although this behavior is consistent with the dataflow model of TensorFlow,
1170  it has frequently surprised users who expected a lazier semantics.
1171  Consider the following simple program:
1172
1173  ```python
1174  z = tf.multiply(a, b)
1175  result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
1176  ```
1177
1178  If `x < y`, the `tf.add` operation will be executed and `tf.square`
1179  operation will not be executed. Since `z` is needed for at least one
1180  branch of the `cond`, the `tf.multiply` operation is always executed,
1181  unconditionally.
1182
1183  Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
1184  call to `cond`, and not at all during `Session.run()`). `cond`
1185  stitches together the graph fragments created during the `true_fn` and
1186  `false_fn` calls with some additional graph nodes to ensure that the right
1187  branch gets executed depending on the value of `pred`.
1188
1189  `tf.cond` supports nested structures as implemented in
1190  `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
1191  same (possibly nested) value structure of lists, tuples, and/or named tuples.
1192  Singleton lists and tuples form the only exceptions to this: when returned by
1193  `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
1194  This behavior is disabled by passing `strict=True`.
1195
1196  Args:
1197    pred: A scalar determining whether to return the result of `true_fn` or
1198      `false_fn`.
1199    true_fn: The callable to be performed if pred is true.
1200    false_fn: The callable to be performed if pred is false.
1201    strict: A boolean that enables/disables 'strict' mode; see above.
1202    name: Optional name prefix for the returned tensors.
1203
1204  Returns:
1205    Tensors returned by the call to either `true_fn` or `false_fn`. If the
1206    callables return a singleton list, the element is extracted from the list.
1207
1208  Raises:
1209    TypeError: if `true_fn` or `false_fn` is not callable.
1210    ValueError: if `true_fn` and `false_fn` do not return the same number of
1211      tensors, or return tensors of different types.
1212
1213  Example:
1214
1215  ```python
1216  x = tf.constant(2)
1217  y = tf.constant(5)
1218  def f1(): return tf.multiply(x, 17)
1219  def f2(): return tf.add(y, 23)
1220  r = tf.cond(tf.less(x, y), f1, f2)
1221  # r is set to f1().
1222  # Operations in f2 (e.g., tf.add) are not executed.
1223  ```
1224
1225  """
1226  # We needed to make true_fn/false_fn keyword arguments for
1227  # backwards-compatibility. This check exists so that we can convert back to
1228  # having them be positional arguments.
1229  # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
1230  # `fn1` and `fn2` are deleted.
1231  if fn1 is not None:
1232    if true_fn is not None:
1233      raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
1234    true_fn = fn1
1235  elif true_fn is None:
1236    raise TypeError("cond(): true_fn argument required")
1237  if fn2 is not None:
1238    if false_fn is not None:
1239      raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
1240    false_fn = fn2
1241  elif false_fn is None:
1242    raise TypeError("cond(): false_fn argument required")
1243
1244  if not callable(true_fn):
1245    raise TypeError("true_fn must be callable.")
1246  if not callable(false_fn):
1247    raise TypeError("false_fn must be callable.")
1248
1249  if context.executing_eagerly():
1250    return _eager_cond_implementation(pred, true_fn, false_fn, strict, name)
1251
1252  # Always enable control flow v2 if building a function, regardless of toggle.
1253  if util.EnableControlFlowV2(ops.get_default_graph()):
1254    return cond_v2.cond_v2(pred, true_fn, false_fn, name)
1255
1256  with ops.name_scope(name, "cond", [pred]):
1257    # Add the Switch to the graph.
1258    if isinstance(pred, bool):
1259      raise TypeError("pred must not be a Python bool")
1260    p_2, p_1 = switch(pred, pred)
1261    pivot_1 = array_ops.identity(p_1, name="switch_t")
1262    pivot_2 = array_ops.identity(p_2, name="switch_f")
1263    pred = array_ops.identity(pred, name="pred_id")
1264    # Disable the fetching of tensors that are only on one branch of cond.
1265    for tensor in [p_1, p_2, pivot_1, pivot_2, pred]:
1266      tensor.op.graph.prevent_fetching(tensor.op)
1267
1268    # Build the graph for the true branch in a new context.
1269    context_t = CondContext(pred, pivot_1, branch=1)
1270    try:
1271      context_t.Enter()
1272      orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
1273      if orig_res_t is None:
1274        raise ValueError("true_fn must have a return value.")
1275      context_t.ExitResult(res_t)
1276    finally:
1277      context_t.Exit()
1278
1279    # Build the graph for the false branch in a new context.
1280    context_f = CondContext(pred, pivot_2, branch=0)
1281    try:
1282      context_f.Enter()
1283      orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
1284      if orig_res_f is None:
1285        raise ValueError("false_fn must have a return value.")
1286      context_f.ExitResult(res_f)
1287    finally:
1288      context_f.Exit()
1289
1290    if not strict:
1291      orig_res_t = _UnpackIfSingleton(orig_res_t)
1292      orig_res_f = _UnpackIfSingleton(orig_res_f)
1293
1294    # Check that the return values of the two branches have the same structure.
1295    try:
1296      nest.assert_same_structure(orig_res_t, orig_res_f, expand_composites=True)
1297    except (TypeError, ValueError):
1298      nest.map_structure(_cast_indexed_slice_indices, orig_res_t, orig_res_f)
1299      nest.map_structure(_cast_indexed_slice_indices, res_t, res_f)
1300      try:
1301        nest.assert_same_structure(orig_res_t, orig_res_f,
1302                                   expand_composites=True)
1303      except TypeError as e:
1304        raise TypeError(
1305            "Incompatible return types of true_fn and false_fn: {}".format(e))
1306      except ValueError as e:
1307        raise ValueError(
1308            "Incompatible return values of true_fn and false_fn: {}".format(e))
1309
1310    # Add the final merge to the graph.
1311    if not res_t:
1312      raise ValueError("true_fn and false_fn must return at least one result.")
1313
1314    res_t_flat = nest.flatten(res_t, expand_composites=True)
1315    res_f_flat = nest.flatten(res_f, expand_composites=True)
1316
1317    for (x, y) in zip(res_t_flat, res_f_flat):
1318      assert isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)
1319      if x.dtype.base_dtype != y.dtype.base_dtype:
1320        raise ValueError(
1321            "Outputs of true_fn and false_fn must have the same type: "
1322            "%s, %s" % (x.dtype.name, y.dtype.name))
1323
1324    merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
1325    merges = _convert_flows_to_tensorarrays(
1326        nest.flatten(orig_res_t, expand_composites=True), merges)
1327
1328    # Only add non-nested conds to the collection. Any nested control flow will
1329    # be encapsulated in the root context.
1330    assert context_t.outer_context == context_f.outer_context
1331    if context_t.outer_context is None:
1332      ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
1333      ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
1334
1335    merges = nest.pack_sequence_as(
1336        structure=orig_res_t, flat_sequence=merges, expand_composites=True)
1337
1338    # Singleton lists and tuples are automatically unpacked if strict == False.
1339    if not strict:
1340      merges = _UnpackIfSingleton(merges)
1341    return merges
1342
1343
1344def _cast_indexed_slice_indices(a, b):
1345  """Cast IndexedSlice.indices from int32 to int64 where necessary.
1346
1347  If `a` and `b` are both IndexedSlices, and their indices have different
1348  dtypes, then cast both their dtypes to `int64` (modifies `a` and `b`
1349  in-place).  Otherwise, does nothing.
1350
1351  Args:
1352    a: A value, which may be an IndexedSlices.
1353    b: A value, which may be an IndexedSlices.
1354  """
1355  if (isinstance(a, ops.IndexedSlices) and isinstance(b, ops.IndexedSlices)
1356      and a.indices.dtype != b.indices.dtype):
1357    # pylint: disable=protected-access
1358    a._indices = math_ops.cast(a.indices, dtypes.int64)
1359    b._indices = math_ops.cast(b.indices, dtypes.int64)
1360
1361
1362# pylint: enable=g-doc-args
1363# pylint: enable=redefined-outer-name
1364
1365
1366@tf_export("cond", v1=[])
1367@dispatch.add_dispatch_support
1368def cond_for_tf_v2(pred, true_fn=None, false_fn=None, name=None):
1369  """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
1370
1371  `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
1372  `false_fn` must have the same non-zero number and type of outputs.
1373
1374  **WARNING**: Any Tensors or Operations created outside of `true_fn` and
1375  `false_fn` will be executed regardless of which branch is selected at runtime.
1376
1377  Although this behavior is consistent with the dataflow model of TensorFlow,
1378  it has frequently surprised users who expected a lazier semantics.
1379  Consider the following simple program:
1380
1381  ```python
1382  z = tf.multiply(a, b)
1383  result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
1384  ```
1385
1386  If `x < y`, the `tf.add` operation will be executed and `tf.square`
1387  operation will not be executed. Since `z` is needed for at least one
1388  branch of the `cond`, the `tf.multiply` operation is always executed,
1389  unconditionally.
1390
1391  Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
1392  call to `cond`, and not at all during `Session.run()`). `cond`
1393  stitches together the graph fragments created during the `true_fn` and
1394  `false_fn` calls with some additional graph nodes to ensure that the right
1395  branch gets executed depending on the value of `pred`.
1396
1397  `tf.cond` supports nested structures as implemented in
1398  `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
1399  same (possibly nested) value structure of lists, tuples, and/or named tuples.
1400  Singleton lists and tuples form the only exceptions to this: when returned by
1401  `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
1402
1403  Note: It is illegal to "directly" use tensors created inside a cond branch
1404  outside it, e.g. by storing a reference to a branch tensor in the python
1405  state. If you need to use a tensor created in a branch function you should
1406  return it as an output of the branch function and use the output from
1407  `tf.cond` instead.
1408
1409  Args:
1410    pred: A scalar determining whether to return the result of `true_fn` or
1411      `false_fn`.
1412    true_fn: The callable to be performed if pred is true.
1413    false_fn: The callable to be performed if pred is false.
1414    name: Optional name prefix for the returned tensors.
1415
1416  Returns:
1417    Tensors returned by the call to either `true_fn` or `false_fn`. If the
1418    callables return a singleton list, the element is extracted from the list.
1419
1420  Raises:
1421    TypeError: if `true_fn` or `false_fn` is not callable.
1422    ValueError: if `true_fn` and `false_fn` do not return the same number of
1423      tensors, or return tensors of different types.
1424
1425  Example:
1426
1427  ```python
1428  x = tf.constant(2)
1429  y = tf.constant(5)
1430  def f1(): return tf.multiply(x, 17)
1431  def f2(): return tf.add(y, 23)
1432  r = tf.cond(tf.less(x, y), f1, f2)
1433  # r is set to f1().
1434  # Operations in f2 (e.g., tf.add) are not executed.
1435  ```
1436
1437  """
1438  return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
1439
1440
1441def _resource_safe_shape(t):
1442  """Returns the shape of t or the variable it points to."""
1443  if t.dtype == dtypes.resource:
1444    while t.op.inputs:
1445      t = t.op.inputs[0]
1446    return tensor_shape.TensorShape(t.op.get_attr("shape"))
1447  return array_ops.shape_internal(t, optimize=False)
1448
1449
1450# TODO(yuanbyu): Consider having a unified notion of context for
1451# not only conditionals and loops but also control dependency and
1452# subgraphs.
1453class WhileContext(ControlFlowContext):
1454  """The context for the loop construct."""
1455
1456  def __init__(self,
1457               maximum_iterations=None,
1458               parallel_iterations=10,
1459               back_prop=True,
1460               swap_memory=False,
1461               name="while_context",
1462               grad_state=None,
1463               context_def=None,
1464               import_scope=None):
1465    """"Creates a `WhileContext`.
1466
1467    Args:
1468      maximum_iterations: Optional upper bound on number of loop iterations.
1469      parallel_iterations: The number of iterations allowed to run in parallel.
1470      back_prop: Whether backprop is enabled for this while loop.
1471      swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
1472      name: Optional name prefix for the returned tensors.
1473      grad_state: The gradient loop state.
1474      context_def: Optional `WhileContextDef` protocol buffer to initialize the
1475        `Whilecontext` python object from.
1476      import_scope: Optional `string`. Name scope to add. Only used when
1477        initialing from protocol buffer.
1478    """
1479    if context_def:
1480      self._init_from_proto(context_def, import_scope=import_scope)
1481    else:
1482      ControlFlowContext.__init__(self)
1483      self._init_from_args(maximum_iterations, parallel_iterations, back_prop,
1484                           swap_memory, name)
1485    # The gradient loop state.
1486    self._grad_state = grad_state
1487
1488  def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop,
1489                      swap_memory, name):
1490    """Creates a new `WhileContext` from arguments.
1491
1492    Args:
1493      maximum_iterations: Optional upper bound on number of loop iterations.
1494      parallel_iterations: The number of iterations allowed to run in parallel.
1495      back_prop: Whether backprop is enabled for this while loop.
1496      swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
1497      name: Optional name prefix for the returned tensors.
1498
1499    Raises:
1500      ValueError: If `parallel_iterations` has invalid value.
1501    """
1502    if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0):
1503      raise ValueError("`parallel_iterations` must be a positive integer: "
1504                       "%s" % parallel_iterations)
1505    self._name = ops.get_default_graph().unique_name(name)
1506    self._maximum_iterations = maximum_iterations
1507    self._parallel_iterations = parallel_iterations
1508    self._back_prop = back_prop
1509    self._swap_memory = swap_memory
1510    # We use this node to control constants created by the pred lambda.
1511    self._pivot_for_pred = None
1512    # We use this node to control constants created by the body lambda.
1513    self._pivot_for_body = None
1514    # The boolean tensor for loop termination condition. Used in code
1515    # generation for gradient computation
1516    self._pivot = None
1517    # The list of exit tensors for loop variables.
1518    self._loop_exits = []
1519    # The list of enter tensors for loop variables.
1520    self._loop_enters = []
1521    self._graph = ops.get_default_graph()
1522
1523  def _init_from_proto(self, context_def, import_scope=None):
1524    """Creates a new `WhileContext` from protocol buffer.
1525
1526    Args:
1527      context_def: `WhileContextDef` protocol buffer.
1528      import_scope: Optional `string`. Name scope to add.
1529    """
1530    assert isinstance(context_def, control_flow_pb2.WhileContextDef)
1531    # Create from context_def.
1532    g = ops.get_default_graph()
1533    self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
1534    if context_def.maximum_iterations_name:
1535      self._maximum_iterations = g.as_graph_element(
1536          ops.prepend_name_scope(context_def.maximum_iterations_name,
1537                                 import_scope))
1538    else:
1539      self._maximum_iterations = None
1540    self._parallel_iterations = context_def.parallel_iterations
1541    self._back_prop = context_def.back_prop
1542    self._swap_memory = context_def.swap_memory
1543    self._pivot_for_pred = g.as_graph_element(
1544        ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope))
1545    # We use this node to control constants created by the body lambda.
1546    self._pivot_for_body = g.as_graph_element(
1547        ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope))
1548    # The boolean tensor for loop termination condition. Used in code
1549    # generation for gradient computation.
1550    self._pivot = g.as_graph_element(
1551        ops.prepend_name_scope(context_def.pivot_name, import_scope))
1552    # The list of exit tensors for loop variables.
1553    self._loop_exits = [
1554        g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope))
1555        for exit_name in context_def.loop_exit_names
1556    ]
1557    # The list of enter tensors for loop variables.
1558    self._loop_enters = [
1559        g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope))
1560        for enter_name in context_def.loop_enter_names
1561    ]
1562    super(WhileContext, self).__init__(
1563        values_def=context_def.values_def, import_scope=import_scope)
1564
1565    # import_scope causes self.name to be different from the original serialized
1566    # context's name. Rewrite "frame_name" attrs with the new name.
1567    if import_scope:
1568      for tensor_name in self._values:
1569        op = g.as_graph_element(tensor_name).op
1570        if util.IsLoopEnter(op):
1571          # pylint: disable=protected-access
1572          op._set_attr("frame_name",
1573                       attr_value_pb2.AttrValue(s=compat.as_bytes(self.name)))
1574          # pylint: enable=protected-access
1575    self._graph = ops.get_default_graph()
1576
1577  @property
1578  def maximum_iterations(self):
1579    """The maximum number of iterations that will be executed."""
1580    return self._maximum_iterations
1581
1582  @property
1583  def parallel_iterations(self):
1584    """The number of iterations allowed to run in parallel."""
1585    return self._parallel_iterations
1586
1587  @property
1588  def back_prop(self):
1589    """True iff backprop is enabled for this while loop."""
1590    return self._back_prop
1591
1592  @property
1593  def swap_memory(self):
1594    """True iff GPU-CPU memory swap is enabled for this while loop."""
1595    return self._swap_memory
1596
1597  @property
1598  def pivot(self):
1599    """The boolean tensor representing the loop termination condition."""
1600    return self._pivot
1601
1602  @property
1603  def loop_enters(self):
1604    """The list of enter tensors for loop variables."""
1605    return self._loop_enters
1606
1607  @property
1608  def loop_exits(self):
1609    """The list of exit tensors for loop variables."""
1610    return self._loop_exits
1611
1612  @property
1613  def grad_state(self):
1614    """The gradient loop state."""
1615    return self._grad_state
1616
1617  def to_proto(self, export_scope=None):
1618    """Converts a `WhileContext` to a `WhileContextDef` protocol buffer.
1619
1620    Args:
1621      export_scope: Optional `string`. Name scope to remove.
1622
1623    Returns:
1624      A `WhileContextDef` protocol buffer.
1625    """
1626    if (export_scope is None or self.name.startswith(export_scope)):
1627      context_def = control_flow_pb2.WhileContextDef()
1628      context_def.context_name = ops.strip_name_scope(self.name, export_scope)
1629      context_def.parallel_iterations = self._parallel_iterations
1630      if self._maximum_iterations is not None:
1631        context_def.maximum_iterations_name = ops.strip_name_scope(
1632            self._maximum_iterations.name, export_scope)
1633      context_def.back_prop = self._back_prop
1634      context_def.swap_memory = self._swap_memory
1635      context_def.pivot_for_pred_name = ops.strip_name_scope(
1636          self._pivot_for_pred.name, export_scope)
1637      context_def.pivot_for_body_name = ops.strip_name_scope(
1638          self._pivot_for_body.name, export_scope)
1639      context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
1640                                                    export_scope)
1641      context_def.loop_exit_names.extend([
1642          ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits
1643      ])
1644      context_def.loop_enter_names.extend([
1645          ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters
1646      ])
1647      context_def.values_def.MergeFrom(
1648          super(WhileContext, self)._to_values_def(export_scope=export_scope))
1649      for nested in self._nested_contexts:
1650        nested_def = context_def.nested_contexts.add()
1651        nested.to_control_flow_context_def(nested_def)
1652
1653      return context_def
1654    else:
1655      return None
1656
1657  def to_control_flow_context_def(self, context_def, export_scope=None):
1658    context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
1659
1660  @staticmethod
1661  def from_proto(context_def, import_scope=None):
1662    """Returns a `WhileContext` object created from `context_def`.
1663
1664    Args:
1665      context_def: A `WhileContextDef` protocol buffer.
1666      import_scope: Optional `string`. Name scope to add.
1667
1668    Returns:
1669      A `WhileContext` Python object.
1670    """
1671    ret = WhileContext(context_def=context_def, import_scope=import_scope)
1672    ret.Enter()
1673    for nested_def in context_def.nested_contexts:
1674      from_control_flow_context_def(nested_def, import_scope=import_scope)
1675    ret.Exit()
1676    return ret
1677
1678  def GetWhileContext(self):
1679    return self
1680
1681  def GetControlPivot(self):
1682    if self._pivot_for_body is not None:
1683      return self._pivot_for_body
1684    return self._pivot_for_pred
1685
1686  def AddValue(self, val):
1687    """Add `val` to the current context and its outer context recursively."""
1688    result = val
1689    new_value = val.name not in self._values
1690    # Don't treat ops in this context as new values. Usually all known values
1691    # are in self._values, except when we're importing a while loop inside this
1692    # WhileContext. Since there's a cycle in this case, `val` may be part of the
1693    # imported while loop but not yet processed by this context and added to
1694    # self._values in _AddOpInternal. We only want to process external input
1695    # tensors to the while loop here.
1696    new_value &= val.op._control_flow_context is not self  # pylint: disable=protected-access
1697    if new_value:
1698      self._values.add(val.name)
1699
1700      # If we are in a grad context and val is from its forward context,
1701      # use GetRealValue(), which adds the logic to save the history of
1702      # val in forward.
1703      grad_ctxt = ops.get_default_graph()._get_control_flow_context()
1704      if grad_ctxt:
1705        grad_ctxt = grad_ctxt.GetWhileContext()
1706        if grad_ctxt.grad_state:
1707          forward_ctxt = util.GetWhileContext(val.op)
1708          if util.IsLoopExit(val.op):
1709            forward_ctxt = forward_ctxt.outer_context
1710            if forward_ctxt:
1711              forward_ctxt = forward_ctxt.GetWhileContext()
1712          if forward_ctxt == grad_ctxt.grad_state.forward_context:
1713            real_val = grad_ctxt.grad_state.GetRealValue(val)
1714            self._external_values[val.name] = real_val
1715            return real_val
1716
1717      if self._outer_context is not None:
1718        result = self._outer_context.AddValue(val)
1719      # Create an Enter to make `result` known to this loop context.
1720      with ops.control_dependencies(None):
1721        enter = _Enter(
1722            result,
1723            self._name,
1724            is_constant=True,
1725            parallel_iterations=self._parallel_iterations)
1726        enter.graph.prevent_feeding(enter)
1727        if self._outer_context:
1728          self._outer_context.AddInnerOp(enter.op)
1729      # Fix the control inputs and control flow context of these enter ops.
1730      self._FixControlInputsAndContext([enter])
1731
1732      # Add `enter` in this context.
1733      self._values.add(enter.name)
1734      self._external_values[val.name] = enter
1735      result = enter
1736    else:
1737      actual_val = self._external_values.get(val.name)
1738      if actual_val is not None:
1739        result = actual_val
1740    return result
1741
1742  def AddOp(self, op):
1743    """Add `op` to the current context."""
1744    # For a reduction op, if op is in a grad context and its input is from
1745    # its forward context, moving op to the forward context means we would
1746    # store the tensor after the reduction as opposed to the tensor before
1747    # reduction, and therefore could significantly reduce memory consumption.
1748    # For now, we do this only for a few ops.
1749    #
1750    # If in XLA context, do not move constant ops to forward pass as pushing to
1751    # and popping from a stack removes the constant property of an op and breaks
1752    # XLA compilation, which requires certain inputs to be constant for certain
1753    # ops.
1754    if not util.IsInXLAContext(op) and op.type in {"Shape", "Size", "Rank"}:
1755      grad_ctxt = ops.get_default_graph()._get_control_flow_context()
1756      if grad_ctxt:
1757        grad_ctxt = grad_ctxt.GetWhileContext()
1758        if grad_ctxt.grad_state:
1759          op_input_forward_ctxt = util.GetWhileContext(op.inputs[0].op)
1760          if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context:
1761            op_input_ctxt = op.inputs[0].op._get_control_flow_context()
1762            op._set_control_flow_context(op_input_ctxt)
1763            op_input_ctxt._AddOpInternal(op)
1764            return
1765    self._AddOpInternal(op)
1766
1767  def _AddOpInternal(self, op):
1768    """Add `op` to the current context.
1769
1770    We move any external control dependencies of the op to the loop pivot, to
1771    ensure they get executed.
1772    """
1773    # This is needed to prevent frame mismatch errors where there are Const
1774    # nodes inside tf.function in v1 while_loop and inlining is turned on.
1775    if op.type in ["PartitionedCall", "StatefulPartitionedCall"]:
1776      op._add_control_input(self.GetControlPivot().op)  # pylint: disable=protected-access
1777    if not op.inputs:
1778      # Remove any external control dependency on this op
1779      control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
1780      # Add a control edge from the control pivot to this op.
1781      if not control_inputs:
1782        # pylint: disable=protected-access
1783        op._add_control_input(self.GetControlPivot().op)
1784        # pylint: enable=protected-access
1785      for x in op.outputs:
1786        self._values.add(x.name)
1787    else:
1788      for index in range(len(op.inputs)):
1789        x = op.inputs[index]
1790        real_x = self.AddValue(x)
1791        if real_x != x:
1792          op._update_input(index, real_x)  # pylint: disable=protected-access
1793      # Remove any external control dependency on this op.
1794      _, external_inputs = self._RemoveExternalControlEdges(op)
1795      # Add a control dependency to prevent loop invariants from
1796      # enabling ops that should not be executed.
1797      self._MaybeAddControlDependency(op)
1798      for x in op.outputs:
1799        self._values.add(x.name)
1800    if external_inputs:
1801      # Use an identity to pull control inputs as data inputs. Note that we
1802      # ignore ops which don't have outputs. TODO(apassos): fix that
1803      with ops.control_dependencies(None):
1804        self.Enter()
1805        external_inputs = [
1806            array_ops.identity(x.outputs[0]).op
1807            for x in external_inputs
1808            if x.outputs
1809        ]
1810        self.Exit()
1811      op._add_control_inputs(external_inputs)  # pylint: disable=protected-access
1812    if self._outer_context or not util.IsLoopExit(op):
1813      op.graph.prevent_fetching(op)
1814      for x in op.outputs:
1815        op.graph.prevent_feeding(x)
1816
1817    if self._outer_context:
1818      self._outer_context.AddInnerOp(op)
1819
1820  def _MaybeAddControlDependency(self, op):
1821    """Add a control input to the op if it only depends on loop invariants."""
1822
1823    def _IsOpFree(op):
1824      """Determines if `op` needs a control dependency."""
1825      if op.control_inputs:
1826        return False
1827      # pylint: disable=protected-access
1828      if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
1829        return True
1830      # pylint: enable=protected-access
1831      for x in op.inputs:
1832        if not util.IsLoopConstantEnter(x.op):
1833          return False
1834      return True
1835
1836    if _IsOpFree(op):
1837      # pylint: disable=protected-access
1838      op._add_control_input(self.GetControlPivot().op)
1839      # pylint: enable=protected-access
1840
1841  def AddForwardLoopCounter(self, outer_grad_state):
1842    """Adds a loop that counts the number of iterations.
1843
1844    This is added to the forward loop at the time when we start to
1845    create the loop for backprop gradient computation. Called in
1846    the outer context of this forward context.
1847
1848    The pseudocode is:
1849      `n = 0; while (_pivot) { n++; }`
1850
1851    Note that a control dependency is added to `n` to ensure the correct
1852    execution order of stack push ops.
1853
1854    Args:
1855      outer_grad_state: The outer grad state. None if not nested.
1856
1857    Returns:
1858      The number of iterations taken by the forward loop and the loop index.
1859    """
1860    n = constant_op.constant(0, name="f_count")
1861    if outer_grad_state is not None:
1862      # Force the stack pushes of i-th execution of an inner loop to be ordered
1863      # before the pushes of (i+1)-th execution of the same inner loop.
1864      outer_add_op = outer_grad_state.forward_index.op.inputs[0].op
1865      n.op._add_control_input(outer_add_op)  # pylint: disable=protected-access
1866
1867    self.Enter()
1868    self.AddName(n.name)
1869    enter_n = _Enter(
1870        n,
1871        self._name,
1872        is_constant=False,
1873        parallel_iterations=self._parallel_iterations,
1874        name="f_count")
1875    self.loop_enters.append(enter_n)
1876
1877    merge_n = merge([enter_n, enter_n])[0]
1878    switch_n = switch(merge_n, self._pivot)
1879
1880    index = math_ops.add(switch_n[1], 1)
1881    next_n = _NextIteration(index)
1882    merge_n.op._update_input(1, next_n)
1883
1884    total_iterations = exit(switch_n[0], name="f_count")
1885    self.loop_exits.append(total_iterations)
1886    self.ExitResult([total_iterations])
1887    self.Exit()
1888    return total_iterations, next_n
1889
1890  def AddBackpropLoopCounter(self, count, outer_grad_state):
1891    """Add the backprop loop that controls the iterations.
1892
1893    This is added to the backprop loop. It is used to control the loop
1894    termination of the backprop loop. Called in the outer context of
1895    this grad context.
1896
1897    The pseudocode is:
1898      `n = count; while (n >= 1) { n--; }`
1899
1900    Note that a control dependency is added to `final_zero` to ensure the
1901    correct execution order of stack pop ops.
1902
1903    Args:
1904      count: The number of iterations for backprop.
1905      outer_grad_state: The outer grad state. None if not nested.
1906
1907    Returns:
1908      The loop index.
1909    """
1910    in_separate_functions = count.graph is not ops.get_default_graph()
1911    if in_separate_functions:
1912      # Brings the count into this graph
1913      count = array_ops.identity(count)
1914    else:
1915      # TODO(apassos) XLA expects this constant to be created outside the loop,
1916      # so doing that for now.
1917      one = constant_op.constant(1, name="b_count")
1918
1919    self.Enter()
1920    self.AddName(count.name)
1921    enter_count = _Enter(
1922        count,
1923        self._name,
1924        is_constant=False,
1925        parallel_iterations=self._parallel_iterations,
1926        name="b_count")
1927    self.loop_enters.append(enter_count)
1928
1929    merge_count = merge([enter_count, enter_count])[0]
1930    self._pivot_for_pred = merge_count
1931
1932    if in_separate_functions:
1933      one = constant_op.constant(1, name="b_count")
1934    pred = math_ops.greater_equal(merge_count, one)
1935    self._pivot = loop_cond(pred, name="b_count")
1936    switch_count = switch(merge_count, self._pivot)
1937
1938    index = math_ops.subtract(switch_count[1], one)
1939    self._pivot_for_body = index
1940    next_count = _NextIteration(index)
1941    merge_count.op._update_input(1, next_count)
1942
1943    final_zero = exit(switch_count[0], name="b_count")
1944    self.loop_exits.append(final_zero)
1945    if outer_grad_state is not None:
1946      # Force the stack pops of i-th execution of an inner loop to be ordered
1947      # before the pops of (i+1)-th execution of the same inner loop.
1948      # pylint: disable=protected-access
1949      outer_grad_state.grad_sync._add_control_input(final_zero.op)
1950      # pylint: enable=protected-access
1951
1952    self.ExitResult([final_zero])
1953    self.Exit()
1954    return next_count
1955
1956  def AddBackpropAccumulator(self, op, grad):
1957    """Add an accumulation loop for every loop invariant.
1958
1959    This is added to the backprop loop. It is used to accumulate partial
1960    gradients within each loop iteration. Called when in the gradient while
1961    context.
1962
1963    The pseudocode is:
1964      ```
1965      acc = 0.0;
1966      while (_pivot) {
1967        acc += grad;
1968      }
1969      ```
1970
1971    Args:
1972      op: The Enter op for a loop invariant.
1973      grad: The partial gradient of an iteration for a loop invariant.
1974
1975    Returns:
1976      The gradient for a loop invariant.
1977    """
1978    self.Exit()
1979    # Create a zeros tensor with the right shape for acc. If we don't
1980    # know the full shape statically, we will have to get the shape
1981    # dynamically from the forward inference. Getting the shape right
1982    # for the zeros is only needed for the base case when the loop exits
1983    # without running any iterations.
1984    shape = grad.get_shape()
1985    if shape.is_fully_defined():
1986      if self.outer_context:
1987        self.outer_context.Enter()
1988      acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")
1989      if self.outer_context:
1990        self.outer_context.Exit()
1991    else:
1992      value = op.inputs[0]
1993      if (isinstance(self.outer_context, WhileContext) and
1994          self.outer_context.grad_state is not None):
1995        # We are in a nested while loop.
1996        forward_ctxt = self.grad_state.forward_context
1997        forward_ctxt.outer_context.Enter()
1998        zeros_shape = array_ops.shape_internal(value, optimize=False)
1999        forward_ctxt.outer_context.Exit()
2000        outer_grad_state = self.grad_state.outer_grad_state
2001        history_zeros_shape = outer_grad_state.AddForwardAccumulator(
2002            zeros_shape)
2003        self.outer_context.Enter()
2004        real_shape = outer_grad_state.AddBackpropAccumulatedValue(
2005            history_zeros_shape, zeros_shape)
2006        acc = array_ops.zeros(real_shape, grad.dtype)
2007        self.outer_context.Exit()
2008      else:
2009        if self.outer_context:
2010          self.outer_context.Enter()
2011        zeros_shape = array_ops.shape_internal(value, optimize=False)
2012        acc = array_ops.zeros(zeros_shape, grad.dtype)
2013        if self.outer_context:
2014          self.outer_context.Exit()
2015
2016    self.Enter()
2017    self.AddName(acc.name)
2018    enter_acc = _Enter(
2019        acc,
2020        self._name,
2021        is_constant=False,
2022        parallel_iterations=self._parallel_iterations,
2023        name="b_acc")
2024    self.loop_enters.append(enter_acc)
2025
2026    merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
2027    switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot)
2028
2029    add_acc = math_ops.add(switch_acc_true, grad)
2030    next_acc = _NextIteration(add_acc)
2031    merge_acc.op._update_input(1, next_acc)  # pylint: disable=protected-access
2032
2033    result_acc = exit(switch_acc_false, name="b_acc")
2034    self.loop_exits.append(result_acc)
2035    self.ExitResult([result_acc])
2036    return result_acc
2037
2038  def AddBackpropIndexedSlicesAccumulator(self, op, grad):
2039    """This is used for accumulating gradients that are IndexedSlices.
2040
2041    This is essentially the equivalent of AddBackpropAccumulator but optimized
2042    for things like updating embeddings from within a while loop.
2043
2044    Args:
2045      op: The Enter op for a loop invariant.
2046      grad: The partial gradients represented as an IndexedSlices.
2047
2048    Returns:
2049      The accumulated IndexedSlices gradient of the loop invariant.
2050    """
2051    values = grad.values
2052    indices = grad.indices
2053    dense_shape = grad.dense_shape
2054
2055    self.Exit()
2056    if self.outer_context:
2057      self.outer_context.Enter()
2058    if values.get_shape().is_fully_defined():
2059      values_shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] +
2060                                              values.get_shape().dims[1:])
2061      if self.outer_context:
2062        self.outer_context.Enter()
2063      values_acc = constant_op.constant(
2064          0, values.dtype, shape=values_shape, name="b_acc")
2065      if self.outer_context:
2066        self.outer_context.Exit()
2067    else:
2068      values_shape = _resource_safe_shape(op.inputs[0])[1:]
2069      values_shape = array_ops.concat([[1], values_shape], 0)
2070      values_acc = array_ops.zeros(values_shape, dtype=values.dtype)
2071    indices_acc = constant_op.constant([0], indices.dtype)
2072    shape_acc = None
2073    if dense_shape is not None:
2074      if dense_shape.get_shape().is_fully_defined():
2075        if self.outer_context:
2076          self.outer_context.Enter()
2077        shape_acc = constant_op.constant(
2078            0, dense_shape.dtype, shape=dense_shape.get_shape())
2079        if self.outer_context:
2080          self.outer_context.Exit()
2081      else:
2082        shape_acc = array_ops.zeros_like(
2083            array_ops.shape_internal(
2084                op.inputs[0], optimize=False, out_type=dense_shape.dtype),
2085            optimize=False)
2086
2087    if self.outer_context:
2088      self.outer_context.Exit()
2089
2090    self.Enter()
2091    self.AddName(values_acc.name)
2092    self.AddName(indices_acc.name)
2093    init_acc = [indices_acc, values_acc]
2094    if shape_acc is not None:
2095      self.AddName(shape_acc.name)
2096      init_acc.append(shape_acc)
2097
2098    # Set use_input_shape=False since the accumulator tensors will grow in
2099    # size. If use_input_shape=True, the _update_input call below will result in
2100    # incompatible shapes.
2101    enter_acc = [
2102        _Enter(
2103            x,
2104            self._name,
2105            is_constant=False,
2106            parallel_iterations=self._parallel_iterations,
2107            use_input_shape=False,
2108            name="b_acc") for x in init_acc
2109    ]
2110    # Manually set appropriate partial shapes.
2111    enter_acc[0].set_shape([None])
2112    if values_acc.shape.dims is not None:
2113      enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:])
2114    self.loop_enters.extend(enter_acc)
2115
2116    merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc]
2117    switch_acc = [switch(x, self._pivot) for x in merge_acc]
2118
2119    # The actual accumulation.
2120    acc_indexed_slices = [
2121        array_ops.concat([xa[1], xv], 0)
2122        for xa, xv in zip(switch_acc[:2], [indices, values])
2123    ]
2124    if shape_acc is not None:
2125      # For the shape we just keep the maximum
2126      acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1]))
2127
2128    next_acc = [_NextIteration(x) for x in acc_indexed_slices]
2129    for xm, xn in zip(merge_acc, next_acc):
2130      xm.op._update_input(1, xn)  # pylint: disable=protected-access
2131
2132    exit_acc = [exit(x[0], name="b_acc") for x in switch_acc]
2133    self.loop_exits.extend(exit_acc)
2134
2135    self.ExitResult(exit_acc)
2136    return ops.IndexedSlices(
2137        indices=exit_acc[0],
2138        values=exit_acc[1],
2139        dense_shape=exit_acc[2] if shape_acc is not None else None)
2140
2141  def _InitializeValues(self, values):
2142    """Makes the values known to this context."""
2143    self._values = set()
2144    for x in values:
2145      if isinstance(x, ops.Tensor):
2146        self._values.add(x.name)
2147      else:
2148        raise TypeError("Type %s not supported" % type(x))
2149
2150  def _BuildLoop(self, pred, body, original_loop_vars, loop_vars,
2151                 shape_invariants):
2152    """Core: Add the loop termination condition and body to the graph."""
2153    flat_loop_vars = nest.flatten(original_loop_vars, expand_composites=True)
2154
2155    # Let the context know the loop variables so the loop variables
2156    # would be added in the outer contexts properly.
2157    self._InitializeValues(loop_vars)
2158    real_vars = loop_vars
2159    if self._outer_context:
2160      real_vars = [self._outer_context.AddValue(x) for x in loop_vars]
2161    with ops.control_dependencies(None):
2162      enter_vars = [
2163          _Enter(
2164              x,
2165              self._name,
2166              is_constant=False,
2167              parallel_iterations=self._parallel_iterations,
2168              use_input_shape=(shape_invariants is None)) for x in real_vars
2169      ]
2170      for x in enter_vars:
2171        x.graph.prevent_feeding(x)
2172        if self._outer_context:
2173          self._outer_context.AddInnerOp(x.op)
2174
2175    # Finds the closest enclosing non-None control pivot.
2176    outer_context = self._outer_context
2177    control_pivot = None
2178    while outer_context is not None and control_pivot is None:
2179      control_pivot = outer_context.GetControlPivot()
2180      # pylint: disable=protected-access
2181      outer_context = outer_context._outer_context
2182      # pylint: enable=protected-access
2183
2184    if control_pivot is not None:
2185      for var in enter_vars:
2186        if util.IsLoopConstantEnter(var.op.inputs[0].op):
2187          # pylint: disable=protected-access
2188          var.op._add_control_input(control_pivot.op)
2189          # pylint: enable=protected-access
2190    _SetShapeInvariants(real_vars, enter_vars, shape_invariants)
2191
2192    # Fix the control inputs and control flow context of these enter ops.
2193    self._FixControlInputsAndContext(enter_vars)
2194    self._InitializeValues(enter_vars)
2195    self._loop_enters = enter_vars
2196
2197    merge_vars = [merge([x, x])[0] for x in enter_vars]
2198    self._pivot_for_pred = merge_vars[0]
2199
2200    # Build the graph for pred.
2201    merge_vars_with_tensor_arrays = (
2202        _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars))
2203    packed_vars = nest.pack_sequence_as(
2204        structure=original_loop_vars,
2205        flat_sequence=merge_vars_with_tensor_arrays,
2206        expand_composites=True)
2207    c = ops.convert_to_tensor(pred(*packed_vars))
2208    self._pivot = loop_cond(c, name="LoopCond")
2209    switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]
2210
2211    # Build the graph for body.
2212    vars_for_body = [_Identity(x[1]) for x in switch_vars]
2213    self._pivot_for_body = vars_for_body[0]
2214    # Convert TensorArray flow variables inside the context back into
2215    # their associated TensorArrays for calling the body.
2216    vars_for_body_with_tensor_arrays = (
2217        _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body))
2218    packed_vars_for_body = nest.pack_sequence_as(
2219        structure=original_loop_vars,
2220        flat_sequence=vars_for_body_with_tensor_arrays,
2221        expand_composites=True)
2222    pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2223    body_result = body(*packed_vars_for_body)
2224    post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2225    if not nest.is_sequence_or_composite(body_result):
2226      body_result = [body_result]
2227    if len(post_summaries) > len(pre_summaries):
2228      new_summaries = post_summaries[len(pre_summaries):]
2229      summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2230      summary_ref[:] = pre_summaries
2231      with ops.control_dependencies(new_summaries):
2232
2233        def map_fn(x):
2234          # TODO(apassos) figure out how to trigger with tensor arrays as well
2235          if isinstance(x, tensor_array_ops.TensorArray):
2236            return x
2237          return array_ops.identity(x)
2238
2239        body_result = nest.map_structure(
2240            map_fn, body_result, expand_composites=True)
2241
2242    # Compare the structure types of input and output of body.
2243    # For backwards compatibility, the first layer is forced to a list
2244    # during this comparison, because inputs are typically lists and
2245    # outputs of the body are typically tuples.
2246    nest.assert_same_structure(
2247        list(packed_vars_for_body), list(body_result), expand_composites=True)
2248
2249    # Store body_result to keep track of TensorArrays returned by body
2250    original_body_result = body_result
2251    # Convert TensorArrays returned by body into their flow variables
2252    result = nest.map_structure(
2253        _convert_tensorarray_to_flow,
2254        nest.flatten(body_result, expand_composites=True),
2255        expand_composites=True)
2256    result = ops.convert_n_to_tensor_or_composite(result)
2257
2258    # Add NextIteration and the back edges to complete the loop.
2259    if len(merge_vars) != len(result):
2260      raise ValueError("Number of inputs and outputs of body must match "
2261                       "loop_vars: %d, %d" % (len(merge_vars), len(result)))
2262    next_vars = []
2263    for m, v in zip(merge_vars, result):
2264      next_vars.append(_AddNextAndBackEdge(m, v))
2265
2266    # Add the exit ops.
2267    exit_vars = [exit(x[0]) for x in switch_vars]
2268    self._loop_exits = exit_vars
2269
2270    # Exit the loop.
2271    self.ExitResult(exit_vars)
2272
2273    return original_body_result, exit_vars
2274
2275  def BuildLoop(self, pred, body, loop_vars, shape_invariants,
2276                return_same_structure):
2277    """Add the loop termination condition and body to the graph."""
2278
2279    # Keep original_loop_vars to identify which are TensorArrays
2280    original_loop_vars = loop_vars
2281    # Convert TensorArrays to their flow variables
2282    loop_vars = nest.map_structure(
2283        _convert_tensorarray_to_flow,
2284        nest.flatten(loop_vars, expand_composites=False),
2285        expand_composites=True)
2286    loop_vars = ops.convert_n_to_tensor_or_composite(loop_vars)
2287    if shape_invariants is None:
2288      shape_invariants = nest.map_structure(
2289          _get_shape_invariant, loop_vars, expand_composites=False)
2290    loop_vars = nest.flatten(loop_vars, expand_composites=True)
2291    try:
2292      self.Enter()
2293      # _BuildLoop calls _update_input in several places. _mutation_lock()
2294      # ensures a Session.run call cannot occur between creating and mutating
2295      # new ops.
2296      with ops.get_default_graph()._mutation_lock():  # pylint: disable=protected-access
2297        original_body_result, exit_vars = self._BuildLoop(
2298            pred, body, original_loop_vars, loop_vars, shape_invariants)
2299    finally:
2300      self.Exit()
2301
2302    flat_result = nest.flatten(original_body_result, expand_composites=True)
2303    # Convert TensorArray flow variables outside the context back into
2304    # their associated TensorArrays for returning to caller.
2305    exit_vars_with_tensor_arrays = (
2306        _convert_flows_to_tensorarrays(flat_result, exit_vars))
2307    packed_exit_vars = nest.pack_sequence_as(
2308        structure=original_body_result,
2309        flat_sequence=exit_vars_with_tensor_arrays,
2310        expand_composites=True)
2311
2312    if return_same_structure:
2313      return packed_exit_vars
2314    else:
2315      return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars
2316
2317  def _FixControlInputsAndContext(self, enters):
2318    graph = ops.get_default_graph()
2319    # pylint: disable=protected-access
2320    for e in enters:
2321      if isinstance(e, ops.Tensor):
2322        xs = [e]
2323      else:
2324        raise TypeError("Type %s not supported" % type(e))
2325      for x in xs:
2326        inp_op = x.op.inputs[0].op
2327        control_inputs = graph._control_dependencies_for_inputs([inp_op])
2328        outer_control_inputs = []
2329        for op in control_inputs:
2330          # We need to keep control inputs that are in any ancestor
2331          # ControlFlowContext, and within outer WhileContext.
2332          keep_as_control_input = True
2333          op_ctxt = util.GetOutputContext(op)
2334          outer_ctxt = self.outer_context
2335          outer_while_context = (None if outer_ctxt is None else
2336                                 outer_ctxt.GetWhileContext())
2337          while outer_ctxt != op_ctxt:
2338            if outer_ctxt is None or outer_ctxt == outer_while_context:
2339              keep_as_control_input = False
2340              break
2341            outer_ctxt = outer_ctxt.outer_context
2342          if keep_as_control_input:
2343            outer_control_inputs.append(op)
2344        x.op._set_control_flow_context(self)
2345        x.op._add_control_inputs(outer_control_inputs)
2346        graph._record_op_seen_by_control_dependencies(x.op)
2347    # pylint: enable=protected-access
2348
2349  def IsWhileContext(self):
2350    return True
2351
2352
2353# @TODO(b/133606651) Replace "shape_invariants" with "loop_vars_signature".
2354# pylint: disable=redefined-outer-name
2355@tf_export("while_loop", v1=[])
2356@deprecation.deprecated_arg_values(
2357    None,
2358    """back_prop=False is deprecated. Consider using tf.stop_gradient instead.
2359Instead of:
2360results = tf.while_loop(c, b, vars, back_prop=False)
2361Use:
2362results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))""",
2363    warn_once=True,
2364    back_prop=False)
2365def while_loop_v2(cond,
2366                  body,
2367                  loop_vars,
2368                  shape_invariants=None,
2369                  parallel_iterations=10,
2370                  back_prop=True,
2371                  swap_memory=False,
2372                  maximum_iterations=None,
2373                  name=None):
2374  """Repeat `body` while the condition `cond` is true.
2375
2376  `cond` is a callable returning a boolean scalar tensor. `body` is a callable
2377  returning a (possibly nested) tuple, namedtuple or list of tensors of the same
2378  arity (length and structure) and types as `loop_vars`. `loop_vars` is a
2379  (possibly nested) tuple, namedtuple or list of tensors that is passed to both
2380  `cond` and `body`. `cond` and `body` both take as many arguments as there are
2381  `loop_vars`.
2382
2383  In addition to regular Tensors or IndexedSlices, the body may accept and
2384  return TensorArray objects.  The flows of the TensorArray objects will
2385  be appropriately forwarded between loops and during gradient calculations.
2386
2387  Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
2388  call to `while_loop`, and not at all during `Session.run()`). `while_loop`
2389  stitches together the graph fragments created during the `cond` and `body`
2390  calls with some additional graph nodes to create the graph flow that
2391  repeats `body` until `cond` returns false.
2392
2393  For correctness, `tf.while_loop()` strictly enforces shape invariants for
2394  the loop variables. A shape invariant is a (possibly partial) shape that
2395  is unchanged across the iterations of the loop. An error will be raised
2396  if the shape of a loop variable after an iteration is determined to be more
2397  general than or incompatible with its shape invariant. For example, a shape
2398  of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
2399  compatible with [11, 17]. By default (if the argument `shape_invariants` is
2400  not specified), it is assumed that the initial shape of each tensor in
2401  `loop_vars` is the same in every iteration. The `shape_invariants` argument
2402  allows the caller to specify a less specific shape invariant for each loop
2403  variable, which is needed if the shape varies between iterations. The
2404  `tf.Tensor.set_shape`
2405  function may also be used in the `body` function to indicate that
2406  the output loop variable has a particular shape. The shape invariant for
2407  SparseTensor and IndexedSlices are treated specially as follows:
2408
2409  a) If a loop variable is a SparseTensor, the shape invariant must be
2410  TensorShape([r]) where r is the rank of the dense tensor represented
2411  by the sparse tensor. It means the shapes of the three tensors of the
2412  SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
2413  is the shape of the SparseTensor.dense_shape property. It must be the shape of
2414  a vector.
2415
2416  b) If a loop variable is an IndexedSlices, the shape invariant must be
2417  a shape invariant of the values tensor of the IndexedSlices. It means
2418  the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
2419  [shape.ndims]).
2420
2421  `while_loop` implements non-strict semantics, enabling multiple iterations
2422  to run in parallel. The maximum number of parallel iterations can be
2423  controlled by `parallel_iterations`, which gives users some control over
2424  memory consumption and execution order. For correct programs, `while_loop`
2425  should return the same result for any parallel_iterations > 0.
2426
2427  For training, TensorFlow stores the tensors that are produced in the
2428  forward inference and are needed in back propagation. These tensors are a
2429  main source of memory consumption and often cause OOM errors when training
2430  on GPUs. When the flag swap_memory is true, we swap out these tensors from
2431  GPU to CPU. This for example allows us to train RNN models with very long
2432  sequences and large batches.
2433
2434  Args:
2435    cond: A callable that represents the termination condition of the loop.
2436    body: A callable that represents the loop body.
2437    loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
2438      `Tensor`, and `TensorArray` objects.
2439    shape_invariants: The shape invariants for the loop variables.
2440    parallel_iterations: The number of iterations allowed to run in parallel. It
2441      must be a positive integer.
2442    back_prop: (optional) Deprecated. False disables support for back
2443      propagation. Prefer using `tf.stop_gradient` instead.
2444    swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
2445    maximum_iterations: Optional maximum number of iterations of the while loop
2446      to run.  If provided, the `cond` output is AND-ed with an additional
2447      condition ensuring the number of iterations executed is no greater than
2448      `maximum_iterations`.
2449    name: Optional name prefix for the returned tensors.
2450
2451  Returns:
2452    The output tensors for the loop variables after the loop. The return value
2453      has the same structure as `loop_vars`.
2454
2455  Raises:
2456    TypeError: if `cond` or `body` is not callable.
2457    ValueError: if `loop_vars` is empty.
2458
2459  Example:
2460
2461  ```python
2462  i = tf.constant(0)
2463  c = lambda i: tf.less(i, 10)
2464  b = lambda i: (tf.add(i, 1), )
2465  r = tf.while_loop(c, b, [i])
2466  ```
2467
2468  Example with nesting and a namedtuple:
2469
2470  ```python
2471  import collections
2472  Pair = collections.namedtuple('Pair', 'j, k')
2473  ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
2474  c = lambda i, p: i < 10
2475  b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
2476  ijk_final = tf.while_loop(c, b, ijk_0)
2477  ```
2478
2479  Example using shape_invariants:
2480
2481  ```python
2482  i0 = tf.constant(0)
2483  m0 = tf.ones([2, 2])
2484  c = lambda i, m: i < 10
2485  b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
2486  tf.while_loop(
2487      c, b, loop_vars=[i0, m0],
2488      shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
2489  ```
2490
2491  Example which demonstrates non-strict semantics: In the following
2492  example, the final value of the counter `i` does not depend on `x`. So
2493  the `while_loop` can increment the counter parallel to updates of `x`.
2494  However, because the loop counter at one loop iteration depends
2495  on the value at the previous iteration, the loop counter itself cannot
2496  be incremented in parallel. Hence if we just want the final value of the
2497  counter (which we print on the line `print(sess.run(i))`), then
2498  `x` will never be incremented, but the counter will be updated on a
2499  single thread. Conversely, if we want the value of the output (which we
2500  print on the line `print(sess.run(out).shape)`), then the counter may be
2501  incremented on its own thread, while `x` can be incremented in
2502  parallel on a separate thread. In the extreme case, it is conceivable
2503  that the thread incrementing the counter runs until completion before
2504  `x` is incremented even a single time. The only thing that can never
2505  happen is that the thread updating `x` can never get ahead of the
2506  counter thread because the thread incrementing `x` depends on the value
2507  of the counter.
2508
2509  ```python
2510  import tensorflow as tf
2511
2512  n = 10000
2513  x = tf.constant(list(range(n)))
2514  c = lambda i, x: i < n
2515  b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
2516  [i], "x:"))
2517  i, out = tf.while_loop(c, b, (0, x))
2518  with tf.compat.v1.Session() as sess:
2519      print(sess.run(i))  # prints [0] ... [9999]
2520
2521      # The following line may increment the counter and x in parallel.
2522      # The counter thread may get ahead of the other thread, but not the
2523      # other way around. So you may see things like
2524      # [9996] x:[9987]
2525      # meaning that the counter thread is on iteration 9996,
2526      # while the other thread is on iteration 9987
2527      print(sess.run(out).shape)
2528  ```
2529
2530  """
2531  return while_loop(
2532      cond=cond,
2533      body=body,
2534      loop_vars=loop_vars,
2535      shape_invariants=shape_invariants,
2536      parallel_iterations=parallel_iterations,
2537      back_prop=back_prop,
2538      swap_memory=swap_memory,
2539      name=name,
2540      maximum_iterations=maximum_iterations,
2541      return_same_structure=True)
2542
2543
2544# pylint: disable=redefined-outer-name
2545@tf_export(v1=["while_loop"])
2546def while_loop(cond,
2547               body,
2548               loop_vars,
2549               shape_invariants=None,
2550               parallel_iterations=10,
2551               back_prop=True,
2552               swap_memory=False,
2553               name=None,
2554               maximum_iterations=None,
2555               return_same_structure=False):
2556  """Repeat `body` while the condition `cond` is true.
2557
2558  `cond` is a callable returning a boolean scalar tensor. `body` is a callable
2559  returning a (possibly nested) tuple, namedtuple or list of tensors of the same
2560  arity (length and structure) and types as `loop_vars`. `loop_vars` is a
2561  (possibly nested) tuple, namedtuple or list of tensors that is passed to both
2562  `cond` and `body`. `cond` and `body` both take as many arguments as there are
2563  `loop_vars`.
2564
2565  In addition to regular Tensors or IndexedSlices, the body may accept and
2566  return TensorArray objects.  The flows of the TensorArray objects will
2567  be appropriately forwarded between loops and during gradient calculations.
2568
2569  Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
2570  call to `while_loop`, and not at all during `Session.run()`). `while_loop`
2571  stitches together the graph fragments created during the `cond` and `body`
2572  calls with some additional graph nodes to create the graph flow that
2573  repeats `body` until `cond` returns false.
2574
2575  For correctness, `tf.while_loop()` strictly enforces shape invariants for
2576  the loop variables. A shape invariant is a (possibly partial) shape that
2577  is unchanged across the iterations of the loop. An error will be raised
2578  if the shape of a loop variable after an iteration is determined to be more
2579  general than or incompatible with its shape invariant. For example, a shape
2580  of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
2581  compatible with [11, 17]. By default (if the argument `shape_invariants` is
2582  not specified), it is assumed that the initial shape of each tensor in
2583  `loop_vars` is the same in every iteration. The `shape_invariants` argument
2584  allows the caller to specify a less specific shape invariant for each loop
2585  variable, which is needed if the shape varies between iterations. The
2586  `tf.Tensor.set_shape`
2587  function may also be used in the `body` function to indicate that
2588  the output loop variable has a particular shape. The shape invariant for
2589  SparseTensor and IndexedSlices are treated specially as follows:
2590
2591  a) If a loop variable is a SparseTensor, the shape invariant must be
2592  TensorShape([r]) where r is the rank of the dense tensor represented
2593  by the sparse tensor. It means the shapes of the three tensors of the
2594  SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
2595  is the shape of the SparseTensor.dense_shape property. It must be the shape of
2596  a vector.
2597
2598  b) If a loop variable is an IndexedSlices, the shape invariant must be
2599  a shape invariant of the values tensor of the IndexedSlices. It means
2600  the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
2601  [shape.ndims]).
2602
2603  `while_loop` implements non-strict semantics, enabling multiple iterations
2604  to run in parallel. The maximum number of parallel iterations can be
2605  controlled by `parallel_iterations`, which gives users some control over
2606  memory consumption and execution order. For correct programs, `while_loop`
2607  should return the same result for any parallel_iterations > 0.
2608
2609  For training, TensorFlow stores the tensors that are produced in the
2610  forward inference and are needed in back propagation. These tensors are a
2611  main source of memory consumption and often cause OOM errors when training
2612  on GPUs. When the flag swap_memory is true, we swap out these tensors from
2613  GPU to CPU. This for example allows us to train RNN models with very long
2614  sequences and large batches.
2615
2616  Args:
2617    cond: A callable that represents the termination condition of the loop.
2618    body: A callable that represents the loop body.
2619    loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
2620      `Tensor`, and `TensorArray` objects.
2621    shape_invariants: The shape invariants for the loop variables.
2622    parallel_iterations: The number of iterations allowed to run in parallel. It
2623      must be a positive integer.
2624    back_prop: Whether backprop is enabled for this while loop.
2625    swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
2626    name: Optional name prefix for the returned tensors.
2627    maximum_iterations: Optional maximum number of iterations of the while loop
2628      to run.  If provided, the `cond` output is AND-ed with an additional
2629      condition ensuring the number of iterations executed is no greater than
2630      `maximum_iterations`.
2631    return_same_structure: If True, output has same structure as `loop_vars`. If
2632      eager execution is enabled, this is ignored (and always treated as True).
2633
2634  Returns:
2635    The output tensors for the loop variables after the loop.
2636     If `return_same_structure` is True, the return value has the same
2637     structure as `loop_vars`.
2638     If `return_same_structure` is False, the return value is a Tensor,
2639     TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list
2640     otherwise.
2641
2642  Raises:
2643    TypeError: if `cond` or `body` is not callable.
2644    ValueError: if `loop_vars` is empty.
2645
2646  Example:
2647
2648  ```python
2649  i = tf.constant(0)
2650  c = lambda i: tf.less(i, 10)
2651  b = lambda i: tf.add(i, 1)
2652  r = tf.while_loop(c, b, [i])
2653  ```
2654
2655  Example with nesting and a namedtuple:
2656
2657  ```python
2658  import collections
2659  Pair = collections.namedtuple('Pair', 'j, k')
2660  ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
2661  c = lambda i, p: i < 10
2662  b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
2663  ijk_final = tf.while_loop(c, b, ijk_0)
2664  ```
2665
2666  Example using shape_invariants:
2667
2668  ```python
2669  i0 = tf.constant(0)
2670  m0 = tf.ones([2, 2])
2671  c = lambda i, m: i < 10
2672  b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
2673  tf.while_loop(
2674      c, b, loop_vars=[i0, m0],
2675      shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
2676  ```
2677
2678  Example which demonstrates non-strict semantics: In the following
2679  example, the final value of the counter `i` does not depend on `x`. So
2680  the `while_loop` can increment the counter parallel to updates of `x`.
2681  However, because the loop counter at one loop iteration depends
2682  on the value at the previous iteration, the loop counter itself cannot
2683  be incremented in parallel. Hence if we just want the final value of the
2684  counter (which we print on the line `print(sess.run(i))`), then
2685  `x` will never be incremented, but the counter will be updated on a
2686  single thread. Conversely, if we want the value of the output (which we
2687  print on the line `print(sess.run(out).shape)`), then the counter may be
2688  incremented on its own thread, while `x` can be incremented in
2689  parallel on a separate thread. In the extreme case, it is conceivable
2690  that the thread incrementing the counter runs until completion before
2691  `x` is incremented even a single time. The only thing that can never
2692  happen is that the thread updating `x` can never get ahead of the
2693  counter thread because the thread incrementing `x` depends on the value
2694  of the counter.
2695
2696  ```python
2697  import tensorflow as tf
2698
2699  n = 10000
2700  x = tf.constant(list(range(n)))
2701  c = lambda i, x: i < n
2702  b = lambda i, x: (tf.compat.v1.Print(i + 1, [i]), tf.compat.v1.Print(x + 1,
2703  [i], "x:"))
2704  i, out = tf.while_loop(c, b, (0, x))
2705  with tf.compat.v1.Session() as sess:
2706      print(sess.run(i))  # prints [0] ... [9999]
2707
2708      # The following line may increment the counter and x in parallel.
2709      # The counter thread may get ahead of the other thread, but not the
2710      # other way around. So you may see things like
2711      # [9996] x:[9987]
2712      # meaning that the counter thread is on iteration 9996,
2713      # while the other thread is on iteration 9987
2714      print(sess.run(out).shape)
2715  ```
2716
2717  """
2718  if not callable(cond):
2719    raise TypeError("cond must be callable.")
2720  if not callable(body):
2721    raise TypeError("body must be callable.")
2722  if parallel_iterations < 1:
2723    raise TypeError("parallel_iterations must be a positive integer.")
2724
2725  # Always enable control flow v2 if building a function, regardless of toggle.
2726  executing_eagerly = context.executing_eagerly()
2727  if (util.EnableControlFlowV2(ops.get_default_graph()) and
2728      not executing_eagerly):
2729    return while_v2.while_loop(
2730        cond,
2731        body,
2732        loop_vars,
2733        shape_invariants=shape_invariants,
2734        parallel_iterations=parallel_iterations,
2735        maximum_iterations=maximum_iterations,
2736        name=name,
2737        return_same_structure=return_same_structure,
2738        back_prop=back_prop)
2739
2740  with ops.name_scope(name, "while", loop_vars):
2741    if not loop_vars:
2742      raise ValueError("No loop variables provided")
2743    try_to_pack = (len(loop_vars) == 1 and not return_same_structure)
2744    if maximum_iterations is not None:
2745      maximum_iterations = ops.convert_to_tensor(
2746          maximum_iterations, name="maximum_iterations")
2747      if maximum_iterations.shape.ndims != 0:
2748        raise ValueError("maximum_iterations must be a scalar, saw shape: %s" %
2749                         maximum_iterations.shape)
2750
2751      if executing_eagerly:
2752        counter = 0
2753        maximum_iterations = int(maximum_iterations.numpy())
2754      else:
2755        counter = constant_op.constant(
2756            0, dtype=maximum_iterations.dtype, name="iteration_counter")
2757      orig_cond = cond
2758      orig_body = body
2759      if try_to_pack:
2760        loop_vars = (counter, loop_vars[0])
2761        cond = lambda i, lv: (  # pylint: disable=g-long-lambda
2762            math_ops.logical_and(i < maximum_iterations, orig_cond(lv)))
2763        body = lambda i, lv: (i + 1, orig_body(lv))
2764      else:
2765        loop_vars = (counter, loop_vars)
2766        cond = lambda i, lv: (  # pylint: disable=g-long-lambda
2767            math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
2768        body = lambda i, lv: (i + 1, orig_body(*lv))
2769      try_to_pack = False
2770
2771    if executing_eagerly:
2772      packed = False  # whether the body result was packed into a 1-item tuple
2773
2774      loop_var_structure = nest.map_structure(type_spec.type_spec_from_value,
2775                                              list(loop_vars))
2776      while cond(*loop_vars):
2777        loop_vars = body(*loop_vars)
2778        if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
2779          packed = True
2780          loop_vars = (loop_vars,)
2781        nest.assert_same_structure(loop_var_structure, list(loop_vars))
2782
2783      def convert(x):
2784        if isinstance(x, tensor_array_ops.TensorArray):
2785          return x
2786        return ops.convert_to_tensor(x)
2787
2788      loop_vars = nest.map_structure(convert, loop_vars, expand_composites=True)
2789      if maximum_iterations is not None:
2790        return loop_vars[1]
2791      else:
2792        return loop_vars[0] if packed else loop_vars
2793
2794    if shape_invariants is not None:
2795      if maximum_iterations is not None:
2796        shape_invariants = (tensor_shape.TensorShape([]), shape_invariants)
2797
2798      nest.assert_same_structure(
2799          loop_vars, shape_invariants, expand_composites=False)
2800      shape_invariants = nest.map_structure(
2801          _get_shape_invariant,
2802          loop_vars,
2803          shape_invariants,
2804          expand_composites=False)
2805
2806    loop_context = WhileContext(
2807        maximum_iterations=maximum_iterations,
2808        parallel_iterations=parallel_iterations,
2809        back_prop=back_prop,
2810        swap_memory=swap_memory)
2811    # Only add non-nested loops to the collection. Any nested control flow will
2812    # be encapsulated in the root context.
2813    if loop_context.outer_context is None:
2814      ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
2815    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
2816                                    return_same_structure)
2817    if maximum_iterations is not None:
2818      return result[1]
2819    else:
2820      return result
2821
2822
2823# pylint: enable=redefined-outer-name
2824
2825
2826def _AsTensorList(x, p):
2827  """Return x as a list of Tensors or IndexedSlices.
2828
2829  For entries of `x` that are Operations, this returns an Identity of `p`
2830  with a dependency on the operation.
2831
2832  Args:
2833    x: A Tensor/IndexedSlices/Operation or a list or tuple of them.
2834    p: A Tensor to return for entries in `x` that are Operations.
2835
2836  Returns:
2837    A list of Tensors or IndexedSlices.
2838  """
2839  if not isinstance(x, (list, _basetuple)):
2840    x = [x]
2841
2842  l = []
2843  for v in x:
2844    if isinstance(v, ops.Operation):
2845      v = with_dependencies([v], p)
2846    v = ops.convert_to_tensor_or_composite(v)
2847    if isinstance(v, ops.Tensor):
2848      l.append(array_ops.identity(v))
2849    else:
2850      l.append(
2851          ops.IndexedSlices(
2852              array_ops.identity(v.values), array_ops.identity(v.indices)))
2853  return l
2854
2855
2856def _CheckResults(a, b):
2857  assert len(a) == len(b), (
2858      "Values returned by a() and b() must have the same length.")
2859  for x, y in zip(a, b):
2860    assert x.dtype == y.dtype, (
2861        "Values returned by a() [%s] and b() [%s] must have "
2862        "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name))
2863
2864
2865def with_dependencies(dependencies, output_tensor, name=None):
2866  """Produces the content of `output_tensor` only after `dependencies`.
2867
2868  In some cases, a user may want the output of an operation to be
2869  consumed externally only after some other dependencies have run
2870  first. This function ensures returns `output_tensor`, but only after all
2871  operations in `dependencies` have run. Note that this means that there is
2872  no guarantee that `output_tensor` will be evaluated after any `dependencies`
2873  have run.
2874
2875  See also `tf.tuple` and `tf.group`.
2876
2877  Args:
2878    dependencies: Iterable of operations to run before this op finishes.
2879    output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
2880    name: (Optional) A name for this operation.
2881
2882  Returns:
2883    Same as `output_tensor`.
2884
2885  Raises:
2886    TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
2887  """
2888  if context.executing_eagerly():
2889    return output_tensor
2890  with ops.name_scope(name, "control_dependency",
2891                      list(dependencies) + [output_tensor]) as name:
2892    with ops.colocate_with(output_tensor):
2893      with ops.control_dependencies(dependencies):
2894        output_tensor = ops.convert_to_tensor_or_composite(output_tensor)
2895        if isinstance(output_tensor, ops.Tensor):
2896          return _Identity(output_tensor, name=name)
2897        else:
2898          return ops.IndexedSlices(
2899              _Identity(output_tensor.values, name=name), output_tensor.indices,
2900              output_tensor.dense_shape)
2901
2902
2903def _GroupControlDeps(dev, deps, name=None):
2904  with ops.control_dependencies(deps):
2905    if dev is None:
2906      return no_op(name=name)
2907    else:
2908      with ops.device(dev):
2909        return no_op(name=name)
2910
2911
2912# TODO(touts): Accept "inputs" as a list.
2913@tf_export("group")
2914def group(*inputs, **kwargs):
2915  """Create an op that groups multiple operations.
2916
2917  When this op finishes, all ops in `inputs` have finished. This op has no
2918  output.
2919
2920  Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
2921  this method, as ops execute in the expected order thanks to automatic control
2922  dependencies.* Only use `tf.group` when working with v1
2923  `tf.Graph` code.
2924
2925  When operating in a v1-style graph context, ops are not executed in the same
2926  order as specified in the code; TensorFlow will attempt to execute ops in
2927  parallel or in an order convenient to the result it is computing.  `tf.group`
2928  allows you to request that one or more results finish before execution
2929  continues.
2930
2931  `tf.group` creates a single op (of type `NoOp`), and then adds appropriate
2932  control dependencies.  Thus, `c = tf.group(a, b)` will compute the same graph
2933  as this:
2934
2935      with tf.control_dependencies([a, b]):
2936          c = tf.no_op()
2937
2938  See also `tf.tuple` and
2939  `tf.control_dependencies`.
2940
2941  Args:
2942    *inputs: Zero or more tensors to group.
2943    name: A name for this operation (optional).
2944
2945  Returns:
2946    An Operation that executes all its inputs.
2947
2948  Raises:
2949    ValueError: If an unknown keyword argument is provided.
2950  """
2951  if context.executing_eagerly():
2952    return None
2953  name = kwargs.pop("name", None)
2954  if kwargs:
2955    raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys()))
2956  with ops.name_scope(name, "group_deps", inputs) as name:
2957    # Grouping no inputs means do nothing
2958    if not inputs:
2959      return no_op(name=name)
2960
2961    # Sorts *inputs according to their devices.
2962    ops_on_device = {}  # device -> operations specified on the device.
2963    for inp in nest.flatten(inputs, expand_composites=True):
2964      if not hasattr(inp, "device"):
2965        raise TypeError("Expected tf.group() expected Tensor arguments not "
2966                        "'%s' with type '%s'" % (inp, type(inp)))
2967      dev = inp.device
2968      if dev in ops_on_device:
2969        ops_on_device[dev].append(inp)
2970      else:
2971        ops_on_device[dev] = [inp]
2972    if len(ops_on_device) == 1:
2973      # 1-level tree. The root node is the returned NoOp node.
2974      (dev, deps), = ops_on_device.items()
2975      return _GroupControlDeps(dev, deps, name=name)
2976
2977    # 2-level tree. The root node is the returned NoOp node.
2978    # deps contains 1 NoOp node for each device.
2979    deps = []
2980
2981    def device_key(dev):
2982      """A sort key that allows None to be compared to strings."""
2983      return "" if dev is None else dev
2984
2985    for dev in sorted(ops_on_device, key=device_key):
2986      deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
2987
2988    with ops.control_dependencies(deps):
2989      return no_op(name=name)
2990
2991
2992@tf_export("tuple", v1=[])
2993@dispatch.add_dispatch_support
2994def tuple_v2(tensors, control_inputs=None, name=None):
2995  """Groups tensors together.
2996
2997  The returned tensors have the same value as the input tensors, but they
2998  are computed only after all the input tensors have been computed.
2999
3000  Note: *In TensorFlow 2 with eager and/or Autograph, you should not require
3001  this method, as ops execute in the expected order thanks to automatic control
3002  dependencies.* Only use `tf.tuple` when working with v1 `tf.Graph` code.
3003
3004  See also `tf.group` and `tf.control_dependencies`.
3005
3006  Args:
3007    tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
3008    control_inputs: List of additional ops to finish before returning.
3009    name: (optional) A name to use as a `name_scope` for the operation.
3010
3011  Returns:
3012    Same as `tensors`.
3013
3014  Raises:
3015    ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
3016    TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
3017      objects.
3018
3019  """
3020  return tuple(tensors=tensors, name=name, control_inputs=control_inputs)  # pylint: disable=redefined-builtin
3021
3022
3023@tf_export(v1=["tuple"])
3024@dispatch.add_dispatch_support
3025def tuple(tensors, name=None, control_inputs=None):  # pylint: disable=redefined-builtin
3026  """Group tensors together.
3027
3028  This creates a tuple of tensors with the same values as the `tensors`
3029  argument, except that the value of each tensor is only returned after the
3030  values of all tensors have been computed.
3031
3032  `control_inputs` contains additional ops that have to finish before this op
3033  finishes, but whose outputs are not returned.
3034
3035  This can be used as a "join" mechanism for parallel computations: all the
3036  argument tensors can be computed in parallel, but the values of any tensor
3037  returned by `tuple` are only available after all the parallel computations
3038  are done.
3039
3040  See also `tf.group` and
3041  `tf.control_dependencies`.
3042
3043  Args:
3044    tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
3045    name: (optional) A name to use as a `name_scope` for the operation.
3046    control_inputs: List of additional ops to finish before returning.
3047
3048  Returns:
3049    Same as `tensors`.
3050
3051  Raises:
3052    ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
3053    TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
3054      objects.
3055
3056  """
3057  if context.executing_eagerly():
3058    return tensors
3059  with ops.name_scope(name, "tuple", tensors) as name:
3060    tensors = [
3061        t if (isinstance(t, ops.Operation) or tensor_util.is_tf_type(t) or
3062              t is None) else ops.convert_to_tensor(t) for t in tensors
3063    ]
3064    gating_ops = [
3065        t if isinstance(t, ops.Operation) else t.op
3066        for t in tensors
3067        if t is not None
3068    ]
3069    if control_inputs:
3070      for c in control_inputs:
3071        if isinstance(c, ops.Tensor):
3072          c = c.op
3073        elif not isinstance(c, ops.Operation):
3074          raise TypeError("Control input must be Operation or Tensor: %s" % c)
3075        gating_ops.append(c)
3076    # Note that in order to ensure ordering in the pbtxt, we must take care to
3077    # ensure the order here.
3078    gating_ops = sorted(set(gating_ops), key=lambda op: op._id)  # Uniquify ops.
3079    if not gating_ops:
3080      raise ValueError("Must have at least one Tensor: %s" % tensors)
3081    gate = group(*gating_ops)
3082    tpl = []
3083    for t in tensors:
3084      if tensor_util.is_tf_type(t):
3085        tpl.append(with_dependencies([gate], t))
3086      elif isinstance(t, ops.Operation):
3087        with ops.control_dependencies([gate]):
3088          tpl.append(group(t))
3089      else:
3090        tpl.append(None)
3091    return tpl
3092
3093
3094def _assert_at_most_n_true(predicates, n, msg):
3095  """Returns an Assert op that checks that at most n predicates are True.
3096
3097  Args:
3098    predicates: list of bool scalar tensors.
3099    n: maximum number of true predicates allowed.
3100    msg: Error message.
3101  """
3102  preds_c = array_ops.stack(predicates, name="preds_c")
3103  num_true_conditions = math_ops.reduce_sum(
3104      math_ops.cast(preds_c, dtypes.int32), name="num_true_conds")
3105  condition = math_ops.less_equal(num_true_conditions,
3106                                  constant_op.constant(n, name="n_true_conds"))
3107  preds_names = ", ".join(getattr(p, "name", "?") for p in predicates)
3108  error_msg = [
3109      "%s: more than %d conditions (%s) evaluated as True:" %
3110      (msg, n, preds_names), preds_c
3111  ]
3112  return Assert(condition, data=error_msg, summarize=len(predicates))
3113
3114
3115def _case_create_default_action(predicates, actions):
3116  """Creates default action for a list of actions and their predicates.
3117
3118  It uses the input actions to select an arbitrary as default and makes sure
3119  that corresponding predicates have valid values.
3120
3121  Args:
3122    predicates: a list of bool scalar tensors
3123    actions: a list of callable objects which return tensors.
3124
3125  Returns:
3126    a callable
3127  """
3128  k = len(predicates) - 1  # could pick any
3129  predicate, action = predicates[k], actions[k]
3130  other_predicates, other_actions = predicates[:k], actions[:k]
3131
3132  def default_action():
3133    others_msg = ("Implementation error: "
3134                  "selected default action #%d was called, but some of other "
3135                  "predicates are True: " % k)
3136    default_msg = ("Input error: "
3137                   "None of conditions evaluated as True:",
3138                   array_ops.stack(predicates, name="preds_c"))
3139    with ops.control_dependencies([
3140        _assert_at_most_n_true(other_predicates, n=0, msg=others_msg),
3141        Assert(predicate, data=default_msg)
3142    ]):
3143      return action()
3144
3145  return default_action, other_predicates, other_actions
3146
3147
3148def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name,
3149                                       allow_python_preds):
3150  """Verifies input arguments for the case function.
3151
3152  Args:
3153    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
3154      callable which returns a list of tensors.
3155    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3156    name: A name for the case operation.
3157    allow_python_preds: if true, pred_fn_pairs may contain Python bools in
3158      addition to boolean Tensors
3159
3160  Raises:
3161    TypeError: If `pred_fn_pairs` is not a list/dictionary.
3162    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3163    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3164               callable.
3165
3166  Returns:
3167    a tuple <list of scalar bool tensors, list of callables>.
3168  """
3169  if not isinstance(pred_fn_pairs, (list, _basetuple, dict)):
3170    raise TypeError("fns must be a list, tuple, or dict")
3171
3172  if isinstance(pred_fn_pairs, collections.OrderedDict):
3173    pred_fn_pairs = pred_fn_pairs.items()
3174  elif isinstance(pred_fn_pairs, dict):
3175    if context.executing_eagerly():
3176      # No name to sort on in eager mode. Use dictionary traversal order,
3177      # which is nondeterministic in versions of Python < 3.6
3178      if not exclusive:
3179        raise ValueError("Unordered dictionaries are not supported for the "
3180                         "`pred_fn_pairs` argument when `exclusive=False` and "
3181                         "eager mode is enabled.")
3182      pred_fn_pairs = list(pred_fn_pairs.items())
3183    else:
3184      pred_fn_pairs = sorted(
3185          pred_fn_pairs.items(), key=lambda item: item[0].name)
3186      if not exclusive:
3187        logging.warn(
3188            "%s: An unordered dictionary of predicate/fn pairs was "
3189            "provided, but exclusive=False. The order of conditional "
3190            "tests is deterministic but not guaranteed.", name)
3191  for pred_fn_pair in pred_fn_pairs:
3192    if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2:
3193      raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple")
3194    pred, fn = pred_fn_pair
3195
3196    if isinstance(pred, ops.Tensor):
3197      if pred.dtype != dtypes.bool:
3198        raise TypeError("pred must be Tensor of type bool: %s" % pred.name)
3199    elif not allow_python_preds:
3200      raise TypeError("pred must be a Tensor, got: %s" % pred)
3201    elif not isinstance(pred, bool):
3202      raise TypeError("pred must be a Tensor or bool, got: %s" % pred)
3203
3204    if not callable(fn):
3205      raise TypeError("fn for pred %s must be callable." % pred.name)
3206
3207  predicates, actions = zip(*pred_fn_pairs)
3208  return predicates, actions
3209
3210
3211def _case_helper(cond_fn,
3212                 pred_fn_pairs,
3213                 default,
3214                 exclusive,
3215                 name,
3216                 allow_python_preds=False,
3217                 **cond_kwargs):
3218  """Implementation of case that allows for different cond functions.
3219
3220  Args:
3221    cond_fn: method that has signature and semantics of `cond` above.
3222    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
3223      callable which returns a list of tensors.
3224    default: Optional callable that returns a list of tensors.
3225    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3226    name: A name for this operation (optional).
3227    allow_python_preds: if true, pred_fn_pairs may contain Python bools in
3228      addition to boolean Tensors
3229    **cond_kwargs: keyword arguments that will be passed to `cond_fn`.
3230
3231  Returns:
3232    The tensors returned by the first pair whose predicate evaluated to True, or
3233    those returned by `default` if none does.
3234
3235  Raises:
3236    TypeError: If `pred_fn_pairs` is not a list/dictionary.
3237    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3238    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3239               callable.
3240  """
3241  predicates, actions = _case_verify_and_canonicalize_args(
3242      pred_fn_pairs, exclusive, name, allow_python_preds)
3243  with ops.name_scope(name, "case", [predicates]):
3244    if default is None:
3245      default, predicates, actions = _case_create_default_action(
3246          predicates, actions)
3247    fn = default
3248    # To eval conditions in direct order we create nested conditions in reverse:
3249    #   cond_fn(c[0], true_fn=.., false_fn=cond_fn(c[1], ...))
3250    for predicate, action in reversed(list(zip(predicates, actions))):
3251      fn = functools.partial(
3252          cond_fn, predicate, true_fn=action, false_fn=fn, **cond_kwargs)
3253    if exclusive:
3254      with ops.control_dependencies([
3255          _assert_at_most_n_true(
3256              predicates, n=1, msg="Input error: exclusive=True")
3257      ]):
3258        return fn()
3259    else:
3260      return fn()
3261
3262
3263def _indexed_case_verify_and_canonicalize_args(branch_fns, default,
3264                                               branch_index):
3265  """Verifies input arguments for the case function.
3266
3267  Args:
3268    branch_fns: Dict or list of pairs of an `int` and a callable which
3269      returns a list of tensors.
3270    default: Optional callable that returns a list of tensors.
3271    branch_index: Optional int `Tensor`, which selects for the corresponding
3272      pred_fn_pair.
3273
3274  Raises:
3275    TypeError: If `branch_fns` is not a list/dictionary.
3276    TypeError: If `branch_fns` is a list but does not contain 2-tuples or
3277               callables.
3278    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3279               callable.
3280
3281  Returns:
3282    branch_fns: validated list of callables for each branch (default last).
3283  """
3284  if not isinstance(branch_index, ops.Tensor):
3285    raise TypeError("branch_index must be a Tensor, got {}".format(
3286        type(branch_index)))
3287  if not branch_index.dtype.is_integer:
3288    raise TypeError("branch_index must be an integer Tensor, got {}".format(
3289        branch_index.dtype))
3290
3291  if not branch_fns:
3292    raise ValueError("Must provide at least one item in branch_fns")
3293  if not isinstance(branch_fns, (list, _basetuple, dict)):
3294    raise TypeError("branch_fns must be a list, tuple, or dict")
3295
3296  if isinstance(branch_fns, dict):
3297    branch_fns = branch_fns.items()
3298
3299  if all(callable(fn) for fn in branch_fns):
3300    branch_fns = list(enumerate(branch_fns))
3301
3302  for key_fn_pair in branch_fns:
3303    if not isinstance(key_fn_pair, _basetuple) or len(key_fn_pair) != 2:
3304      raise TypeError("Each entry in branch_fns must be a 2-tuple")
3305    key, branch_fn = key_fn_pair
3306
3307    if not isinstance(key, int):
3308      raise TypeError("key must be a Python `int`, got {}".format(type(key)))
3309
3310    if not callable(branch_fn):
3311      raise TypeError("fn for key {} must be callable.".format(key))
3312
3313  keys = [p[0] for p in branch_fns]
3314  if min(keys) < 0 or max(keys) >= len(keys) or len(set(keys)) != len(keys):
3315    raise ValueError(
3316        "branch indices (keys) must form contiguous range of [0 to {}) but "
3317        "found {{{}}}".format(len(keys), ",".join(map(str, sorted(keys)))))
3318  actions = [p[1] for p in sorted(branch_fns)]
3319  if default is not None:
3320    actions.append(default)
3321  return actions
3322
3323
3324def _indexed_case_helper(branch_fns,
3325                         default,
3326                         branch_index,
3327                         name,
3328                         lower_using_switch_merge=None):
3329  """Implementation of case that emits the n-way indexed Case op.
3330
3331  Args:
3332    branch_fns: Dict or list of pairs of a boolean scalar tensor, and a
3333      callable which returns a list of tensors.
3334    default: Optional callable that returns a list of tensors.
3335    branch_index: Optional int `Tensor`, which selects for the corresponding
3336      pred_fn_pair.
3337    name: A name for this operation (optional).
3338    lower_using_switch_merge: Lower this op using switch merge ops (optional).
3339
3340  Returns:
3341    The tensors returned by the pair whose key matched branch_index, or
3342    those returned by `default` if none does.
3343
3344  Raises:
3345    TypeError: If `branch_fns` is not a list/dictionary.
3346    TypeError: If `branch_fns` is a list but does not contain 2-tuples or
3347               callables.
3348    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3349               callable.
3350  """
3351  branch_fns = _indexed_case_verify_and_canonicalize_args(
3352      branch_fns, default, branch_index)
3353  with ops.name_scope(name, "case", [branch_index]):
3354    if context.executing_eagerly() and not hasattr(branch_index, "graph"):
3355      branch_index = array_ops.where(
3356          math_ops.less(branch_index, 0)
3357          | math_ops.greater_equal(branch_index, len(branch_fns)),
3358          len(branch_fns) - 1, branch_index)
3359      return branch_fns[int(branch_index)]()
3360    return cond_v2.indexed_case(
3361        branch_index,
3362        branch_fns,
3363        lower_using_switch_merge=lower_using_switch_merge)
3364
3365
3366@tf_export("case", v1=[])
3367@dispatch.add_dispatch_support
3368def case_v2(pred_fn_pairs,
3369            default=None,
3370            exclusive=False,
3371            strict=False,
3372            name="case"):
3373  """Create a case operation.
3374
3375  See also `tf.switch_case`.
3376
3377  The `pred_fn_pairs` parameter is a list of pairs of size N.
3378  Each pair contains a boolean scalar tensor and a python callable that
3379  creates the tensors to be returned if the boolean evaluates to True.
3380  `default` is a callable generating a list of tensors. All the callables
3381  in `pred_fn_pairs` as well as `default` (if provided) should return the same
3382  number and types of tensors.
3383
3384  If `exclusive==True`, all predicates are evaluated, and an exception is
3385  thrown if more than one of the predicates evaluates to `True`.
3386  If `exclusive==False`, execution stops at the first predicate which
3387  evaluates to True, and the tensors generated by the corresponding function
3388  are returned immediately. If none of the predicates evaluate to True, this
3389  operation returns the tensors generated by `default`.
3390
3391  `tf.case` supports nested structures as implemented in
3392  `tf.contrib.framework.nest`. All of the callables must return the same
3393  (possibly nested) value structure of lists, tuples, and/or named tuples.
3394  Singleton lists and tuples form the only exceptions to this: when returned by
3395  a callable, they are implicitly unpacked to single values. This
3396  behavior is disabled by passing `strict=True`.
3397
3398  @compatibility(v2)
3399  `pred_fn_pairs` could be a dictionary in v1. However, tf.Tensor and
3400  tf.Variable are no longer hashable in v2, so cannot be used as a key for a
3401  dictionary.  Please use a list or a tuple instead.
3402  @end_compatibility
3403
3404
3405  **Example 1:**
3406
3407  Pseudocode:
3408
3409  ```
3410  if (x < y) return 17;
3411  else return 23;
3412  ```
3413
3414  Expressions:
3415
3416  ```python
3417  f1 = lambda: tf.constant(17)
3418  f2 = lambda: tf.constant(23)
3419  r = tf.case([(tf.less(x, y), f1)], default=f2)
3420  ```
3421
3422  **Example 2:**
3423
3424  Pseudocode:
3425
3426  ```
3427  if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
3428  if (x < y) return 17;
3429  else if (x > z) return 23;
3430  else return -1;
3431  ```
3432
3433  Expressions:
3434
3435  ```python
3436  def f1(): return tf.constant(17)
3437  def f2(): return tf.constant(23)
3438  def f3(): return tf.constant(-1)
3439  r = tf.case([(tf.less(x, y), f1), (tf.greater(x, z), f2)],
3440           default=f3, exclusive=True)
3441  ```
3442
3443  Args:
3444    pred_fn_pairs: List of pairs of a boolean scalar tensor and a callable which
3445      returns a list of tensors.
3446    default: Optional callable that returns a list of tensors.
3447    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3448    strict: A boolean that enables/disables 'strict' mode; see above.
3449    name: A name for this operation (optional).
3450
3451  Returns:
3452    The tensors returned by the first pair whose predicate evaluated to True, or
3453    those returned by `default` if none does.
3454
3455  Raises:
3456    TypeError: If `pred_fn_pairs` is not a list/tuple.
3457    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3458    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3459               callable.
3460  """
3461  return _case_helper(
3462      cond,
3463      pred_fn_pairs,
3464      default,
3465      exclusive,
3466      name,
3467      allow_python_preds=False,
3468      strict=strict)
3469
3470
3471@tf_export(v1=["case"])
3472@dispatch.add_dispatch_support
3473def case(pred_fn_pairs,
3474         default=None,
3475         exclusive=False,
3476         strict=False,
3477         name="case"):
3478  """Create a case operation.
3479
3480  See also `tf.switch_case`.
3481
3482  The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
3483  Each pair contains a boolean scalar tensor and a python callable that
3484  creates the tensors to be returned if the boolean evaluates to True.
3485  `default` is a callable generating a list of tensors. All the callables
3486  in `pred_fn_pairs` as well as `default` (if provided) should return the same
3487  number and types of tensors.
3488
3489  If `exclusive==True`, all predicates are evaluated, and an exception is
3490  thrown if more than one of the predicates evaluates to `True`.
3491  If `exclusive==False`, execution stops at the first predicate which
3492  evaluates to True, and the tensors generated by the corresponding function
3493  are returned immediately. If none of the predicates evaluate to True, this
3494  operation returns the tensors generated by `default`.
3495
3496  `tf.case` supports nested structures as implemented in
3497  `tf.contrib.framework.nest`. All of the callables must return the same
3498  (possibly nested) value structure of lists, tuples, and/or named tuples.
3499  Singleton lists and tuples form the only exceptions to this: when returned by
3500  a callable, they are implicitly unpacked to single values. This
3501  behavior is disabled by passing `strict=True`.
3502
3503  If an unordered dictionary is used for `pred_fn_pairs`, the order of the
3504  conditional tests is not guaranteed. However, the order is guaranteed to be
3505  deterministic, so that variables created in conditional branches are created
3506  in fixed order across runs.
3507
3508  @compatibility(eager)
3509  Unordered dictionaries are not supported in eager mode when `exclusive=False`.
3510  Use a list of tuples instead.
3511  @end_compatibility
3512
3513
3514  **Example 1:**
3515
3516  Pseudocode:
3517
3518  ```
3519  if (x < y) return 17;
3520  else return 23;
3521  ```
3522
3523  Expressions:
3524
3525  ```python
3526  f1 = lambda: tf.constant(17)
3527  f2 = lambda: tf.constant(23)
3528  r = tf.case([(tf.less(x, y), f1)], default=f2)
3529  ```
3530
3531  **Example 2:**
3532
3533  Pseudocode:
3534
3535  ```
3536  if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
3537  if (x < y) return 17;
3538  else if (x > z) return 23;
3539  else return -1;
3540  ```
3541
3542  Expressions:
3543
3544  ```python
3545  def f1(): return tf.constant(17)
3546  def f2(): return tf.constant(23)
3547  def f3(): return tf.constant(-1)
3548  r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
3549           default=f3, exclusive=True)
3550  ```
3551
3552  Args:
3553    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
3554      callable which returns a list of tensors.
3555    default: Optional callable that returns a list of tensors.
3556    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3557    strict: A boolean that enables/disables 'strict' mode; see above.
3558    name: A name for this operation (optional).
3559
3560  Returns:
3561    The tensors returned by the first pair whose predicate evaluated to True, or
3562    those returned by `default` if none does.
3563
3564  Raises:
3565    TypeError: If `pred_fn_pairs` is not a list/dictionary.
3566    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3567    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3568               callable.
3569  """
3570  return _case_helper(
3571      cond,
3572      pred_fn_pairs,
3573      default,
3574      exclusive,
3575      name,
3576      allow_python_preds=False,
3577      strict=strict)
3578
3579
3580@tf_export("switch_case")
3581def switch_case(branch_index,
3582                branch_fns,
3583                default=None,
3584                name="switch_case"):
3585  """Create a switch/case operation, i.e. an integer-indexed conditional.
3586
3587  See also `tf.case`.
3588
3589  This op can be substantially more efficient than `tf.case` when exactly one
3590  branch will be selected. `tf.switch_case` is more like a C++ switch/case
3591  statement than `tf.case`, which is more like an if/elif/elif/else chain.
3592
3593  The `branch_fns` parameter is either a dict from `int` to callables, or list
3594  of (`int`, callable) pairs, or simply a list of callables (in which case the
3595  index is implicitly the key). The `branch_index` `Tensor` is used to select an
3596  element in `branch_fns` with matching `int` key, falling back to `default`
3597  if none match, or `max(keys)` if no `default` is provided. The keys must form
3598  a contiguous set from `0` to `len(branch_fns) - 1`.
3599
3600  `tf.switch_case` supports nested structures as implemented in `tf.nest`. All
3601  callables must return the same (possibly nested) value structure of lists,
3602  tuples, and/or named tuples.
3603
3604  **Example:**
3605
3606  Pseudocode:
3607
3608  ```c++
3609  switch (branch_index) {  // c-style switch
3610    case 0: return 17;
3611    case 1: return 31;
3612    default: return -1;
3613  }
3614  ```
3615  or
3616  ```python
3617  branches = {0: lambda: 17, 1: lambda: 31}
3618  branches.get(branch_index, lambda: -1)()
3619  ```
3620
3621  Expressions:
3622
3623  ```python
3624  def f1(): return tf.constant(17)
3625  def f2(): return tf.constant(31)
3626  def f3(): return tf.constant(-1)
3627  r = tf.switch_case(branch_index, branch_fns={0: f1, 1: f2}, default=f3)
3628  # Equivalent: tf.switch_case(branch_index, branch_fns={0: f1, 1: f2, 2: f3})
3629  ```
3630
3631  Args:
3632    branch_index: An int Tensor specifying which of `branch_fns` should be
3633      executed.
3634    branch_fns: A `dict` mapping `int`s to callables, or a `list` of
3635      (`int`, callable) pairs, or simply a list of callables (in which case the
3636      index serves as the key). Each callable must return a matching structure
3637      of tensors.
3638    default: Optional callable that returns a structure of tensors.
3639    name: A name for this operation (optional).
3640
3641  Returns:
3642    The tensors returned by the callable identified by `branch_index`, or those
3643    returned by `default` if no key matches and `default` was provided, or those
3644    returned by the max-keyed `branch_fn` if no `default` is provided.
3645
3646  Raises:
3647    TypeError: If `branch_fns` is not a list/dictionary.
3648    TypeError: If `branch_fns` is a list but does not contain 2-tuples or
3649               callables.
3650    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3651               callable.
3652  """
3653  return _indexed_case_helper(branch_fns, default, branch_index, name)
3654
3655
3656@tf_export("__internal__.execute_fn_for_device", v1=[])
3657def execute_fn_for_device(device_branch_fns, default_fn, name="execute_fn"):
3658  """Executes one of the provided callables based on the device placement.
3659
3660  This API is used when the implementations for high level function depend on
3661  the underlying device placement. It takes a dictionary of device type to
3662  callables. The device type includes "CPU", "GPU", "TPU", etc. When the type of
3663  the device where to run this op matches the key in 'device_branch_fns',
3664  the corresponding callable is executed, falling back to 'default_fn' if none
3665  matches.
3666
3667  **Example:**
3668  ```python
3669  def f1(): return tf.constant(1)
3670  def f2(): return tf.constant(2)
3671  r = tf.execute_fn_for_device({"CPU": f1, "GPU": f2}, default_fn=f1)
3672  ```
3673  'r' is evaluated as 1 when it runs on CPU, 2 running on GPU, 1 running on
3674  any other device types.
3675
3676
3677  Args:
3678    device_branch_fns: a dictionary of device types to the callables. Each
3679      callable must return a matching structure of tensors.
3680    default_fn: fallback callable when the underlying device does not match any
3681      key in the 'device_branch_fns'.
3682    name: A name for this operation (optional).
3683
3684  Returns:
3685    The tensors returned by the callable identified by device type during
3686    execution, or those returned by 'default_fn' if no key matches.
3687  """
3688  # Always execute the default fn for XLA to avoid complicated graph by case op.
3689  # see more discussions in b/167276293.
3690  is_in_xla = util.GraphOrParentsInXlaContext(ops.get_default_graph())
3691  if is_in_xla:
3692    return default_fn()
3693  device_branch_fns_upper = {k.upper(): v for k, v in device_branch_fns.items()}
3694  branch_fns = list(device_branch_fns_upper.values())
3695  devices = list(device_branch_fns_upper.keys())
3696  device_index = gen_functional_ops.device_index(device_names=devices)
3697  return _indexed_case_helper(
3698      branch_fns,
3699      default_fn,
3700      device_index,
3701      name,
3702      lower_using_switch_merge=False)
3703
3704
3705class XLAControlFlowContext(ControlFlowContext):
3706  """Base class for XLA and TPU control flow contexts."""
3707
3708  def __init__(self):
3709    super(XLAControlFlowContext, self).__init__()
3710    self._name = "XLAControlFlowContext"
3711
3712  def to_control_flow_context_def(self, context_def, export_scope=None):
3713    # pylint: disable=useless-super-delegation
3714    # NOTE(slebedev): the method is required by `ControlFlowContext`.
3715    super(XLAControlFlowContext,
3716          self).to_control_flow_context_def(context_def, export_scope)
3717
3718  def IsXLAContext(self):
3719    return True
3720
3721  def AddOp(self, _):
3722    pass
3723
3724  def AddValue(self, x):
3725    return x
3726
3727  def RequiresUniqueFunctionRetracing(self):
3728    """Returns whether the tf.function should be retraced if the context changes.
3729    """
3730    return False
3731
3732
3733@tf_export("__internal__.get_enclosing_xla_context", v1=[])
3734def get_enclosing_xla_context():
3735  """Recursively find and return the XLAControlFlowContext."""
3736  graph = ops.get_default_graph()
3737  while graph is not None:
3738    # pylint: disable=protected-access
3739    context_ = graph._get_control_flow_context()
3740    # pylint: enable=protected-access
3741    while context_ is not None:
3742      if isinstance(context_, XLAControlFlowContext):
3743        return context_
3744      context_ = context_.outer_context
3745    # This may be a FuncGraph due to defuns or v2 control flow. We need to
3746    # find the original graph with the XLAControlFlowContext.
3747    graph = getattr(graph, "outer_graph", None)
3748  return None
3749
3750
3751def from_control_flow_context_def(context_def, import_scope=None):
3752  """Deserializes `context_def` into the appropriate ControlFlowContext.
3753
3754  Args:
3755    context_def: ControlFlowContextDef proto
3756    import_scope: Optional `string`. Name scope to add.
3757
3758  Returns:
3759    A ControlFlowContext subclass
3760  """
3761  if context_def.HasField("cond_ctxt"):
3762    return CondContext.from_proto(
3763        context_def.cond_ctxt, import_scope=import_scope)
3764  if context_def.HasField("while_ctxt"):
3765    return WhileContext.from_proto(
3766        context_def.while_ctxt, import_scope=import_scope)
3767  raise NotImplementedError("Unknown ControlFlowContextDef field: %s" %
3768                            context_def.WhichOneof("ctxt"))
3769
3770
3771ops.register_proto_function(
3772    ops.GraphKeys.COND_CONTEXT,
3773    proto_type=control_flow_pb2.CondContextDef,
3774    to_proto=CondContext.to_proto,
3775    from_proto=CondContext.from_proto)
3776
3777ops.register_proto_function(
3778    ops.GraphKeys.WHILE_CONTEXT,
3779    proto_type=control_flow_pb2.WhileContextDef,
3780    to_proto=WhileContext.to_proto,
3781    from_proto=WhileContext.from_proto)
3782