1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Control Flow Operations.
16
17See the [autograph](https://www.tensorflow.org/guide/autographs) guide.
18"""
19# pylint: disable=g-bad-name
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import abc
25import collections
26import functools
27
28import six
29
30from tensorflow.core.framework import attr_value_pb2
31from tensorflow.core.protobuf import control_flow_pb2
32from tensorflow.python.eager import context
33from tensorflow.python.framework import composite_tensor
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import tensor_shape
39from tensorflow.python.framework import tensor_util
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import control_flow_util as util
42from tensorflow.python.ops import gen_array_ops
43from tensorflow.python.ops import gen_control_flow_ops
44from tensorflow.python.ops import gen_data_flow_ops
45from tensorflow.python.ops import gen_logging_ops
46from tensorflow.python.ops import gen_resource_variable_ops
47from tensorflow.python.ops import math_ops
48from tensorflow.python.ops import tensor_array_ops
49# go/tf-wildcard-import
50# pylint: disable=wildcard-import,undefined-variable
51from tensorflow.python.ops.gen_control_flow_ops import *
52# pylint: enable=wildcard-import
53from tensorflow.python.platform import tf_logging as logging
54from tensorflow.python.util import compat
55from tensorflow.python.util import deprecation
56from tensorflow.python.util import nest
57from tensorflow.python.util import tf_should_use
58from tensorflow.python.util.lazy_loader import LazyLoader
59from tensorflow.python.util.tf_export import tf_export
60
61# This is to avoid a circular dependency:
62# cond_v2 -> gradients_util -> control_flow_ops
63cond_v2 = LazyLoader("cond_v2", globals(),
64                     "tensorflow.python.ops.cond_v2")
65
66# This is to avoid circular dependencies:
67# while_v2 -> control_flow_ops
68# while_v2 -> gradients_util -> control_flow_ops
69while_v2 = LazyLoader("while_v2", globals(),
70                      "tensorflow.python.ops.while_v2")
71
72# We override the 'tuple' for a control flow op, so we keep python's
73# existing 'tuple' for later use in this module.
74_basetuple = tuple
75
76
77def _summarize_eager(tensor, summarize=None):
78  """Returns a summarized string representation of eager `tensor`.
79
80  Args:
81    tensor: EagerTensor to summarize
82    summarize: Include these many first elements of `array`
83  """
84  # Emulate the behavior of Tensor::SummarizeValue()
85  if summarize is None:
86    summarize = 3
87  elif summarize < 0:
88    summarize = array_ops.size(tensor)
89  # reshape((-1,)) is the fastest way to get a flat array view
90  if tensor._rank():  # pylint: disable=protected-access
91    flat = tensor.numpy().reshape((-1,))
92    lst = [str(x) for x in flat[:summarize]]
93    if len(lst) < flat.size:
94      lst.append("...")
95  else:
96    # tensor.numpy() returns a scalar for zero dimensional arrays
97    if summarize != 0:
98      lst = [str(tensor.numpy())]
99    else:
100      lst = []
101
102  return ", ".join(lst)
103
104
105# pylint: disable=protected-access
106
107
108# Assert and Print are special symbols in python, so we must
109# use an upper-case version of them.
110@tf_export("debugging.Assert", "Assert")
111@tf_should_use.should_use_result
112def Assert(condition, data, summarize=None, name=None):
113  """Asserts that the given condition is true.
114
115  If `condition` evaluates to false, print the list of tensors in `data`.
116  `summarize` determines how many entries of the tensors to print.
117
118  NOTE: In graph mode, to ensure that Assert executes, one usually attaches
119  a dependency:
120
121  ```python
122  # Ensure maximum element of x is smaller or equal to 1
123  assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x])
124  with tf.control_dependencies([assert_op]):
125    ... code using x ...
126  ```
127
128  Args:
129    condition: The condition to evaluate.
130    data: The tensors to print out when condition is false.
131    summarize: Print this many entries of each tensor.
132    name: A name for this operation (optional).
133
134  Returns:
135    assert_op: An `Operation` that, when executed, raises a
136    `tf.errors.InvalidArgumentError` if `condition` is not true.
137    @compatibility{eager} returns None.
138
139  Raises:
140    @compatibility{eager} `tf.errors.InvalidArgumentError` if `condition`
141    is not true
142  """
143  if context.executing_eagerly():
144    if not condition:
145      xs = ops.convert_n_to_tensor(data)
146      data_str = [_summarize_eager(x, summarize) for x in xs]
147      raise errors.InvalidArgumentError(
148          node_def=None,
149          op=None,
150          message="Expected '%s' to be true. Summarized data: %s" %
151          (condition, "\n".join(data_str)))
152    return
153
154  with ops.name_scope(name, "Assert", [condition, data]) as name:
155    xs = ops.convert_n_to_tensor(data)
156    if all(x.dtype in {dtypes.string, dtypes.int32} for x in xs):
157      # As a simple heuristic, we assume that string and int32 are
158      # on host to avoid the need to use cond. If it is not case,
159      # we will pay the price copying the tensor to host memory.
160      return gen_logging_ops._assert(condition, data, summarize, name="Assert")
161    else:
162      condition = ops.convert_to_tensor(condition, name="Condition")
163
164      def true_assert():
165        return gen_logging_ops._assert(
166            condition, data, summarize, name="Assert")
167
168      guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
169      if context.executing_eagerly():
170        return
171      return guarded_assert.op
172
173
174def _Identity(data, name=None):
175  """Return a tensor with the same shape and contents as the input tensor.
176
177  Args:
178    data: A Tensor.
179    name: A name for this operation (optional).
180
181  Returns:
182    A Tensor with the same type and value as the input Tensor.
183  """
184  data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
185  if isinstance(data, ops.Tensor):
186    if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
187      return gen_array_ops.ref_identity(data, name=name)
188    else:
189      return array_ops.identity(data, name=name)
190  elif isinstance(data, composite_tensor.CompositeTensor):
191    return nest.map_structure(_Identity, data, expand_composites=True)
192  else:
193    raise TypeError("Type %s not supported" % type(data))
194
195
196def _NextIteration(data, name=None):
197  data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
198  if isinstance(data, ops.Tensor):
199    if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
200      return ref_next_iteration(data, name=name)
201    else:
202      return next_iteration(data, name=name)
203  elif isinstance(data, composite_tensor.CompositeTensor):
204    return nest.map_structure(_NextIteration, data, expand_composites=True)
205  else:
206    raise TypeError("Type %s not supported" % type(data))
207
208
209def _Enter(data,
210           frame_name,
211           is_constant=False,
212           parallel_iterations=10,
213           use_ref=True,
214           use_input_shape=True,
215           name=None):
216  """Creates or finds a child frame, and makes `data` available to it.
217
218  The unique `frame_name` is used by the `Executor` to identify frames. If
219  `is_constant` is true, `data` is a constant in the child frame; otherwise
220  it may be changed in the child frame. At most `parallel_iterations`
221  iterations are run in parallel in the child frame.
222
223  Args:
224    data: The tensor to be made available to the child frame.
225    frame_name: The name of the child frame.
226    is_constant: If true, the output is constant within the child frame.
227    parallel_iterations: The number of iterations allowed to run in parallel.
228    use_ref: If true, use ref_enter if data is of ref type.
229    use_input_shape: If true, set the result's shape based on data's shape.
230    name: A name for this operation (optional).
231
232  Returns:
233    The same tensor as `data`.
234  """
235  data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
236  if isinstance(data, ops.Tensor):
237    if data.dtype._is_ref_dtype and use_ref:  # pylint: disable=protected-access
238      result = gen_control_flow_ops.ref_enter(
239          data, frame_name, is_constant, parallel_iterations, name=name)
240    else:
241      result = gen_control_flow_ops.enter(
242          data, frame_name, is_constant, parallel_iterations, name=name)
243    if use_input_shape:
244      result.set_shape(data.get_shape())
245    return result
246  elif isinstance(data, composite_tensor.CompositeTensor):
247    def enter_component(t):
248      return _Enter(t, frame_name, is_constant, parallel_iterations,
249                    use_ref, use_input_shape)
250    return nest.map_structure(enter_component, data, expand_composites=True)
251  else:
252    raise TypeError("Type %s not supported" % type(data))
253
254
255def exit(data, name=None):  # pylint: disable=redefined-builtin
256  """Exits the current frame to its parent frame.
257
258  Exit makes its input `data` available to the parent frame.
259
260  Args:
261    data: The tensor to be made available to the parent frame.
262    name: A name for this operation (optional).
263
264  Returns:
265    The same tensor as `data`.
266  """
267  data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
268  if isinstance(data, ops.Tensor):
269    if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
270      return gen_control_flow_ops.ref_exit(data, name)
271    else:
272      return gen_control_flow_ops._exit(data, name)
273  elif isinstance(data, composite_tensor.CompositeTensor):
274    return nest.map_structure(exit, data, expand_composites=True)
275  else:
276    raise TypeError("Type %s not supported" % type(data))
277
278
279def switch(data, pred, dtype=None, name=None):
280  """Forwards `data` to an output determined by `pred`.
281
282  If `pred` is false, the `data` input is forwarded to the first output.
283  Otherwise, the data goes to the second output.
284
285  This op handles `Tensor`s and `IndexedSlices`.
286
287  Args:
288    data: The tensor to be forwarded to the appropriate output.
289    pred: A scalar that specifies which output port will receive data.
290    dtype: Optional element type for the returned tensor. If missing, the type
291      is inferred from the type of `value`.
292    name: A name for this operation (optional).
293
294  Returns:
295    `(output_false, output_true)`: If `pred` is true, data will be forwarded
296    to `output_true`, otherwise it goes to `output_false`.
297  """
298  with ops.name_scope(name, "Switch", [data, pred]) as name:
299    data = ops.internal_convert_to_tensor_or_composite(
300        data, dtype=dtype, name="data", as_ref=True)
301    pred = ops.convert_to_tensor(pred, name="pred")
302    if isinstance(data, ops.Tensor):
303      return gen_control_flow_ops.switch(data, pred, name=name)
304    else:
305      if not isinstance(data, composite_tensor.CompositeTensor):
306        raise TypeError("Type %s not supported" % type(data))
307      tensors = nest.flatten(data, expand_composites=True)
308      mapped = [gen_control_flow_ops.switch(tensor, pred) for tensor in tensors]
309      mapped_f, mapped_t = zip(*mapped)
310      return (nest.pack_sequence_as(data, mapped_f, expand_composites=True),
311              nest.pack_sequence_as(data, mapped_t, expand_composites=True))
312
313
314def _SwitchRefOrTensor(data, pred, name="Switch"):
315  """Forwards `data` to an output determined by `pred`.
316
317  If `pred` is false, the `data` input is forwarded to the first output.
318  Otherwise, the data goes to the second output.
319
320  This op handles `Tensor`s and `IndexedSlices`.
321
322  Args:
323    data: The tensor to be forwarded to the appropriate output.
324    pred: A scalar that specifies which output port will receive data.
325    name: A name for this operation (optional).
326
327  Returns:
328    `(output_false, output_true)`: If `pred` is true, data will be forwarded to
329    `output_true`, otherwise it goes to `output_false`.
330
331  Raises:
332    TypeError: if data is not a Tensor or IndexedSlices
333  """
334  data = ops.convert_to_tensor_or_composite(data, name="data")
335  # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
336  # addresses the following scenario.
337  #
338  # Assume you execute Optimizer.apply_gradients() in a branch of a cond().
339  #
340  # 1. The update op is created inside a `with ops.colocate(var):` block
341  #
342  # 2. Some tensor `data` is captured and a switch is created in a
343  #    `with ops.colocate_with(data):` block.
344  #
345  # with ops.colocate_with(var):
346  #  with ops.colocate_with(data):
347  #    op = ...
348  #
349  # var and data may be pinned to different devices, so we want to ops
350  # created within ops.colocate_with(data) to ignore the existing stack.
351  with ops.colocate_with(data, ignore_existing=True):
352    if isinstance(data, ops.Tensor):
353      if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
354        return ref_switch(data, pred, name=name)
355    return switch(data, pred, name=name)
356
357
358def merge(inputs, name=None):
359  """Returns the value of an available element of `inputs`.
360
361  This op tests each of the tensors in `inputs` in turn to determine if any of
362  them is available. If it finds an available tensor, it returns it and its
363  index in `inputs`.
364
365  It is an error if more than one tensor in `inputs` is available. If no tensor
366  in `inputs` is available, the returned tensor and index are not set.
367
368  This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
369  `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
370  before merging.
371
372  Args:
373    inputs: The input tensors, at most one of which is available.
374    name: A name for this operation (optional).
375
376  Returns:
377    A tuple containing the chosen input tensor and its index in `inputs`.
378
379  Raises:
380    ValueError: If any of the inputs is None, or inputs are IndexedSlices and
381      some but not all have a dense_shape property.
382  """
383  if any(inp is None for inp in inputs):
384    raise ValueError("At least one of the merge inputs is None: %s" % inputs)
385  with ops.name_scope(name, "Merge", inputs) as name:
386    inputs = [
387        ops.internal_convert_to_tensor_or_composite(inp, as_ref=True)
388        for inp in inputs
389    ]
390    if all(isinstance(v, ops.Tensor) for v in inputs):
391      if all(v.dtype._is_ref_dtype for v in inputs):  # pylint: disable=protected-access
392        return gen_control_flow_ops.ref_merge(inputs, name)
393      else:
394        return gen_control_flow_ops.merge(inputs, name)
395    else:
396      # If there is a mix of tensors and indexed slices, then convert the
397      # tensors to indexed slices.
398      if all(isinstance(v, (ops.IndexedSlices, ops.Tensor)) for v in inputs):
399        inputs = math_ops._as_indexed_slices_list(inputs, optimize=False)
400
401      for v in inputs:
402        if not isinstance(v, composite_tensor.CompositeTensor):
403          raise TypeError("Type %s not supported" % type(v))
404
405      for v in inputs[1:]:
406        nest.assert_same_structure(inputs[0], v, expand_composites=True)
407
408      flat_inputs = [nest.flatten(v, expand_composites=True) for v in inputs]
409      merged_results = [gen_control_flow_ops.merge(component)
410                        for component in zip(*flat_inputs)]
411      flat_merged = [tensor for (tensor, _) in merged_results]
412      chosen_index = merged_results[0][1]
413      merged_inputs = nest.pack_sequence_as(inputs[0], flat_merged,
414                                            expand_composites=True)
415      return (merged_inputs, chosen_index)
416
417
418# pylint: enable=protected-access
419
420
421def _convert_tensorarray_to_flow(tensor_or_tensor_array):
422  if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):
423    return tensor_or_tensor_array.flow
424  else:
425    return tensor_or_tensor_array
426
427
428def _make_tensor_array(ta, t_or_flow):
429  # pylint: disable=protected-access
430  new_ta = tensor_array_ops.TensorArray(
431      dtype=ta.dtype,
432      handle=ta.handle,
433      flow=t_or_flow,
434      infer_shape=ta._infer_shape,
435      colocate_with_first_write_call=ta._colocate_with_first_write_call)
436  new_ta._colocate_with = ta._colocate_with
437  new_ta._element_shape = ta._element_shape
438  # pylint: enable=protected-access
439  return new_ta
440
441
442def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows):
443  if len(tensors_or_tensorarrays) != len(tensors_or_flows):
444    raise ValueError(
445        "Lengths of original Tensor list and new list do not match: %d vs. %d" %
446        (len(tensors_or_tensorarrays), len(tensors_or_flows)))
447  return [
448      _make_tensor_array(ta, t_or_flow) if isinstance(
449          ta, tensor_array_ops.TensorArray) else t_or_flow
450      for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)
451  ]
452
453
454def _ShapeLessThanOrEqual(shape1, shape2):
455  if shape2.dims is None:
456    return True
457  if shape1.ndims != shape2.ndims:
458    return False
459  for dim1, dim2 in zip(shape1.dims, shape2.dims):
460    if dim2.value is not None and dim1.value != dim2.value:
461      return False
462  return True
463
464
465def _get_shape_invariant(var, shape=None):
466  """Returns a shape invariant for the given variable.
467
468  If `var` is a `CompositeTensor`, then this uses
469  `_shape_invariant_to_components()` to get shape invariants for the
470  component tensors.
471
472  Args:
473    var: The tensor whose shape is described.
474    shape: The shape invariant for the tensor.  If not specified, then a default
475      shape invariant for `var` is returned.
476
477  Returns:
478    The shape invariant for `var` (if it is a `Tensor`), or the shape invariants
479    for the components that comprise `var` (if it is a `CompositeTensor`).
480  """
481  if isinstance(var, composite_tensor.CompositeTensor):
482    return var._shape_invariant_to_components(shape)  # pylint: disable=protected-access
483  elif shape is None:
484    return var.shape
485  else:
486    return shape
487
488
489def _SetShapeInvariants(input_vars, enter_vars, shapes):
490  """Set the shapes of the tensors in `enter_vars` to `shapes`.
491
492  Args:
493    input_vars: A list of tensors that are inputs to `enter_vars`.
494    enter_vars: A list of tensors whose shapes will be set.
495    shapes: A (possibly nested) list of shapes.
496
497  Raises:
498    ValueError: If any tensor in `enter_vars` has a less specific shape
499      than its corresponding shape in `shapes`.
500  """
501  if shapes is None:
502    return
503  flat_shapes = nest.flatten(shapes)
504  if not all(isinstance(s, tensor_shape.TensorShape) for s in flat_shapes):
505    raise ValueError("`shapes` must be a (possibly nested) list of shapes.")
506  # Check that the shapes of the inputs are less than the shape invariants,
507  # and set the shapes of `enter_vars` to the shape invariants.
508  for inp, var, shape in zip(input_vars, enter_vars, flat_shapes):
509    if isinstance(var, ops.Tensor):
510      if not _ShapeLessThanOrEqual(inp.get_shape(), shape):
511        raise ValueError(
512            "The shape invariant specified for %s is not compatible with "
513            "the initial shape of the loop variable. It enters the loop "
514            "with shape %s, but the specified shape invariant is %s." %
515            (inp.name, inp.get_shape(), shape))
516      var.set_shape(shape)
517    else:
518      raise TypeError("Type %s not supported" % type(var))
519
520
521def _EnforceShapeInvariant(merge_var, next_var):
522  """Check if the shapes of the loops variables are invariants.
523
524  Args:
525    merge_var: The list of tensors representing the initial values of the loop
526      variables.
527    next_var: The list of tensors representing the values of the loop variables
528      after one loop iteration.
529
530  Raises:
531    ValueError: If any tensor in `merge_var` has a more specific shape than
532      its correspnding tensor in `next_var`.
533  """
534  if isinstance(merge_var, ops.Tensor):
535    m_shape = merge_var.get_shape()
536    n_shape = next_var.get_shape()
537    if not _ShapeLessThanOrEqual(n_shape, m_shape):
538      enter = merge_var.op.inputs[0].op
539      assert util.IsLoopEnter(enter)
540      input_t = enter.inputs[0]
541      raise ValueError(
542          "Input tensor '%s' enters the loop with shape %s, but has shape %s "
543          "after one iteration. To allow the shape to vary across iterations, "
544          "use the `shape_invariants` argument of tf.while_loop to specify a "
545          "less-specific shape." % (input_t.name, input_t.shape, n_shape))
546  else:
547    raise TypeError("Type %s not supported" % type(merge_var))
548
549
550def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True):
551  """Add NextIteration and back edge from v to m."""
552  if isinstance(m, ops.Tensor):
553    v = ops.convert_to_tensor(v)
554    v = _NextIteration(v)
555    if enforce_shape_invariant:
556      # Make sure the shapes of loop outputs are correct. We do this before
557      # calling _update_input, which will raise a less-helpful error message if
558      # the types don't match.
559      # TODO(skyewm): call this for other cases below (needs testing)
560      _EnforceShapeInvariant(m, v)
561    m.op._update_input(1, v)  # pylint: disable=protected-access
562  elif isinstance(m, composite_tensor.CompositeTensor):
563    # pylint: disable=protected-access
564    def update_component(m_component, v_component):
565      m_component.op._update_input(1, v_component)
566    if isinstance(m, ops.IndexedSlices):
567      v = math_ops._as_indexed_slices(v, optimize=False)
568    # pylint: enable=protected-access
569    v = _NextIteration(v)
570    return nest.map_structure(update_component, m, v, expand_composites=True)
571  else:
572    raise TypeError("Type %s not supported" % type(m))
573  return v
574
575
576def GetMaxSizeFromNestedMaximumIterations(value, while_ctxt):
577  """Calculate a max_size for use by stack ops inside an XLA while_loop.
578
579  Args:
580    value: The value inside the while_loop forward context.  Used for printing
581      error messages.
582    while_ctxt: The forward context inside which value resides.  This does not
583      always match the value's immediate context, as `value` may be inside e.g.
584      a cond context inside the while_loop.
585
586  Returns:
587    A tensor containing the `max_size` to feed to a Stack initializer.
588
589  Raises:
590    ValueError: If `value` is nested inside a `while_loop` that either
591      lacks a `maximum_iterations` parameter, or the `maximum_iterations`
592      parameter:
593
594        - is inside a `while_loop` that is a parent of the calling context, and
595        - cannot be evaluated at graph build time to a constant.
596  """
597  value_name = value.name
598  # curr_ctxt is the context that tf.gradients was called in.
599  curr_ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
600
601  curr_ctxt_name = curr_ctxt.name if curr_ctxt is not None else ""
602  max_size = constant_op.constant(1)
603
604  # Loop through all containing while contexts between value and the
605  # current context, multiplying together each context's
606  # max_iterations to get the maximum stack size.
607  while while_ctxt not in (None, curr_ctxt):
608    max_iter = while_ctxt.maximum_iterations
609    if max_iter is None:
610      raise ValueError(
611          "Cannot create a gradient accumulator for tensor '%s' inside "
612          "XLA while_loop because maximum_iterations was not passed to "
613          "the tf.while_loop call ('%s')." % (value_name, while_ctxt.name))
614
615    # pylint: disable=protected-access
616    max_iter_ctxt = max_iter.op._get_control_flow_context()
617    # pylint: enable=protected-access
618
619    # If max_iter_ctxt (non-strictly) contains curr_ctxt, then it's OK to use.
620    if util.IsContainingContext(curr_ctxt, max_iter_ctxt):
621      max_size *= max_iter
622    else:
623      # We cannot use max_iter because it's defined in a nested while
624      # or cond context, so will fail if we try to use it as input to
625      # any ops in curr_ctxt (e.g. max_size or the final accumulator
626      # stack). Attempt to get a constant value out to use instead.
627      const_max_iter = tensor_util.constant_value(max_iter)
628      if const_max_iter is None:
629        raise ValueError(
630            "Cannot create a gradient accumulator for tensor '%s' inside XLA "
631            "while_loop. maximum_iterations tensor '%s' for while_loop context "
632            "'%s' must be statically known (e.g. a constant value or known "
633            "shape dimension), or be defined at or outside the while loop "
634            "context '%s' (currently defined in '%s')." %
635            (value_name, max_iter.name, while_ctxt.name, curr_ctxt_name,
636             max_iter_ctxt.name))
637      max_size *= const_max_iter
638
639    # Find the next outer WhileContext (or stop if we reach the
640    # tf.gradient's context).
641    while_ctxt = util.GetContainingWhileContext(
642        while_ctxt.outer_context, stop_ctxt=curr_ctxt)
643
644  return max_size
645
646
647class GradLoopState(object):
648  """The state used for constructing the gradient graph for a while loop.
649
650  We create a GradLoopState for each while loop in forward and its
651  corresponding while loop in backprop. This gives us access to both
652  the forward and the backprop WhileContexts.
653
654  During the construction of gradient graph, any time when we detect
655  a forward value that is needed for backprop, we create a history
656  accumulator and add it to `history_map`. Any time when we backprop
657  a loop switch op (in _SwitchGrad), we add the grad merge op in
658  `switch_map`.
659  """
660
661  def __init__(self, forward_ctxt, outer_grad_state):
662    # The grad loop state for the outer while loop.
663    self._outer_grad_state = None
664
665    # The while loop context for forward.
666    self._forward_context = None
667
668    # The loop counter added by AddForwardLoopCounter. It is the value
669    # of the loop counter for the next iteration.
670    self._forward_index = None
671
672    # A sync op for forward.
673    self._forward_sync = None
674
675    # The while loop context for backprop.
676    self._grad_context = None
677
678    # The loop counter added by AddBackpropLoopCounter. It is the value
679    # of the loop counter for the current iteration.
680    self._grad_index = None
681
682    # A sync op for backprop.
683    self._grad_sync = None
684
685    # Information needed by backprop.
686    self._history_map = {}
687    self._switch_map = {}
688    self._unused_exits = []
689    self._deferred_exits = []
690    self._forward_loop_exits = list(forward_ctxt.loop_exits)
691    self._pending_exits_count = len(forward_ctxt.loop_exits)
692
693    self._outer_grad_state = outer_grad_state
694    if outer_grad_state:
695      outer_forward_ctxt = outer_grad_state.forward_context
696    else:
697      if not hasattr(forward_ctxt, "outer_context"):
698        raise ValueError("Failed to call gradients on a while loop without"
699                         "properly serializing graph via MetaGraphDef")
700      outer_forward_ctxt = forward_ctxt.outer_context
701
702    # Add the forward loop counter.
703    with forward_ctxt._graph.as_default():  # pylint: disable=protected-access
704      if outer_forward_ctxt:
705        outer_forward_ctxt.Enter()
706      cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
707      if outer_forward_ctxt:
708        outer_forward_ctxt.Exit()
709    self._forward_context = forward_ctxt
710    self._forward_index = forward_index
711
712    # Add the backprop WhileContext, and the backprop loop counter.
713    if outer_grad_state:
714      # This is a nested loop. Remember the iteration counts for each
715      # execution of this inner loop.
716      outer_forward_ctxt.AddName(cnt.name)
717      history_cnt = outer_grad_state.AddForwardAccumulator(cnt)
718
719      outer_grad_ctxt = outer_grad_state.grad_context
720      outer_grad_ctxt.Enter()
721      self._grad_context = WhileContext(
722          maximum_iterations=forward_ctxt.maximum_iterations,
723          parallel_iterations=forward_ctxt.parallel_iterations,
724          back_prop=forward_ctxt.back_prop,
725          swap_memory=forward_ctxt.swap_memory,
726          name=forward_ctxt.name,
727          grad_state=self)
728      real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt)
729      self._grad_index = self._grad_context.AddBackpropLoopCounter(
730          real_cnt, outer_grad_state)
731      outer_grad_ctxt.Exit()
732    else:
733      if outer_forward_ctxt:
734        outer_forward_ctxt.Enter()
735      self._grad_context = WhileContext(
736          maximum_iterations=forward_ctxt.maximum_iterations,
737          parallel_iterations=forward_ctxt.parallel_iterations,
738          back_prop=forward_ctxt.back_prop,
739          swap_memory=forward_ctxt.swap_memory,
740          name=forward_ctxt.name,
741          grad_state=self)
742      self._grad_index = self._grad_context.AddBackpropLoopCounter(
743          cnt, outer_grad_state)
744      if outer_forward_ctxt:
745        outer_forward_ctxt.Exit()
746
747  @property
748  def outer_grad_state(self):
749    """The grad loop state for outer loop."""
750    return self._outer_grad_state
751
752  @property
753  def forward_context(self):
754    """The while loop context for forward."""
755    return self._forward_context
756
757  @property
758  def forward_index(self):
759    """The loop index of forward loop."""
760    return self._forward_index
761
762  @property
763  def forward_sync(self):
764    """A control trigger node for synchronization in the forward loop.
765
766    One main use is to keep the push ops of a stack executed in the
767    iteration order.
768    """
769    if self._forward_sync is None:
770      with ops.control_dependencies(None):
771        self._forward_sync = control_trigger(name="f_sync")
772      self._forward_sync._set_control_flow_context(self._forward_context)
773      self._forward_index.op._add_control_input(self._forward_sync)
774    return self._forward_sync
775
776  @property
777  def grad_context(self):
778    """The corresponding WhileContext for gradient."""
779    return self._grad_context
780
781  @property
782  def grad_index(self):
783    """The loop index of backprop loop."""
784    return self._grad_index
785
786  @property
787  def grad_sync(self):
788    """A control trigger node for synchronization in the grad loop.
789
790    One main use is to keep the pop ops of a stack executed in the
791    iteration order.
792    """
793    if self._grad_sync is None:
794      with ops.control_dependencies(None):
795        self._grad_sync = control_trigger(name="b_sync")
796      self._grad_sync._set_control_flow_context(self._grad_context)
797      self._grad_index.op._add_control_input(self._grad_sync)
798      if self._grad_context.outer_context:
799        self._grad_context.outer_context.AddInnerOp(self._grad_sync)
800    return self._grad_sync
801
802  @property
803  def history_map(self):
804    """The map that records all the tensors needed for backprop."""
805    return self._history_map
806
807  @property
808  def switch_map(self):
809    """The map that records all the Switch ops for the while loop."""
810    return self._switch_map
811
812  @property
813  def unused_exits(self):
814    """The list of "unused" exits."""
815    return self._unused_exits
816
817  @property
818  def deferred_exits(self):
819    """The list of "deferred" exits."""
820    return self._deferred_exits
821
822  @property
823  def forward_loop_exits(self):
824    """The list of exits of the forward loop."""
825    return self._forward_loop_exits
826
827  @property
828  def pending_exits_count(self):
829    """The number of exits we expect to see but haven't."""
830    return self._pending_exits_count
831
832  @pending_exits_count.setter
833  def pending_exits_count(self, cnt):
834    """Set the pending count to cnt."""
835    self._pending_exits_count = cnt
836
837  def AddForwardAccumulator(self, value, dead_branch=False):
838    """Add an accumulator for each forward tensor that is needed in backprop.
839
840    This is added to the forward loop at the first time when a tensor
841    in the forward loop is used by backprop gradient computation loop.
842    We create an accumulator that accumulates the value of tensor at each
843    iteration. Called in the control flow context where gradients() is called.
844
845    The pseudocode is:
846    ```
847      acc = stack();
848      while (_pivot) {
849        acc = stack_push(acc, value);
850      }
851    ```
852
853    We make sure that the stack push op in one iteration is executed before
854    next iteration. This is achieved by adding a control edge from
855    `forward_index.op.inputs[0].op` to the push op, and another control
856    edge from the push op to either `forward_index.op` or `forward_sync`.
857
858    Args:
859      value: The source tensor in forward that is to be accumulated.
860      dead_branch: True iff the tensor is on a dead branch of a cond.
861
862    Returns:
863      The stack that contains the accumulated history of the tensor.
864
865    Raises:
866      TypeError: For internal errors involving the value condition context.
867      ValueError: If `value` is inside a XLA scope and a valid max size
868        for the stack can't be found.
869    """
870    # curr_ctxt is the context that tf.gradients was called in.
871    with self._forward_index.graph.as_default():
872      curr_ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
873      with ops.control_dependencies(None):
874        if curr_ctxt:
875          curr_ctxt.Enter()
876        with ops.colocate_with(value):
877          # We only need to pass maximum_iterations to the stack if
878          # we're inside an XLA context.
879          if not util.IsInXLAContext(value.op):
880            max_size = constant_op.constant(-1, dtypes.int32)
881          else:
882            max_size = GetMaxSizeFromNestedMaximumIterations(
883                value, self.forward_context)
884          acc = gen_data_flow_ops.stack_v2(
885              max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
886        if curr_ctxt:
887          curr_ctxt.Exit()
888
889        # Make acc available in the forward context.
890        enter_acc = self.forward_context.AddValue(acc)
891
892        # Add the stack_push op in the context of value.op.
893        swap_enabled = self.forward_context.swap_memory
894        value_ctxt = util.GetOutputContext(value.op)
895        if value_ctxt == self.forward_context:
896          # value is not nested in the forward context.
897          self.forward_context.Enter()
898          push = gen_data_flow_ops.stack_push_v2(
899              enter_acc, value, swap_memory=swap_enabled)
900          self.forward_context.Exit()
901          # Protect stack push and order it before forward_index.
902          self.forward_index.op._add_control_input(push.op)
903        else:
904          # value is in a cond context within the forward context.
905          if not isinstance(value_ctxt, CondContext):
906            raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
907          if dead_branch:
908            # The special case for creating a zero tensor for a dead
909            # branch of a switch. See ControlFlowState.ZerosLike().
910            value_ctxt.outer_context.Enter()
911            push = gen_data_flow_ops.stack_push_v2(
912                enter_acc, value, swap_memory=swap_enabled)
913            value_ctxt.outer_context.Exit()
914            push.op._set_control_flow_context(value_ctxt)
915          else:
916            value_ctxt.Enter()
917            push = gen_data_flow_ops.stack_push_v2(
918                enter_acc, value, swap_memory=swap_enabled)
919            value_ctxt.Exit()
920          # Protect stack push and order it before forward_sync.
921          self.forward_sync._add_control_input(push.op)
922        # Order stack push after the successor of forward_index
923        add_op = self.forward_index.op.inputs[0].op
924        push.op._add_control_input(add_op)
925        return acc
926
927  def AddBackpropAccumulatedValue(self, history_value, value,
928                                  dead_branch=False):
929    """Add the getter for an accumulated value in the grad context.
930
931    This is added to the backprop loop. Called in the grad context to
932    get the value of an accumulated value. The stack pop op must be guarded
933    by the pred of the controlling cond.
934
935    Args:
936      history_value: The history (a stack) of a value.
937      value: The value that is pushed onto the stack.
938      dead_branch: True iff the tensor is on a dead branch of a cond.
939
940    Returns:
941      The current value (the top of the stack).
942    """
943    history_ctxt = history_value.op._get_control_flow_context()
944    # Find the cond context that controls history_value if any.
945    cond_ctxt = None
946    value_ctxt = value.op._get_control_flow_context()
947    while value_ctxt and value_ctxt != history_ctxt:
948      if isinstance(value_ctxt, CondContext):
949        cond_ctxt = value_ctxt
950        break
951      value_ctxt = value_ctxt.outer_context
952    with ops.control_dependencies(None):
953      self.grad_context.Enter()
954      if cond_ctxt:
955        # Guard stack pop with a switch if it is controlled by a cond.
956        grad_state = self
957        pred = None
958        while pred is None and grad_state:
959          pred = grad_state.history_map.get(cond_ctxt.pred.name)
960          grad_state = grad_state.outer_grad_state
961        if pred is None:
962          pred = cond_ctxt.pred
963        branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
964        history_value = _SwitchRefOrTensor(history_value, pred)[branch]
965      pop = gen_data_flow_ops.stack_pop_v2(history_value,
966                                           value.dtype.base_dtype)
967      pop.set_shape(value.get_shape())
968      self.grad_context.Exit()
969    parallel_iterations = self.grad_context.parallel_iterations
970    if parallel_iterations > 1:
971      # All pops are ordered after pivot_for_body and before grad_sync.
972      self.grad_sync._add_control_input(pop.op)
973    return pop
974
975  def GetRealValue(self, value):
976    """Get the real value of `value`.
977
978    If backprop "uses" a value produced by forward inference, an accumulator
979    is added in the forward loop to accumulate its values.  We use the
980    accumulated value. This method must be called in the grad loop context.
981    `value` must be in forward and needed for backprop.
982
983    Args:
984      value: A tensor to be captured.
985
986    Returns:
987      The same tensor obtained from the saved history.
988    """
989    assert value.op.type not in ["Variable", "VariableV2"]
990    real_value = self._history_map.get(value.name)
991    if real_value is None:
992      cur_value = value
993      cur_grad_state = self
994      while True:
995        enter_op = util.GetLoopConstantEnter(cur_value)
996        if enter_op:
997          # Special case: cur_value comes from a constant Enter node.
998          cur_value = enter_op.inputs[0]
999          cur_grad_state = cur_grad_state.outer_grad_state
1000          if cur_grad_state is None:
1001            # We are now outside all nested loops for this gradient(),
1002            # so `value` is a loop invariant and there is no need to
1003            # save the history of value. Just make cur_value to enter
1004            # the right control flow context.
1005            real_value = self._grad_context.AddValue(cur_value)
1006            break
1007        elif constant_op.is_constant(cur_value):
1008          # If the value to be forwarded is a constant, clone the constant in
1009          # the gradient loop rather than using a stack.
1010          # TODO(phawkins): consider hoisting the constant out of the loop
1011          # instead.
1012          real_value = constant_op.constant(
1013              tensor_util.constant_value(cur_value), dtype=cur_value.dtype)
1014          break
1015        else:
1016          # Record the history of this value in forward_ctxt.
1017          self._grad_context.Exit()
1018          history_value = cur_grad_state.AddForwardAccumulator(cur_value)
1019          self._grad_context.Enter()
1020          break
1021
1022      if real_value is None:
1023        # Add the stack pop op in the grad context.
1024        real_value = cur_grad_state.AddBackpropAccumulatedValue(
1025            history_value, cur_value)
1026        if cur_grad_state != self:
1027          real_value = self._grad_context.AddValue(real_value)
1028      self._history_map[value.name] = real_value
1029    return real_value
1030
1031
1032def _GetWhileContext(op):
1033  """Get the WhileContext to which this op belongs."""
1034  ctxt = op._get_control_flow_context()
1035  if ctxt:
1036    ctxt = ctxt.GetWhileContext()
1037  return ctxt
1038
1039
1040class ControlFlowState(object):
1041  """Maintain the mapping from the loops to their grad states."""
1042
1043  def __init__(self):
1044    self._map = {}  # maps forward loop context to GradLoopState
1045
1046  def GetGradState(self, op, before):
1047    """Return the grad state for this op if it's in a forward loop context."""
1048    if before and util.IsLoopExit(op):
1049      forward_ctxt = op._get_control_flow_context()
1050      forward_ctxt = forward_ctxt.outer_context
1051      if forward_ctxt:
1052        forward_ctxt = forward_ctxt.GetWhileContext()
1053    else:
1054      forward_ctxt = _GetWhileContext(op)
1055    if forward_ctxt:
1056      return self._map.get(forward_ctxt)
1057    return None
1058
1059  def ProcessUnusedLoopExits(self, pending_count, to_ops_set):
1060    """Process all the "unused" loop exits.
1061
1062    The "unused" exits of the loops are added to `unused_exits`. An exit is
1063    unused if its pending_count is 0. If there is an exit with real gradient,
1064    all these deferred exits will enter the backprop loop with zero gradient.
1065    Otherwise, they will enter the backprop loop with None. As an example,
1066    people often write:
1067
1068    ```python
1069    v1, _ = tf.while_loop(p, b, [x1, x2])
1070    result = gradients(v1, x1)
1071    ```
1072
1073    The exit node for x2 is not included by the betweenness analysis. But we
1074    need to backprop x2 if x2 is involved in computing v1.
1075
1076    Args:
1077      pending_count: The number of backprop inputs for every op.
1078      to_ops_set: The set of ops for ys in gradients(ys, xs)
1079
1080    Returns:
1081      The set of unused loop exits that we know at this point we need
1082      to backprop.
1083    """
1084    loop_exits = []
1085    for grad_state in self._map.values():
1086      for y in grad_state.forward_loop_exits:
1087        if pending_count[y.op] == 0:
1088          grad_state.pending_exits_count -= 1
1089          if y.op not in to_ops_set:
1090            grad_state.unused_exits.append(y)
1091          if grad_state.pending_exits_count == 0:
1092            loop_exits.extend(grad_state.unused_exits)
1093      # Need to include Enters in backprop for higher-order gradients.
1094      for y in grad_state.forward_context.loop_enters:
1095        if pending_count[y.op] == 0:
1096          pending_count[y.op] = 1
1097    return loop_exits
1098
1099  def EnterGradWhileContext(self, op, before):
1100    """Enter the WhileContext for gradient computation."""
1101    grad_state = self.GetGradState(op, before)
1102    if grad_state:
1103      grad_state.grad_context.Enter()
1104
1105  def ExitGradWhileContext(self, op, before):
1106    """Exit the WhileContext for gradient computation."""
1107    grad_state = self.GetGradState(op, before)
1108    if grad_state:
1109      grad_state.grad_context.Exit()
1110
1111  def AddWhileContext(self, op, between_op_list, between_ops):
1112    """Add the grad state for the while loop that op belongs to.
1113
1114    Note that op is an Exit, and this method must be called in
1115    the control flow context where gradients() is called.
1116
1117    Note that this method modifies `between_op_list` and `between_ops`.
1118    """
1119    forward_ctxt = _GetWhileContext(op)
1120    grad_state = self._map.get(forward_ctxt)
1121    if grad_state is None:
1122      # This is a new while loop so create a grad state for it.
1123      outer_forward_ctxt = forward_ctxt.outer_context
1124      if outer_forward_ctxt:
1125        outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
1126      outer_grad_state = None
1127      if outer_forward_ctxt:
1128        outer_grad_state = self._map.get(outer_forward_ctxt)
1129      grad_state = GradLoopState(forward_ctxt, outer_grad_state)
1130      self._map[forward_ctxt] = grad_state
1131
1132      # We need to include all exits of a loop for backprop.
1133      for loop_exit in grad_state.forward_loop_exits:
1134        if loop_exit.op not in between_ops:
1135          between_ops.add(loop_exit.op)
1136          between_op_list.append(loop_exit.op)
1137
1138  def ZerosLikeForExit(self, val):
1139    """Create zeros_like gradient for a loop exit.
1140
1141    If the result of a loop variable is not used but is involved in
1142    computing the result of some needed loop variable, we create a
1143    zero-valued tensor that is fed as gradient for the Exit node of that
1144    loop variable. Note that val.op is an Exit, and this method must be
1145    called in the control flow context where gradients() is called.
1146
1147    Args:
1148      val: The output tensor of an Exit op.
1149
1150    Returns:
1151      A zero tensor of the same shape of val.
1152    """
1153    val_shape = val.get_shape()
1154    forward_ctxt = val.op._get_control_flow_context()
1155    outer_forward_ctxt = forward_ctxt.outer_context
1156    if outer_forward_ctxt:
1157      outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
1158    outer_grad_state = None
1159    if outer_forward_ctxt:
1160      outer_grad_state = self._map.get(outer_forward_ctxt)
1161    if outer_grad_state:
1162      # This is a nested loop.
1163      if val_shape.is_fully_defined():
1164        # If the shape is known statically, just create a zero tensor
1165        # with the right shape in the right context.
1166        outer_grad_state.grad_context.Enter()
1167        result = array_ops.zeros(val_shape.dims, val.dtype)
1168        outer_grad_state.grad_context.Exit()
1169      else:
1170        # Only the shape of value is needed for backprop.
1171        forward_ctxt.outer_context.Enter()
1172        shape = array_ops.shape_internal(val, optimize=False)
1173        forward_ctxt.outer_context.Exit()
1174        # Save the shape to a stack.
1175        history_shape = outer_grad_state.AddForwardAccumulator(shape)
1176        # Get the shape back from the stack.
1177        outer_grad_ctxt = outer_grad_state.grad_context
1178        outer_grad_ctxt.Enter()
1179        real_shape = outer_grad_state.AddBackpropAccumulatedValue(
1180            history_shape, shape)
1181        result = array_ops.zeros(real_shape, val.dtype)
1182        outer_grad_ctxt.Exit()
1183    else:
1184      # This is not a nested loop.
1185      if val_shape.is_fully_defined():
1186        # If the shape is known statically, just create a zero tensor
1187        # with the right shape.
1188        result = array_ops.zeros(val_shape.dims, val.dtype)
1189      else:
1190        result = array_ops.zeros_like(val, optimize=False)
1191    return result
1192
1193  def ZerosLike(self, op, index):
1194    """Create zeros_like for the specified output of an op.
1195
1196    If op is in a while loop that is part of gradients(), this method
1197    must be called in its grad loop context.
1198
1199    Args:
1200      op: A tensorflow operation.
1201      index: the index for a specific output of the op.
1202
1203    Returns:
1204      A zero tensor of the same shape of op.outputs[index].
1205    """
1206    if util.IsLoopSwitch(op):
1207      return None
1208    if op.graph._building_function:  # pylint: disable=protected-access
1209      # The optimization here is tricky to apply to functions
1210      return array_ops.zeros_like(op.outputs[index])
1211    dead_branch = util.IsSwitch(op)
1212    forward_ctxt = _GetWhileContext(op)
1213    grad_state = self._map.get(forward_ctxt)
1214    if grad_state is None:
1215      # op is not in a while loop that is part of gradients().
1216      return ZerosLikeOutsideLoop(op, index)
1217    op_ctxt = op._get_control_flow_context()
1218    val = ops.convert_to_tensor(op.outputs[index], name="tensor")
1219    shape = val.get_shape()
1220    if shape.is_fully_defined():
1221      # If the shape is known statically, just create a zero tensor with
1222      # the right shape in the grad loop context.
1223      result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
1224      if dead_branch:
1225        # op is a cond switch. Guard the zero tensor with a switch.
1226        pred = grad_state.history_map.get(op_ctxt.pred.name)
1227        branch = op_ctxt.branch
1228        result = _SwitchRefOrTensor(result, pred)[1 - branch]
1229    else:
1230      # Unknown shape so keep a history of the shape at runtime.
1231      if dead_branch:
1232        # Need to add a special switch to guard the value.
1233        pred = op_ctxt.pred
1234        branch = op_ctxt.branch
1235        op_ctxt.outer_context.Enter()
1236        val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch]
1237        zeros_shape = array_ops.shape_internal(val, optimize=False)
1238        op_ctxt.outer_context.Exit()
1239        val.op._set_control_flow_context(op_ctxt)
1240        zeros_shape.op._set_control_flow_context(op_ctxt)
1241      else:
1242        op_ctxt.Enter()
1243        zeros_shape = array_ops.shape_internal(val, optimize=False)
1244        op_ctxt.Exit()
1245
1246      # Add forward accumulator for shape.
1247      grad_state.grad_context.Exit()
1248      history_zeros_shape = grad_state.AddForwardAccumulator(
1249          zeros_shape, dead_branch=dead_branch)
1250      grad_state.grad_context.Enter()
1251
1252      # Create a zero tensor with the right shape.
1253      shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
1254                                                     zeros_shape, dead_branch)
1255      result = array_ops.zeros(shape, val.dtype)
1256    return result
1257
1258  def PostProcessing(self):
1259    """Perform postprocessing at the end of gradients().
1260
1261    We have created the gradient graph at this point. So this function
1262    can be used to perform any postprocessing on the gradient graph.
1263    We currently perform the following postprocessing:
1264      1. Patch the gradient graph if the output of a loop variable
1265         doesn't depend on its input.
1266    """
1267    for _, grad_state in self._map.items():
1268      for _, b_merge in grad_state.switch_map.items():
1269        if b_merge.op.inputs[0] == b_merge.op.inputs[1]:
1270          # The value of this loop variable at iteration i+1 doesn't
1271          # depend on its value at iteration i. So use zeros as the
1272          # gradients for all iterations > 0.
1273          dtype = b_merge.op.inputs[0].dtype
1274          shape = b_merge.op.inputs[0].get_shape()
1275          # pylint: disable=protected-access
1276          if shape.is_fully_defined():
1277            grad_state.grad_context.Enter()
1278            # Create a zeros and use it for iterations > 0.
1279            grad_val = constant_op.constant(0, dtype=dtype, shape=shape)
1280            next_grad_val = _NextIteration(grad_val)
1281            grad_state.grad_context.Exit()
1282          else:
1283            # Create a zeros in the outer grad context.
1284            outer_grad_ctxt = grad_state.grad_context.outer_context
1285            if outer_grad_ctxt:
1286              outer_grad_ctxt.Enter()
1287            enter_grad_op = b_merge.op.inputs[0].op
1288            enter_grad = enter_grad_op.inputs[0]
1289            grad_shape = array_ops.shape_internal(enter_grad, optimize=False)
1290            grad_val = array_ops.zeros(grad_shape)
1291            if outer_grad_ctxt:
1292              outer_grad_ctxt.Exit()
1293            # Use the zeros for iterations > 0.
1294            grad_state.grad_context.Enter()
1295            next_grad_val = _NextIteration(grad_val)
1296            grad_state.grad_context.Exit()
1297          b_merge.op._update_input(1, next_grad_val)
1298          # pylint: enable=protected-access
1299
1300
1301def MaybeCreateControlFlowState(between_op_list, between_ops,
1302                                colocate_gradients_with_ops):
1303  """Create the state for all the while loops involved in one gradients().
1304
1305  We create a ControlFlowState when there are while loops involved in
1306  gradients(). In gradients(), control flow logic is only invoked when
1307  the ControlFlowState is not None.
1308
1309  Note that this method modifies `between_op_list` and `between_ops`.
1310  """
1311  loop_state = None
1312  for op in between_op_list:
1313    if util.IsLoopExit(op):
1314      if loop_state is None:
1315        loop_state = ControlFlowState()
1316      if colocate_gradients_with_ops:
1317        with ops.colocate_with(op):
1318          loop_state.AddWhileContext(op, between_op_list, between_ops)
1319      else:
1320        loop_state.AddWhileContext(op, between_op_list, between_ops)
1321  return loop_state
1322
1323
1324def ZerosLikeOutsideLoop(op, index):
1325  """Create zeros_like for the specified output of an op."""
1326  val = op.outputs[index]
1327  if not util.IsSwitch(op):
1328    if val.dtype == dtypes.resource:
1329      return array_ops.zeros(gen_resource_variable_ops.variable_shape(val))
1330    return array_ops.zeros_like(val, optimize=False)
1331  else:
1332    op_ctxt = op._get_control_flow_context()
1333    if op_ctxt:
1334      # We are in a cond context. Use a switch to create zeros only when needed.
1335      pred = op_ctxt.pred
1336      branch = op_ctxt.branch
1337      switch_val = switch(op.inputs[0], pred)[1 - branch]
1338      # A op is created along the branch taken as control dependencies are on
1339      # the whole op and not on the tensor output.
1340      pivot = array_ops.identity(switch_val)
1341      if val.dtype == dtypes.resource:
1342        with ops.control_dependencies([pivot]):
1343          return array_ops.zeros(
1344              gen_resource_variable_ops.variable_shape(switch_val))
1345      zeros_shape = array_ops.shape_internal(switch_val, optimize=False)
1346      # Ensure ops created within array_ops.zeros are dominated by switch in
1347      # cond context.
1348      with ops.control_dependencies([pivot]):
1349        return array_ops.zeros(zeros_shape, dtype=val.dtype)
1350    else:
1351      return array_ops.zeros_like(val, optimize=False)
1352
1353
1354@six.add_metaclass(abc.ABCMeta)
1355class ControlFlowContext(object):
1356  """The base class for control flow context.
1357
1358  The usage pattern is a sequence of (Enter, Exit) followed by a final
1359  ExitResult.
1360
1361  We maintain the following state for control flow contexts during graph
1362  construction:
1363   1. graph has _control_flow_context: the current context used to
1364      construct new nodes. Changed by ctxt.Enter() and ctxt.Exit()
1365   2. op has _control_flow_context: the context to which the op belongs.
1366      Set at the time the op is created. Immutable.
1367   3. A ControlFlowContext has _outer_context: the context in which this
1368      context is created. Set at the time a context is created. Immutable.
1369   4. A ControlFlowContext has _context_stack.
1370      Pushed and popped by ctxt.Enter() and ctxt.Exit()
1371  """
1372
1373  def __init__(self, values_def=None, import_scope=None):
1374    self._nested_contexts = []
1375    self._outer_context = ops.get_default_graph()._get_control_flow_context()
1376    if self._outer_context:
1377      self._outer_context._nested_contexts.append(self)  # pylint: disable=protected-access
1378    self._context_stack = []
1379    if values_def:
1380      self._init_values_from_proto(values_def, import_scope=import_scope)
1381    else:
1382      # The names of tensors that have been already seen in this context.
1383      self._values = set()
1384      # The keys are the names of tensors referenced by but external to this
1385      # context. Each value is the Tensor that should be used by this context to
1386      # access the key value (e.g. a switch output guarding a cond input value).
1387      self._external_values = {}
1388
1389  def _init_values_from_proto(self, values_def, import_scope=None):
1390    """Initializes values and external_values from `ValuesDef` protocol buffer.
1391
1392    Args:
1393      values_def: `ValuesDef` protocol buffer.
1394      import_scope: Optional `string`. Name scope to add.
1395    """
1396    assert isinstance(values_def, control_flow_pb2.ValuesDef)
1397    self._values = set(
1398        ops.prepend_name_scope(value, import_scope)
1399        for value in values_def.values)
1400    g = ops.get_default_graph()
1401    self._external_values = {}
1402    for k, v in values_def.external_values.items():
1403      k = ops.prepend_name_scope(k, import_scope)
1404      self._external_values[k] = g.as_graph_element(
1405          ops.prepend_name_scope(v, import_scope))
1406    op_names = set([
1407        op.split(":")[0]
1408        for op in self._values - set(self._external_values.keys())
1409    ])
1410    for op in op_names:
1411      # pylint: disable=protected-access
1412      g.as_graph_element(op)._set_control_flow_context(self)
1413      # pylint: enable=protected-access
1414
1415  @property
1416  def name(self):
1417    return self._name
1418
1419  @property
1420  def outer_context(self):
1421    """Return the context containing this context."""
1422    return self._outer_context
1423
1424  @property
1425  def grad_state(self):
1426    raise NotImplementedError("Abstract method")
1427
1428  @property
1429  def back_prop(self):
1430    raise NotImplementedError("Abstract method")
1431
1432  @abc.abstractmethod
1433  def to_control_flow_context_def(self, context_def, export_scope=None):
1434    """Serializes this into `context_def`.
1435
1436    Args:
1437      context_def: a `ControlFlowContextDef` protocol buffer.
1438      export_scope: Optional `string`. Name scope to remove.
1439    """
1440    raise NotImplementedError("Abstract method")
1441
1442  def _to_values_def(self, export_scope=None):
1443    """Converts the values to a `ValuesDef` protocol buffer.
1444
1445    Args:
1446      export_scope: Optional `string`. Name scope to remove.
1447
1448    Returns:
1449      A `ValuesDef` protocol buffer.
1450    """
1451    values_def = control_flow_pb2.ValuesDef()
1452    values_def.values.extend(
1453        [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)])
1454    for k, v in self._external_values.items():
1455      k = ops.strip_name_scope(k, export_scope)
1456      values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope)
1457    return values_def
1458
1459  def AddName(self, name):
1460    self._values.add(name)
1461
1462  # pylint: disable=protected-access
1463  def Enter(self):
1464    """Enter this control flow context."""
1465    graph = ops.get_default_graph()
1466    self._context_stack.append(graph._get_control_flow_context())
1467    graph._set_control_flow_context(self)
1468
1469  def Exit(self):
1470    """Exit this control flow context."""
1471    graph = ops.get_default_graph()
1472    last_context = self._context_stack.pop()
1473    graph._set_control_flow_context(last_context)
1474
1475  def EnterGradientColocation(self, op, gradient_uid):
1476    """Start building a gradient colocated with an op."""
1477    if self._outer_context:
1478      self._outer_context.EnterGradientColocation(op, gradient_uid)
1479
1480  def ExitGradientColocation(self, op, gradient_uid):
1481    """Start building a gradient colocated with an op."""
1482    if self._outer_context:
1483      self._outer_context.ExitGradientColocation(op, gradient_uid)
1484
1485  def ExitResult(self, result):
1486    """Make a list of tensors available in the outer context."""
1487    if self._outer_context:
1488      nest.map_structure(lambda x: self._outer_context.AddName(x.name), result,
1489                         expand_composites=True)
1490
1491  def GetWhileContext(self):
1492    """Return the while context containing this context."""
1493    if self._outer_context:
1494      return self._outer_context.GetWhileContext()
1495    return None
1496
1497  def _IsInOuterContext(self, op):
1498    op_ctxt = util.GetOutputContext(op)
1499    outer_ctxt = self.outer_context
1500    while outer_ctxt != op_ctxt:
1501      if outer_ctxt is None:
1502        return False
1503      outer_ctxt = outer_ctxt.outer_context
1504    return True
1505
1506  def _RemoveExternalControlEdges(self, op):
1507    """Remove any external control dependency on this op."""
1508    while_ctxt = self.GetWhileContext()
1509    # A control input of `op` is internal if it is in the same while
1510    # loop context as the enclosing while loop context of self.
1511    if while_ctxt is None:
1512      internal_control_inputs = op.control_inputs
1513    else:
1514      internal_control_inputs = []
1515      for x in op.control_inputs:
1516        ctxt = util.GetOutputContext(x)
1517        if ctxt is not None and ctxt.GetWhileContext() == while_ctxt:
1518          internal_control_inputs.append(x)
1519    external_control_inputs = []
1520    if len(internal_control_inputs) != len(op.control_inputs):
1521      external_control_inputs = list(
1522          set(op.control_inputs) - set(internal_control_inputs))
1523      op._remove_all_control_inputs()
1524      op._add_control_inputs(internal_control_inputs)
1525    return internal_control_inputs, external_control_inputs
1526
1527  # pylint: enable=protected-access
1528
1529  def AddInnerOp(self, op):
1530    """Notifies a scope about an operator added to an inner scope."""
1531    if self._outer_context:
1532      self._outer_context.AddInnerOp(op)
1533
1534  def GetControlPivot(self):
1535    """Returns the pivot node for this context, or None."""
1536    return None
1537
1538  def IsWhileContext(self):
1539    return False
1540
1541  def IsCondContext(self):
1542    return False
1543
1544  def IsXLAContext(self):
1545    return False
1546
1547  def __str__(self):
1548    return self.name
1549
1550
1551class CondContext(ControlFlowContext):
1552  """The context for the conditional construct."""
1553
1554  def __init__(self,
1555               pred=None,
1556               pivot=None,
1557               branch=None,
1558               name="cond_text",
1559               context_def=None,
1560               import_scope=None):
1561    """Creates a `CondContext`.
1562
1563    Args:
1564      pred: The `boolean` tensor for the conditional predicate.
1565      pivot: The predicate tensor in this branch.
1566      branch: 0 or 1 representing this branch.
1567      name: Name of the `CondContext` python object.
1568      context_def: Optional `ContextDef` protocol buffer to initialize the
1569        `CondContext` object from.
1570      import_scope: Optional `string`. Name scope to add. Only used when
1571        initialing from protocol buffer.
1572    """
1573    self._name = ops.get_default_graph().unique_name(name)
1574
1575    if context_def:
1576      self._init_from_proto(context_def, import_scope=import_scope)
1577    else:
1578      # Initializes the default fields.
1579      ControlFlowContext.__init__(self)
1580      self._pred = pred  # The boolean tensor for the cond predicate
1581      self._pivot = pivot  # The predicate tensor in this branch
1582      self._branch = branch  # 0 or 1 representing this branch
1583
1584      # Values considered to have been already seen in this context. pred is not
1585      # included in this context.
1586      self._values.add(pred.name)
1587      self._external_values[pred.name] = pred
1588      self._values.add(pivot.name)
1589      pivot.op._set_control_flow_context(self)  # pylint: disable=protected-access
1590
1591  def _init_from_proto(self, context_def, import_scope=None):
1592    """Creates a new `CondContext` from protocol buffer.
1593
1594    Args:
1595      context_def: `CondContextDef` protocol buffer.
1596      import_scope: Optional `string`. Name scope to add.
1597    """
1598    assert isinstance(context_def, control_flow_pb2.CondContextDef)
1599    # Create from context_def.
1600    g = ops.get_default_graph()
1601    self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
1602    self._pred = g.as_graph_element(
1603        ops.prepend_name_scope(context_def.pred_name, import_scope))
1604    self._pivot = g.as_graph_element(
1605        ops.prepend_name_scope(context_def.pivot_name, import_scope))
1606    self._branch = context_def.branch
1607    super(CondContext, self).__init__(
1608        values_def=context_def.values_def, import_scope=import_scope)
1609
1610  @property
1611  def pred(self):
1612    return self._pred
1613
1614  @property
1615  def pivot(self):
1616    return self._pivot
1617
1618  @property
1619  def branch(self):
1620    return self._branch
1621
1622  @property
1623  def grad_state(self):
1624    if self.GetWhileContext():
1625      return self.GetWhileContext().grad_state
1626    return None
1627
1628  @property
1629  def back_prop(self):
1630    if self.GetWhileContext():
1631      self.GetWhileContext().back_prop
1632    return False
1633
1634  def GetControlPivot(self):
1635    return self._pivot
1636
1637  def to_proto(self, export_scope=None):
1638    """Converts a `CondContext` to a `CondContextDef` protocol buffer.
1639
1640    Args:
1641      export_scope: Optional `string`. Name scope to remove.
1642
1643    Returns:
1644      A `CondContextDef` protocol buffer.
1645    """
1646    if (export_scope is None or self.name.startswith(export_scope)):
1647      context_def = control_flow_pb2.CondContextDef()
1648      context_def.context_name = ops.strip_name_scope(self.name, export_scope)
1649      context_def.pred_name = ops.strip_name_scope(self._pred.name,
1650                                                   export_scope)
1651      context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
1652                                                    export_scope)
1653      context_def.branch = self._branch
1654      context_def.values_def.MergeFrom(
1655          super(CondContext, self)._to_values_def(export_scope))
1656      for nested in self._nested_contexts:
1657        nested_def = context_def.nested_contexts.add()
1658        nested.to_control_flow_context_def(nested_def)
1659
1660      return context_def
1661    else:
1662      return None
1663
1664  @staticmethod
1665  def from_proto(context_def, import_scope=None):
1666    """Returns a `CondContext` object created from `context_def`."""
1667    ret = CondContext(context_def=context_def, import_scope=import_scope)
1668
1669    ret.Enter()
1670    for nested_def in context_def.nested_contexts:
1671      from_control_flow_context_def(nested_def, import_scope=import_scope)
1672    ret.Exit()
1673    return ret
1674
1675  def to_control_flow_context_def(self, context_def, export_scope=None):
1676    context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
1677
1678  def AddValue(self, val):
1679    """Add `val` to the current context and its outer context recursively."""
1680    if val.name in self._values:
1681      # Use the real value if it comes from outer context. This is needed in
1682      # particular for nested conds.
1683      result = self._external_values.get(val.name)
1684      result = val if result is None else result
1685    else:
1686      result = val
1687      self._values.add(val.name)
1688      if self._outer_context:
1689        result = self._outer_context.AddValue(val)
1690        self._values.add(result.name)
1691        self._external_values[result.name] = result
1692      with ops.control_dependencies(None):
1693        result = _SwitchRefOrTensor(result, self._pred)[self._branch]
1694        if self._outer_context:
1695          self._outer_context.AddInnerOp(result.op)
1696
1697      result.op.graph.prevent_fetching(result.op)
1698      # pylint: disable=protected-access
1699      result.op._set_control_flow_context(self)
1700      # pylint: enable=protected-access
1701
1702      # Mark Switch output as seen by this context and any outer contexts,
1703      # just like what we do for normal op outputs in _AddOpInternal() below.
1704      ctxt = self
1705      while ctxt is not None:
1706        # pylint: disable=protected-access
1707        ctxt._values.add(result.name)
1708        ctxt = ctxt._outer_context
1709        # pylint: enable=protected-access
1710
1711      self._external_values[val.name] = result
1712    return result
1713
1714  def AddOp(self, op):
1715    self._AddOpInternal(op)
1716
1717  def _AddOpInternal(self, op):
1718    """Add `op` to the current context."""
1719    if not op.inputs:
1720      # If we're in a while loop, remove any control inputs from outside the
1721      # loop.
1722      self._RemoveExternalControlEdges(op)
1723
1724      if not any(
1725          util.OpInContext(input_op, self) for input_op in op.control_inputs):
1726        # pylint: disable=protected-access
1727        op._add_control_input(self._pivot.op)
1728        # pylint: enable=protected-access
1729    else:
1730      # Make each input to 'op' available in this CondContext. If an input is
1731      # already part of this context there's nothing to do, but if it's
1732      # external, AddValue() will handle adding the appropriate Switch node and
1733      # other bookkeeping.
1734      for index in range(len(op.inputs)):
1735        x = op.inputs[index]
1736        if op.type == "Merge" and x.op.type == "NextIteration":
1737          # Edge case: if we're importing a while loop inside this CondContext,
1738          # AddValue() will not correctly handle the NextIteration inputs to
1739          # Merge node. The problem is that the NextIteration should also be
1740          # part of this context, but if we're importing it won't have been
1741          # processed and added to the context yet, so AddValue() will try to
1742          # add a Switch which results in an invalid graph. Instead, we use the
1743          # NextIteration input as-is here, and it will eventually be added to
1744          # the context via AddOp().
1745          real_x = x
1746        else:
1747          real_x = self.AddValue(x)
1748        if real_x != x:
1749          # pylint: disable=protected-access
1750          op._update_input(index, real_x)
1751          # pylint: enable=protected-access
1752      # Remove any external control dependency on this op.
1753      self._RemoveExternalControlEdges(op)
1754      # pylint: disable=protected-access
1755      if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
1756        op._add_control_input(self._pivot.op)
1757      # pylint: enable=protected-access
1758
1759    # Mark op's outputs as seen by this context and any outer contexts.
1760    output_names = [x.name for x in op.outputs]
1761    ctxt = self
1762    while ctxt is not None:
1763      # pylint: disable=protected-access
1764      ctxt._values.update(output_names)
1765      ctxt = ctxt._outer_context
1766      # pylint: enable=protected-access
1767
1768    if self._outer_context or not util.IsLoopExit(op):
1769      op.graph.prevent_fetching(op)
1770
1771    if self._outer_context:
1772      self._outer_context.AddInnerOp(op)
1773
1774  def _ProcessOutputTensor(self, val):
1775    """Process an output tensor of a conditional branch."""
1776    real_val = val
1777    if val.name not in self._values:
1778      # Handle the special case of lambda: x
1779      self._values.add(val.name)
1780      if self._outer_context:
1781        real_val = self._outer_context.AddValue(val)
1782        self._values.add(real_val.name)
1783        self._external_values[real_val.name] = real_val
1784      real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]
1785      self._external_values[val.name] = real_val
1786    else:
1787      external_val = self._external_values.get(val.name)
1788      if external_val is not None:
1789        real_val = external_val
1790    return real_val
1791
1792  def _BuildCondTensor(self, v):
1793    if isinstance(v, ops.Operation):
1794      # Use pivot as the proxy for this op.
1795      return with_dependencies([v], self._pivot)
1796    else:
1797      v = nest.map_structure(_convert_tensorarray_to_flow, v,
1798                             expand_composites=True)
1799      return self._ProcessOutputTensor(ops.convert_to_tensor(v))
1800
1801  def BuildCondBranch(self, fn):
1802    """Add the subgraph defined by fn() to the graph."""
1803    pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1804    original_result = fn()
1805    post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1806    if len(post_summaries) > len(pre_summaries):
1807      new_summaries = post_summaries[len(pre_summaries):]
1808      summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1809      summary_ref[:] = pre_summaries
1810      with ops.control_dependencies(new_summaries):
1811        if original_result is None:
1812          return no_op(), None
1813        else:
1814          original_result = nest.map_structure(array_ops.identity,
1815                                               original_result,
1816                                               expand_composites=True)
1817    if original_result is None:
1818      return None, None
1819
1820    result = nest.map_structure(self._BuildCondTensor, original_result,
1821                                expand_composites=True)
1822    if not isinstance(result, (list, _basetuple)):
1823      result = [result]
1824    return original_result, result
1825
1826  def IsCondContext(self):
1827    return True
1828
1829
1830def _UnpackIfSingleton(res):
1831  if isinstance(res, (list, _basetuple)) and len(res) == 1:
1832    return res[0]
1833  else:
1834    return res
1835
1836
1837# pylint: disable=redefined-outer-name
1838# pylint: disable=g-doc-args
1839@tf_export(v1=["cond"])
1840@deprecation.deprecated_args(
1841    None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
1842    "fn1", "fn2")
1843def cond(pred,
1844         true_fn=None,
1845         false_fn=None,
1846         strict=False,
1847         name=None,
1848         fn1=None,
1849         fn2=None):
1850  """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
1851
1852  `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
1853  `false_fn` must have the same non-zero number and type of outputs.
1854
1855  **WARNING**: Any Tensors or Operations created outside of `true_fn` and
1856  `false_fn` will be executed regardless of which branch is selected at runtime.
1857
1858  Although this behavior is consistent with the dataflow model of TensorFlow,
1859  it has frequently surprised users who expected a lazier semantics.
1860  Consider the following simple program:
1861
1862  ```python
1863  z = tf.multiply(a, b)
1864  result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
1865  ```
1866
1867  If `x < y`, the `tf.add` operation will be executed and `tf.square`
1868  operation will not be executed. Since `z` is needed for at least one
1869  branch of the `cond`, the `tf.multiply` operation is always executed,
1870  unconditionally.
1871
1872  Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
1873  call to `cond`, and not at all during `Session.run()`). `cond`
1874  stitches together the graph fragments created during the `true_fn` and
1875  `false_fn` calls with some additional graph nodes to ensure that the right
1876  branch gets executed depending on the value of `pred`.
1877
1878  `tf.cond` supports nested structures as implemented in
1879  `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
1880  same (possibly nested) value structure of lists, tuples, and/or named tuples.
1881  Singleton lists and tuples form the only exceptions to this: when returned by
1882  `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
1883  This behavior is disabled by passing `strict=True`.
1884
1885  Args:
1886    pred: A scalar determining whether to return the result of `true_fn` or
1887      `false_fn`.
1888    true_fn: The callable to be performed if pred is true.
1889    false_fn: The callable to be performed if pred is false.
1890    strict: A boolean that enables/disables 'strict' mode; see above.
1891    name: Optional name prefix for the returned tensors.
1892
1893  Returns:
1894    Tensors returned by the call to either `true_fn` or `false_fn`. If the
1895    callables return a singleton list, the element is extracted from the list.
1896
1897  Raises:
1898    TypeError: if `true_fn` or `false_fn` is not callable.
1899    ValueError: if `true_fn` and `false_fn` do not return the same number of
1900      tensors, or return tensors of different types.
1901
1902  Example:
1903
1904  ```python
1905  x = tf.constant(2)
1906  y = tf.constant(5)
1907  def f1(): return tf.multiply(x, 17)
1908  def f2(): return tf.add(y, 23)
1909  r = tf.cond(tf.less(x, y), f1, f2)
1910  # r is set to f1().
1911  # Operations in f2 (e.g., tf.add) are not executed.
1912  ```
1913
1914  """
1915  # Always enable control flow v2 if building a function, regardless of toggle.
1916  if (util.EnableControlFlowV2(ops.get_default_graph()) and
1917      not context.executing_eagerly()):
1918    return cond_v2.cond_v2(pred, true_fn, false_fn, name)
1919
1920  # We needed to make true_fn/false_fn keyword arguments for
1921  # backwards-compatibility. This check exists so that we can convert back to
1922  # having them be positional arguments.
1923  # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
1924  # `fn1` and `fn2` are deleted.
1925  if fn1 is not None:
1926    if true_fn is not None:
1927      raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
1928    true_fn = fn1
1929  elif true_fn is None:
1930    raise TypeError("cond(): true_fn argument required")
1931  if fn2 is not None:
1932    if false_fn is not None:
1933      raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
1934    false_fn = fn2
1935  elif false_fn is None:
1936    raise TypeError("cond(): false_fn argument required")
1937
1938  if not callable(true_fn):
1939    raise TypeError("true_fn must be callable.")
1940  if not callable(false_fn):
1941    raise TypeError("false_fn must be callable.")
1942
1943  with ops.name_scope(name, "cond", [pred]):
1944    if context.executing_eagerly():
1945      if pred:
1946        return _UnpackIfSingleton(true_fn())
1947      return _UnpackIfSingleton(false_fn())
1948
1949    # Add the Switch to the graph.
1950    if isinstance(pred, bool):
1951      raise TypeError("pred must not be a Python bool")
1952    p_2, p_1 = switch(pred, pred)
1953    pivot_1 = array_ops.identity(p_1, name="switch_t")
1954    pivot_2 = array_ops.identity(p_2, name="switch_f")
1955    pred = array_ops.identity(pred, name="pred_id")
1956    # Disable the fetching of tensors that are only on one branch of cond.
1957    for tensor in [p_1, p_2, pivot_1, pivot_2, pred]:
1958      tensor.op.graph.prevent_fetching(tensor.op)
1959
1960    # Build the graph for the true branch in a new context.
1961    context_t = CondContext(pred, pivot_1, branch=1)
1962    try:
1963      context_t.Enter()
1964      orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
1965      if orig_res_t is None:
1966        raise ValueError("true_fn must have a return value.")
1967      context_t.ExitResult(res_t)
1968    finally:
1969      context_t.Exit()
1970
1971    # Build the graph for the false branch in a new context.
1972    context_f = CondContext(pred, pivot_2, branch=0)
1973    try:
1974      context_f.Enter()
1975      orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
1976      if orig_res_f is None:
1977        raise ValueError("false_fn must have a return value.")
1978      context_f.ExitResult(res_f)
1979    finally:
1980      context_f.Exit()
1981
1982    if not strict:
1983      orig_res_t = _UnpackIfSingleton(orig_res_t)
1984      orig_res_f = _UnpackIfSingleton(orig_res_f)
1985
1986    # Check that the return values of the two branches have the same structure.
1987    try:
1988      nest.assert_same_structure(orig_res_t, orig_res_f,
1989                                 expand_composites=True)
1990    except TypeError as e:
1991      raise TypeError(
1992          "Incompatible return types of true_fn and false_fn: {}".format(e))
1993    except ValueError as e:
1994      raise ValueError(
1995          "Incompatible return values of true_fn and false_fn: {}".format(e))
1996
1997    # Add the final merge to the graph.
1998    if not res_t:
1999      raise ValueError("true_fn and false_fn must return at least one result.")
2000
2001    res_t_flat = nest.flatten(res_t, expand_composites=True)
2002    res_f_flat = nest.flatten(res_f, expand_composites=True)
2003
2004    for i, (x, y) in enumerate(zip(res_t_flat, res_f_flat)):
2005      assert isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)
2006      if x.dtype.base_dtype != y.dtype.base_dtype:
2007        _cast_indexed_slice_indices(res_t, res_t_flat, res_f_flat)
2008        if res_t_flat[i].dtype.base_dtype != res_f_flat[i].dtype.base_dtype:
2009          raise ValueError(
2010              "Outputs of true_fn and false_fn must have the same type: "
2011              "%s, %s" % (x.dtype.name, y.dtype.name))
2012
2013    merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
2014    merges = _convert_flows_to_tensorarrays(
2015        nest.flatten(orig_res_t, expand_composites=True), merges)
2016
2017    # Only add non-nested conds to the collection. Any nested control flow will
2018    # be encapsulated in the root context.
2019    assert context_t.outer_context == context_f.outer_context
2020    if context_t.outer_context is None:
2021      ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
2022      ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
2023
2024    merges = nest.pack_sequence_as(structure=orig_res_t, flat_sequence=merges,
2025                                   expand_composites=True)
2026
2027    # Singleton lists and tuples are automatically unpacked if strict == False.
2028    if not strict:
2029      merges = _UnpackIfSingleton(merges)
2030    return merges
2031
2032
2033def _cast_indexed_slice_indices(structure, flat_a, flat_b):
2034  """Cast IndexedSlice.indices from int32 to int64 where necessary.
2035
2036  For each `IndexedSlices` in the nested structure `structure`, find its
2037  indices `Tensor` in the corresponding flattened lists `flat_a` and `flat_b`
2038  (where composites have been expanded); and if those indices tensors have
2039  different dtypes (i.e., if one is int64 but the other is int32), then cast
2040  them to both be int64.
2041
2042  Args:
2043    structure: The nested structure that was flattened.
2044    flat_a: A flattened list of `Tensors` whose structure matches
2045        `structure`.  Will be modified in place to cast `IndexedSlices`
2046        indices tensors to int64, where necessary.
2047    flat_a: A flattened list of `Tensors` whose structure matches
2048        `structure`.  Will be modified in place to cast `IndexedSlices`
2049        indices tensors to int64, where necessary.
2050  """
2051  # Find the locations (in flat_a and flat_b) of the IndexedSlices'
2052  # indices tensors.
2053  indexed_slice_indices = []
2054  current_index = 0
2055  for item in nest.flatten(structure, expand_composites=False):
2056    if isinstance(item, ops.IndexedSlices):
2057      # indices is the second component of the composite tensor.
2058      indexed_slice_indices.append(current_index + 1)
2059    if nest.is_sequence_or_composite(item):
2060      current_index += len(nest.flatten(item, expand_composites=True))
2061    else:
2062      current_index += 1
2063  assert current_index == len(flat_a)
2064
2065  for index in indexed_slice_indices:
2066    assert flat_a[index].dtype in (dtypes.int32, dtypes.int64)
2067    assert flat_b[index].dtype in (dtypes.int32, dtypes.int64)
2068    if flat_a[index].dtype != flat_b[index].dtype:
2069      if flat_b[index].dtype == dtypes.int32:
2070        flat_b[index] = math_ops.cast(flat_b[index], dtypes.int64)
2071      else:
2072        flat_a[index] = math_ops.cast(flat_a[index], dtypes.int64)
2073
2074
2075# pylint: enable=g-doc-args
2076# pylint: enable=redefined-outer-name
2077
2078
2079@tf_export("cond", v1=[])
2080def cond_for_tf_v2(pred,
2081                   true_fn=None,
2082                   false_fn=None,
2083                   name=None):
2084  """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
2085
2086  `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
2087  `false_fn` must have the same non-zero number and type of outputs.
2088
2089  **WARNING**: Any Tensors or Operations created outside of `true_fn` and
2090  `false_fn` will be executed regardless of which branch is selected at runtime.
2091
2092  Although this behavior is consistent with the dataflow model of TensorFlow,
2093  it has frequently surprised users who expected a lazier semantics.
2094  Consider the following simple program:
2095
2096  ```python
2097  z = tf.multiply(a, b)
2098  result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
2099  ```
2100
2101  If `x < y`, the `tf.add` operation will be executed and `tf.square`
2102  operation will not be executed. Since `z` is needed for at least one
2103  branch of the `cond`, the `tf.multiply` operation is always executed,
2104  unconditionally.
2105
2106  Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
2107  call to `cond`, and not at all during `Session.run()`). `cond`
2108  stitches together the graph fragments created during the `true_fn` and
2109  `false_fn` calls with some additional graph nodes to ensure that the right
2110  branch gets executed depending on the value of `pred`.
2111
2112  `tf.cond` supports nested structures as implemented in
2113  `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
2114  same (possibly nested) value structure of lists, tuples, and/or named tuples.
2115  Singleton lists and tuples form the only exceptions to this: when returned by
2116  `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
2117
2118  Args:
2119    pred: A scalar determining whether to return the result of `true_fn` or
2120      `false_fn`.
2121    true_fn: The callable to be performed if pred is true.
2122    false_fn: The callable to be performed if pred is false.
2123    name: Optional name prefix for the returned tensors.
2124
2125  Returns:
2126    Tensors returned by the call to either `true_fn` or `false_fn`. If the
2127    callables return a singleton list, the element is extracted from the list.
2128
2129  Raises:
2130    TypeError: if `true_fn` or `false_fn` is not callable.
2131    ValueError: if `true_fn` and `false_fn` do not return the same number of
2132      tensors, or return tensors of different types.
2133
2134  Example:
2135
2136  ```python
2137  x = tf.constant(2)
2138  y = tf.constant(5)
2139  def f1(): return tf.multiply(x, 17)
2140  def f2(): return tf.add(y, 23)
2141  r = tf.cond(tf.less(x, y), f1, f2)
2142  # r is set to f1().
2143  # Operations in f2 (e.g., tf.add) are not executed.
2144  ```
2145
2146  """
2147  return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
2148
2149
2150def _resource_safe_shape(t):
2151  """Returns the shape of t or the variable it points to."""
2152  if t.dtype == dtypes.resource:
2153    while t.op.inputs:
2154      t = t.op.inputs[0]
2155    return tensor_shape.TensorShape(t.op.get_attr("shape"))
2156  return array_ops.shape_internal(t, optimize=False)
2157
2158
2159# TODO(yuanbyu): Consider having a unified notion of context for
2160# not only conditionals and loops but also control dependency and
2161# subgraphs.
2162class WhileContext(ControlFlowContext):
2163  """The context for the loop construct."""
2164
2165  def __init__(self,
2166               maximum_iterations=None,
2167               parallel_iterations=10,
2168               back_prop=True,
2169               swap_memory=False,
2170               name="while_context",
2171               grad_state=None,
2172               context_def=None,
2173               import_scope=None):
2174    """"Creates a `WhileContext`.
2175
2176    Args:
2177      maximum_iterations: Optional upper bound on number of loop iterations.
2178      parallel_iterations: The number of iterations allowed to run in parallel.
2179      back_prop: Whether backprop is enabled for this while loop.
2180      swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
2181      name: Optional name prefix for the returned tensors.
2182      grad_state: The gradient loop state.
2183      context_def: Optional `WhileContextDef` protocol buffer to initialize the
2184        `Whilecontext` python object from.
2185      import_scope: Optional `string`. Name scope to add. Only used when
2186        initialing from protocol buffer.
2187    """
2188    if context_def:
2189      self._init_from_proto(context_def, import_scope=import_scope)
2190    else:
2191      ControlFlowContext.__init__(self)
2192      self._init_from_args(maximum_iterations, parallel_iterations, back_prop,
2193                           swap_memory, name)
2194    # The gradient loop state.
2195    self._grad_state = grad_state
2196
2197  def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop,
2198                      swap_memory, name):
2199    """Creates a new `WhileContext` from arguments.
2200
2201    Args:
2202      maximum_iterations: Optional upper bound on number of loop iterations.
2203      parallel_iterations: The number of iterations allowed to run in parallel.
2204      back_prop: Whether backprop is enabled for this while loop.
2205      swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
2206      name: Optional name prefix for the returned tensors.
2207
2208    Raises:
2209      ValueError: If `parallel_iterations` has invalid value.
2210    """
2211    if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0):
2212      raise ValueError("`parallel_iterations` must be a positive integer: "
2213                       "%s" % parallel_iterations)
2214    self._name = ops.get_default_graph().unique_name(name)
2215    self._maximum_iterations = maximum_iterations
2216    self._parallel_iterations = parallel_iterations
2217    self._back_prop = back_prop
2218    self._swap_memory = swap_memory
2219    # We use this node to control constants created by the pred lambda.
2220    self._pivot_for_pred = None
2221    # We use this node to control constants created by the body lambda.
2222    self._pivot_for_body = None
2223    # The boolean tensor for loop termination condition. Used in code
2224    # generation for gradient computation
2225    self._pivot = None
2226    # The list of exit tensors for loop variables.
2227    self._loop_exits = []
2228    # The list of enter tensors for loop variables.
2229    self._loop_enters = []
2230    self._graph = ops.get_default_graph()
2231
2232  def _init_from_proto(self, context_def, import_scope=None):
2233    """Creates a new `WhileContext` from protocol buffer.
2234
2235    Args:
2236      context_def: `WhileContextDef` protocol buffer.
2237      import_scope: Optional `string`. Name scope to add.
2238    """
2239    assert isinstance(context_def, control_flow_pb2.WhileContextDef)
2240    # Create from context_def.
2241    g = ops.get_default_graph()
2242    self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
2243    if context_def.maximum_iterations_name:
2244      self._maximum_iterations = g.as_graph_element(
2245          ops.prepend_name_scope(context_def.maximum_iterations_name,
2246                                 import_scope))
2247    else:
2248      self._maximum_iterations = None
2249    self._parallel_iterations = context_def.parallel_iterations
2250    self._back_prop = context_def.back_prop
2251    self._swap_memory = context_def.swap_memory
2252    self._pivot_for_pred = g.as_graph_element(
2253        ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope))
2254    # We use this node to control constants created by the body lambda.
2255    self._pivot_for_body = g.as_graph_element(
2256        ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope))
2257    # The boolean tensor for loop termination condition. Used in code
2258    # generation for gradient computation.
2259    self._pivot = g.as_graph_element(
2260        ops.prepend_name_scope(context_def.pivot_name, import_scope))
2261    # The list of exit tensors for loop variables.
2262    self._loop_exits = [
2263        g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope))
2264        for exit_name in context_def.loop_exit_names
2265    ]
2266    # The list of enter tensors for loop variables.
2267    self._loop_enters = [
2268        g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope))
2269        for enter_name in context_def.loop_enter_names
2270    ]
2271    super(WhileContext, self).__init__(
2272        values_def=context_def.values_def, import_scope=import_scope)
2273
2274    # import_scope causes self.name to be different from the original serialized
2275    # context's name. Rewrite "frame_name" attrs with the new name.
2276    if import_scope:
2277      for tensor_name in self._values:
2278        op = g.as_graph_element(tensor_name).op
2279        if util.IsLoopEnter(op):
2280          # pylint: disable=protected-access
2281          op._set_attr("frame_name",
2282                       attr_value_pb2.AttrValue(s=compat.as_bytes(self.name)))
2283          # pylint: enable=protected-access
2284    self._graph = ops.get_default_graph()
2285
2286  @property
2287  def maximum_iterations(self):
2288    """The maximum number of iterations that will be executed."""
2289    return self._maximum_iterations
2290
2291  @property
2292  def parallel_iterations(self):
2293    """The number of iterations allowed to run in parallel."""
2294    return self._parallel_iterations
2295
2296  @property
2297  def back_prop(self):
2298    """True iff backprop is enabled for this while loop."""
2299    return self._back_prop
2300
2301  @property
2302  def swap_memory(self):
2303    """True iff GPU-CPU memory swap is enabled for this while loop."""
2304    return self._swap_memory
2305
2306  @property
2307  def pivot(self):
2308    """The boolean tensor representing the loop termination condition."""
2309    return self._pivot
2310
2311  @property
2312  def loop_enters(self):
2313    """The list of enter tensors for loop variables."""
2314    return self._loop_enters
2315
2316  @property
2317  def loop_exits(self):
2318    """The list of exit tensors for loop variables."""
2319    return self._loop_exits
2320
2321  @property
2322  def grad_state(self):
2323    """The gradient loop state."""
2324    return self._grad_state
2325
2326  def to_proto(self, export_scope=None):
2327    """Converts a `WhileContext` to a `WhileContextDef` protocol buffer.
2328
2329    Args:
2330      export_scope: Optional `string`. Name scope to remove.
2331
2332    Returns:
2333      A `WhileContextDef` protocol buffer.
2334    """
2335    if (export_scope is None or self.name.startswith(export_scope)):
2336      context_def = control_flow_pb2.WhileContextDef()
2337      context_def.context_name = ops.strip_name_scope(self.name, export_scope)
2338      context_def.parallel_iterations = self._parallel_iterations
2339      if self._maximum_iterations is not None:
2340        context_def.maximum_iterations_name = ops.strip_name_scope(
2341            self._maximum_iterations.name, export_scope)
2342      context_def.back_prop = self._back_prop
2343      context_def.swap_memory = self._swap_memory
2344      context_def.pivot_for_pred_name = ops.strip_name_scope(
2345          self._pivot_for_pred.name, export_scope)
2346      context_def.pivot_for_body_name = ops.strip_name_scope(
2347          self._pivot_for_body.name, export_scope)
2348      context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
2349                                                    export_scope)
2350      context_def.loop_exit_names.extend([
2351          ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits
2352      ])
2353      context_def.loop_enter_names.extend([
2354          ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters
2355      ])
2356      context_def.values_def.MergeFrom(
2357          super(WhileContext, self)._to_values_def(export_scope=export_scope))
2358      for nested in self._nested_contexts:
2359        nested_def = context_def.nested_contexts.add()
2360        nested.to_control_flow_context_def(nested_def)
2361
2362      return context_def
2363    else:
2364      return None
2365
2366  def to_control_flow_context_def(self, context_def, export_scope=None):
2367    context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
2368
2369  @staticmethod
2370  def from_proto(context_def, import_scope=None):
2371    """Returns a `WhileContext` object created from `context_def`.
2372
2373    Args:
2374      context_def: A `WhileContextDef` protocol buffer.
2375      import_scope: Optional `string`. Name scope to add.
2376
2377    Returns:
2378      A `WhileContext` Python object.
2379    """
2380    ret = WhileContext(context_def=context_def, import_scope=import_scope)
2381    ret.Enter()
2382    for nested_def in context_def.nested_contexts:
2383      from_control_flow_context_def(nested_def, import_scope=import_scope)
2384    ret.Exit()
2385    return ret
2386
2387  def GetWhileContext(self):
2388    return self
2389
2390  def GetControlPivot(self):
2391    if self._pivot_for_body is not None:
2392      return self._pivot_for_body
2393    return self._pivot_for_pred
2394
2395  def AddValue(self, val):
2396    """Add `val` to the current context and its outer context recursively."""
2397    result = val
2398    new_value = val.name not in self._values
2399    # Don't treat ops in this context as new values. Usually all known values
2400    # are in self._values, except when we're importing a while loop inside this
2401    # WhileContext. Since there's a cycle in this case, `val` may be part of the
2402    # imported while loop but not yet processed by this context and added to
2403    # self._values in _AddOpInternal. We only want to process external input
2404    # tensors to the while loop here.
2405    new_value &= val.op._control_flow_context is not self  # pylint: disable=protected-access
2406    if new_value:
2407      self._values.add(val.name)
2408
2409      # If we are in a grad context and val is from its forward context,
2410      # use GetRealValue(), which adds the logic to save the history of
2411      # val in forward.
2412      grad_ctxt = ops.get_default_graph()._get_control_flow_context()
2413      if grad_ctxt:
2414        grad_ctxt = grad_ctxt.GetWhileContext()
2415        if grad_ctxt.grad_state:
2416          forward_ctxt = _GetWhileContext(val.op)
2417          if util.IsLoopExit(val.op):
2418            forward_ctxt = forward_ctxt.outer_context
2419            if forward_ctxt:
2420              forward_ctxt = forward_ctxt.GetWhileContext()
2421          if forward_ctxt == grad_ctxt.grad_state.forward_context:
2422            real_val = grad_ctxt.grad_state.GetRealValue(val)
2423            self._external_values[val.name] = real_val
2424            return real_val
2425
2426      if self._outer_context is not None:
2427        result = self._outer_context.AddValue(val)
2428      # Create an Enter to make `result` known to this loop context.
2429      with ops.control_dependencies(None):
2430        enter = _Enter(
2431            result,
2432            self._name,
2433            is_constant=True,
2434            parallel_iterations=self._parallel_iterations)
2435        enter.graph.prevent_feeding(enter)
2436        if self._outer_context:
2437          self._outer_context.AddInnerOp(enter.op)
2438      # Fix the control inputs and control flow context of these enter ops.
2439      self._FixControlInputsAndContext([enter])
2440
2441      # Add `enter` in this context.
2442      self._values.add(enter.name)
2443      self._external_values[val.name] = enter
2444      result = enter
2445    else:
2446      actual_val = self._external_values.get(val.name)
2447      if actual_val is not None:
2448        result = actual_val
2449    return result
2450
2451  def AddOp(self, op):
2452    """Add `op` to the current context."""
2453    # For a reduction op, if op is in a grad context and its input is from
2454    # its forward context, moving op to the forward context means we would
2455    # store the tensor after the reduction as opposed to the tensor before
2456    # reduction, and therefore could significantly reduce memory consumption.
2457    # For now, we do this only for a few ops.
2458    if op.type in {"Shape", "Size", "Rank"}:
2459      grad_ctxt = ops.get_default_graph()._get_control_flow_context()
2460      if grad_ctxt:
2461        grad_ctxt = grad_ctxt.GetWhileContext()
2462        if grad_ctxt.grad_state:
2463          op_input_forward_ctxt = _GetWhileContext(op.inputs[0].op)
2464          if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context:
2465            op_input_ctxt = op.inputs[0].op._get_control_flow_context()
2466            op._set_control_flow_context(op_input_ctxt)
2467            op_input_ctxt._AddOpInternal(op)
2468            return
2469    self._AddOpInternal(op)
2470
2471  def _AddOpInternal(self, op):
2472    """Add `op` to the current context.
2473
2474    We move any external control dependencies of the op to the loop pivot, to
2475    ensure they get executed.
2476    """
2477    if not op.inputs:
2478      # Remove any external control dependency on this op
2479      control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
2480      # Add a control edge from the control pivot to this op.
2481      if not control_inputs:
2482        # pylint: disable=protected-access
2483        op._add_control_input(self.GetControlPivot().op)
2484        # pylint: enable=protected-access
2485      for x in op.outputs:
2486        self._values.add(x.name)
2487    else:
2488      for index in range(len(op.inputs)):
2489        x = op.inputs[index]
2490        real_x = self.AddValue(x)
2491        if real_x != x:
2492          op._update_input(index, real_x)  # pylint: disable=protected-access
2493      # Remove any external control dependency on this op.
2494      _, external_inputs = self._RemoveExternalControlEdges(op)
2495      # Add a control dependency to prevent loop invariants from
2496      # enabling ops that should not be executed.
2497      self._MaybeAddControlDependency(op)
2498      for x in op.outputs:
2499        self._values.add(x.name)
2500    if external_inputs:
2501      # Use an identity to pull control inputs as data inputs. Note that we
2502      # ignore ops which don't have outputs. TODO(apassos): fix that
2503      with ops.control_dependencies(None):
2504        self.Enter()
2505        external_inputs = [
2506            array_ops.identity(x.outputs[0]).op
2507            for x in external_inputs
2508            if x.outputs
2509        ]
2510        self.Exit()
2511      op._add_control_inputs(external_inputs)  # pylint: disable=protected-access
2512    if self._outer_context or not util.IsLoopExit(op):
2513      op.graph.prevent_fetching(op)
2514      for x in op.outputs:
2515        op.graph.prevent_feeding(x)
2516
2517    if self._outer_context:
2518      self._outer_context.AddInnerOp(op)
2519
2520  def _MaybeAddControlDependency(self, op):
2521    """Add a control input to the op if it only depends on loop invariants."""
2522
2523    def _IsOpFree(op):
2524      """Determines if `op` needs a control dependency."""
2525      if op.control_inputs:
2526        return False
2527      # pylint: disable=protected-access
2528      if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
2529        return True
2530      # pylint: enable=protected-access
2531      for x in op.inputs:
2532        if not util.IsLoopConstantEnter(x.op):
2533          return False
2534      return True
2535
2536    if _IsOpFree(op):
2537      # pylint: disable=protected-access
2538      op._add_control_input(self.GetControlPivot().op)
2539      # pylint: enable=protected-access
2540
2541  def AddForwardLoopCounter(self, outer_grad_state):
2542    """Adds a loop that counts the number of iterations.
2543
2544    This is added to the forward loop at the time when we start to
2545    create the loop for backprop gradient computation. Called in
2546    the outer context of this forward context.
2547
2548    The pseudocode is:
2549      `n = 0; while (_pivot) { n++; }`
2550
2551    Note that a control dependency is added to `n` to ensure the correct
2552    execution order of stack push ops.
2553
2554    Args:
2555      outer_grad_state: The outer grad state. None if not nested.
2556
2557    Returns:
2558      The number of iterations taken by the forward loop and the loop index.
2559    """
2560    n = constant_op.constant(0, name="f_count")
2561    if outer_grad_state is not None:
2562      # Force the stack pushes of i-th execution of an inner loop to be ordered
2563      # before the pushes of (i+1)-th execution of the same inner loop.
2564      outer_add_op = outer_grad_state.forward_index.op.inputs[0].op
2565      n.op._add_control_input(outer_add_op)  # pylint: disable=protected-access
2566
2567    self.Enter()
2568    self.AddName(n.name)
2569    enter_n = _Enter(
2570        n,
2571        self._name,
2572        is_constant=False,
2573        parallel_iterations=self._parallel_iterations,
2574        name="f_count")
2575    self.loop_enters.append(enter_n)
2576
2577    merge_n = merge([enter_n, enter_n])[0]
2578    switch_n = switch(merge_n, self._pivot)
2579
2580    index = math_ops.add(switch_n[1], 1)
2581    next_n = _NextIteration(index)
2582    merge_n.op._update_input(1, next_n)
2583
2584    total_iterations = exit(switch_n[0], name="f_count")
2585    self.loop_exits.append(total_iterations)
2586    self.ExitResult([total_iterations])
2587    self.Exit()
2588    return total_iterations, next_n
2589
2590  def AddBackpropLoopCounter(self, count, outer_grad_state):
2591    """Add the backprop loop that controls the iterations.
2592
2593    This is added to the backprop loop. It is used to control the loop
2594    termination of the backprop loop. Called in the outer context of
2595    this grad context.
2596
2597    The pseudocode is:
2598      `n = count; while (n >= 1) { n--; }`
2599
2600    Note that a control dependency is added to `final_zero` to ensure the
2601    correct execution order of stack pop ops.
2602
2603    Args:
2604      count: The number of iterations for backprop.
2605      outer_grad_state: The outer grad state. None if not nested.
2606
2607    Returns:
2608      The loop index.
2609    """
2610    in_separate_functions = count.graph is not ops.get_default_graph()
2611    if in_separate_functions:
2612      # Brings the count into this graph
2613      count = array_ops.identity(count)
2614    else:
2615      # TODO(apassos) XLA expects this constant to be created outside the loop,
2616      # so doing that for now.
2617      one = constant_op.constant(1, name="b_count")
2618
2619    self.Enter()
2620    self.AddName(count.name)
2621    enter_count = _Enter(
2622        count,
2623        self._name,
2624        is_constant=False,
2625        parallel_iterations=self._parallel_iterations,
2626        name="b_count")
2627    self.loop_enters.append(enter_count)
2628
2629    merge_count = merge([enter_count, enter_count])[0]
2630    self._pivot_for_pred = merge_count
2631
2632    if in_separate_functions:
2633      one = constant_op.constant(1, name="b_count")
2634    pred = math_ops.greater_equal(merge_count, one)
2635    self._pivot = loop_cond(pred, name="b_count")
2636    switch_count = switch(merge_count, self._pivot)
2637
2638    index = math_ops.subtract(switch_count[1], one)
2639    self._pivot_for_body = index
2640    next_count = _NextIteration(index)
2641    merge_count.op._update_input(1, next_count)
2642
2643    final_zero = exit(switch_count[0], name="b_count")
2644    self.loop_exits.append(final_zero)
2645    if outer_grad_state is not None:
2646      # Force the stack pops of i-th execution of an inner loop to be ordered
2647      # before the pops of (i+1)-th execution of the same inner loop.
2648      # pylint: disable=protected-access
2649      outer_grad_state.grad_sync._add_control_input(final_zero.op)
2650      # pylint: enable=protected-access
2651
2652    self.ExitResult([final_zero])
2653    self.Exit()
2654    return next_count
2655
2656  def AddBackpropAccumulator(self, op, grad):
2657    """Add an accumulation loop for every loop invariant.
2658
2659    This is added to the backprop loop. It is used to accumulate partial
2660    gradients within each loop iteration. Called when in the gradient while
2661    context.
2662
2663    The pseudocode is:
2664      ```
2665      acc = 0.0;
2666      while (_pivot) {
2667        acc += grad;
2668      }
2669      ```
2670
2671    Args:
2672      op: The Enter op for a loop invariant.
2673      grad: The partial gradient of an iteration for a loop invariant.
2674
2675    Returns:
2676      The gradient for a loop invariant.
2677    """
2678    self.Exit()
2679    # Create a zeros tensor with the right shape for acc. If we don't
2680    # know the full shape statically, we will have to get the shape
2681    # dynamically from the forward inference. Getting the shape right
2682    # for the zeros is only needed for the base case when the loop exits
2683    # without running any iterations.
2684    shape = grad.get_shape()
2685    if shape.is_fully_defined():
2686      if self.outer_context:
2687        self.outer_context.Enter()
2688      acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")
2689      if self.outer_context:
2690        self.outer_context.Exit()
2691    else:
2692      value = op.inputs[0]
2693      if (isinstance(self.outer_context, WhileContext) and
2694          self.outer_context.grad_state is not None):
2695        # We are in a nested while loop.
2696        forward_ctxt = self.grad_state.forward_context
2697        forward_ctxt.outer_context.Enter()
2698        zeros_shape = array_ops.shape_internal(value, optimize=False)
2699        forward_ctxt.outer_context.Exit()
2700        outer_grad_state = self.grad_state.outer_grad_state
2701        history_zeros_shape = outer_grad_state.AddForwardAccumulator(
2702            zeros_shape)
2703        self.outer_context.Enter()
2704        real_shape = outer_grad_state.AddBackpropAccumulatedValue(
2705            history_zeros_shape, zeros_shape)
2706        acc = array_ops.zeros(real_shape, grad.dtype)
2707        self.outer_context.Exit()
2708      else:
2709        if self.outer_context:
2710          self.outer_context.Enter()
2711        zeros_shape = array_ops.shape_internal(value, optimize=False)
2712        acc = array_ops.zeros(zeros_shape, grad.dtype)
2713        if self.outer_context:
2714          self.outer_context.Exit()
2715
2716    self.Enter()
2717    self.AddName(acc.name)
2718    enter_acc = _Enter(
2719        acc,
2720        self._name,
2721        is_constant=False,
2722        parallel_iterations=self._parallel_iterations,
2723        name="b_acc")
2724    self.loop_enters.append(enter_acc)
2725
2726    merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
2727    switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot)
2728
2729    add_acc = math_ops.add(switch_acc_true, grad)
2730    next_acc = _NextIteration(add_acc)
2731    merge_acc.op._update_input(1, next_acc)  # pylint: disable=protected-access
2732
2733    result_acc = exit(switch_acc_false, name="b_acc")
2734    self.loop_exits.append(result_acc)
2735    self.ExitResult([result_acc])
2736    return result_acc
2737
2738  def AddBackpropIndexedSlicesAccumulator(self, op, grad):
2739    """This is used for accumulating gradients that are IndexedSlices.
2740
2741    This is essentially the equivalent of AddBackpropAccumulator but optimized
2742    for things like updating embeddings from within a while loop.
2743
2744    Args:
2745      op: The Enter op for a loop invariant.
2746      grad: The partial gradients represented as an IndexedSlices.
2747
2748    Returns:
2749      The accumulated IndexedSlices gradient of the loop invariant.
2750    """
2751    values = grad.values
2752    indices = grad.indices
2753    dense_shape = grad.dense_shape
2754
2755    self.Exit()
2756    if self.outer_context:
2757      self.outer_context.Enter()
2758    if values.get_shape().is_fully_defined():
2759      values_shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] +
2760                                              values.get_shape().dims[1:])
2761      if self.outer_context:
2762        self.outer_context.Enter()
2763      values_acc = constant_op.constant(
2764          0, values.dtype, shape=values_shape, name="b_acc")
2765      if self.outer_context:
2766        self.outer_context.Exit()
2767    else:
2768      values_shape = _resource_safe_shape(op.inputs[0])[1:]
2769      values_shape = array_ops.concat([[1], values_shape], 0)
2770      values_acc = array_ops.zeros(values_shape, dtype=values.dtype)
2771    indices_acc = constant_op.constant([0], indices.dtype)
2772    shape_acc = None
2773    if dense_shape is not None:
2774      if dense_shape.get_shape().is_fully_defined():
2775        if self.outer_context:
2776          self.outer_context.Enter()
2777        shape_acc = constant_op.constant(
2778            0, dense_shape.dtype, shape=dense_shape.get_shape())
2779        if self.outer_context:
2780          self.outer_context.Exit()
2781      else:
2782        shape_acc = array_ops.zeros_like(
2783            array_ops.shape_internal(
2784                op.inputs[0], optimize=False, out_type=dense_shape.dtype),
2785            optimize=False)
2786
2787    if self.outer_context:
2788      self.outer_context.Exit()
2789
2790    self.Enter()
2791    self.AddName(values_acc.name)
2792    self.AddName(indices_acc.name)
2793    init_acc = [indices_acc, values_acc]
2794    if shape_acc is not None:
2795      self.AddName(shape_acc.name)
2796      init_acc.append(shape_acc)
2797
2798    # Set use_input_shape=False since the accumulator tensors will grow in
2799    # size. If use_input_shape=True, the _update_input call below will result in
2800    # incompatible shapes.
2801    enter_acc = [
2802        _Enter(
2803            x,
2804            self._name,
2805            is_constant=False,
2806            parallel_iterations=self._parallel_iterations,
2807            use_input_shape=False,
2808            name="b_acc") for x in init_acc
2809    ]
2810    # Manually set appropriate partial shapes.
2811    enter_acc[0].set_shape([None])
2812    if values_acc.shape.dims is not None:
2813      enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:])
2814    self.loop_enters.extend(enter_acc)
2815
2816    merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc]
2817    switch_acc = [switch(x, self._pivot) for x in merge_acc]
2818
2819    # The actual accumulation.
2820    acc_indexed_slices = [
2821        array_ops.concat([xa[1], xv], 0)
2822        for xa, xv in zip(switch_acc[:2], [indices, values])
2823    ]
2824    if shape_acc is not None:
2825      # For the shape we just keep the maximum
2826      acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1]))
2827
2828    next_acc = [_NextIteration(x) for x in acc_indexed_slices]
2829    for xm, xn in zip(merge_acc, next_acc):
2830      xm.op._update_input(1, xn)  # pylint: disable=protected-access
2831
2832    exit_acc = [exit(x[0], name="b_acc") for x in switch_acc]
2833    self.loop_exits.extend(exit_acc)
2834
2835    self.ExitResult(exit_acc)
2836    return ops.IndexedSlices(
2837        indices=exit_acc[0],
2838        values=exit_acc[1],
2839        dense_shape=exit_acc[2] if shape_acc is not None else None)
2840
2841  def _InitializeValues(self, values):
2842    """Makes the values known to this context."""
2843    self._values = set()
2844    for x in values:
2845      if isinstance(x, ops.Tensor):
2846        self._values.add(x.name)
2847      else:
2848        raise TypeError("Type %s not supported" % type(x))
2849
2850  def _BuildLoop(self, pred, body, original_loop_vars, loop_vars,
2851                 shape_invariants):
2852    """Core: Add the loop termination condition and body to the graph."""
2853    flat_loop_vars = nest.flatten(original_loop_vars, expand_composites=True)
2854
2855    # Let the context know the loop variables so the loop variables
2856    # would be added in the outer contexts properly.
2857    self._InitializeValues(loop_vars)
2858    real_vars = loop_vars
2859    if self._outer_context:
2860      real_vars = [self._outer_context.AddValue(x) for x in loop_vars]
2861    with ops.control_dependencies(None):
2862      enter_vars = [
2863          _Enter(
2864              x,
2865              self._name,
2866              is_constant=False,
2867              parallel_iterations=self._parallel_iterations,
2868              use_input_shape=(shape_invariants is None)) for x in real_vars
2869      ]
2870      for x in enter_vars:
2871        x.graph.prevent_feeding(x)
2872        if self._outer_context:
2873          self._outer_context.AddInnerOp(x.op)
2874
2875    # Finds the closest enclosing non-None control pivot.
2876    outer_context = self._outer_context
2877    control_pivot = None
2878    while outer_context is not None and control_pivot is None:
2879      control_pivot = outer_context.GetControlPivot()
2880      # pylint: disable=protected-access
2881      outer_context = outer_context._outer_context
2882      # pylint: enable=protected-access
2883
2884    if control_pivot is not None:
2885      for var in enter_vars:
2886        if util.IsLoopConstantEnter(var.op.inputs[0].op):
2887          # pylint: disable=protected-access
2888          var.op._add_control_input(control_pivot.op)
2889          # pylint: enable=protected-access
2890    _SetShapeInvariants(real_vars, enter_vars, shape_invariants)
2891
2892    # Fix the control inputs and control flow context of these enter ops.
2893    self._FixControlInputsAndContext(enter_vars)
2894    self._InitializeValues(enter_vars)
2895    self._loop_enters = enter_vars
2896
2897    merge_vars = [merge([x, x])[0] for x in enter_vars]
2898    self._pivot_for_pred = merge_vars[0]
2899
2900    # Build the graph for pred.
2901    merge_vars_with_tensor_arrays = (
2902        _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars))
2903    packed_vars = nest.pack_sequence_as(
2904        structure=original_loop_vars,
2905        flat_sequence=merge_vars_with_tensor_arrays,
2906        expand_composites=True)
2907    c = ops.convert_to_tensor(pred(*packed_vars))
2908    self._pivot = loop_cond(c, name="LoopCond")
2909    switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]
2910
2911    # Build the graph for body.
2912    vars_for_body = [_Identity(x[1]) for x in switch_vars]
2913    self._pivot_for_body = vars_for_body[0]
2914    # Convert TensorArray flow variables inside the context back into
2915    # their associated TensorArrays for calling the body.
2916    vars_for_body_with_tensor_arrays = (
2917        _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body))
2918    packed_vars_for_body = nest.pack_sequence_as(
2919        structure=original_loop_vars,
2920        flat_sequence=vars_for_body_with_tensor_arrays,
2921        expand_composites=True)
2922    pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2923    body_result = body(*packed_vars_for_body)
2924    post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2925    if not nest.is_sequence_or_composite(body_result):
2926      body_result = [body_result]
2927    if len(post_summaries) > len(pre_summaries):
2928      new_summaries = post_summaries[len(pre_summaries):]
2929      summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2930      summary_ref[:] = pre_summaries
2931      with ops.control_dependencies(new_summaries):
2932
2933        def map_fn(x):
2934          # TODO(apassos) figure out how to trigger with tensor arrays as well
2935          if isinstance(x, tensor_array_ops.TensorArray):
2936            return x
2937          return array_ops.identity(x)
2938
2939        body_result = nest.map_structure(map_fn, body_result,
2940                                         expand_composites=True)
2941
2942    # Compare the structure types of input and output of body.
2943    # For backwards compatibility, the first layer is forced to a list
2944    # during this comparison, because inputs are typically lists and
2945    # outputs of the body are typically tuples.
2946    nest.assert_same_structure(list(packed_vars_for_body), list(body_result),
2947                               expand_composites=True)
2948
2949    # Store body_result to keep track of TensorArrays returned by body
2950    original_body_result = body_result
2951    # Convert TensorArrays returned by body into their flow variables
2952    result = nest.map_structure(
2953        _convert_tensorarray_to_flow,
2954        nest.flatten(body_result, expand_composites=True),
2955        expand_composites=True)
2956    result = ops.convert_n_to_tensor_or_composite(result)
2957
2958    # Add NextIteration and the back edges to complete the loop.
2959    if len(merge_vars) != len(result):
2960      raise ValueError("Number of inputs and outputs of body must match "
2961                       "loop_vars: %d, %d" % (len(merge_vars), len(result)))
2962    next_vars = []
2963    for m, v in zip(merge_vars, result):
2964      next_vars.append(_AddNextAndBackEdge(m, v))
2965
2966    # Add the exit ops.
2967    exit_vars = [exit(x[0]) for x in switch_vars]
2968    self._loop_exits = exit_vars
2969
2970    # Exit the loop.
2971    self.ExitResult(exit_vars)
2972
2973    return original_body_result, exit_vars
2974
2975  def BuildLoop(self, pred, body, loop_vars, shape_invariants,
2976                return_same_structure):
2977    """Add the loop termination condition and body to the graph."""
2978
2979    # Keep original_loop_vars to identify which are TensorArrays
2980    original_loop_vars = loop_vars
2981    # Convert TensorArrays to their flow variables
2982    loop_vars = nest.map_structure(
2983        _convert_tensorarray_to_flow,
2984        nest.flatten(loop_vars, expand_composites=False),
2985        expand_composites=True)
2986    loop_vars = ops.convert_n_to_tensor_or_composite(loop_vars)
2987    if shape_invariants is None:
2988      shape_invariants = nest.map_structure(
2989          _get_shape_invariant, loop_vars, expand_composites=False)
2990    loop_vars = nest.flatten(loop_vars, expand_composites=True)
2991    try:
2992      self.Enter()
2993      # _BuildLoop calls _update_input in several places. _mutation_lock()
2994      # ensures a Session.run call cannot occur between creating and mutating
2995      # new ops.
2996      with ops.get_default_graph()._mutation_lock():  # pylint: disable=protected-access
2997        original_body_result, exit_vars = self._BuildLoop(
2998            pred, body, original_loop_vars, loop_vars, shape_invariants)
2999    finally:
3000      self.Exit()
3001
3002    flat_result = nest.flatten(original_body_result, expand_composites=True)
3003    # Convert TensorArray flow variables outside the context back into
3004    # their associated TensorArrays for returning to caller.
3005    exit_vars_with_tensor_arrays = (
3006        _convert_flows_to_tensorarrays(flat_result, exit_vars))
3007    packed_exit_vars = nest.pack_sequence_as(
3008        structure=original_body_result,
3009        flat_sequence=exit_vars_with_tensor_arrays,
3010        expand_composites=True)
3011
3012    if return_same_structure:
3013      return packed_exit_vars
3014    else:
3015      return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars
3016
3017  def _FixControlInputsAndContext(self, enters):
3018    graph = ops.get_default_graph()
3019    # pylint: disable=protected-access
3020    for e in enters:
3021      if isinstance(e, ops.Tensor):
3022        xs = [e]
3023      else:
3024        raise TypeError("Type %s not supported" % type(e))
3025      for x in xs:
3026        inp_op = x.op.inputs[0].op
3027        control_inputs = graph._control_dependencies_for_inputs([inp_op])
3028        outer_control_inputs = [
3029            op for op in control_inputs if self._IsInOuterContext(op)
3030        ]
3031        x.op._set_control_flow_context(self)
3032        x.op._add_control_inputs(outer_control_inputs)
3033        graph._record_op_seen_by_control_dependencies(x.op)
3034    # pylint: enable=protected-access
3035
3036  def IsWhileContext(self):
3037    return True
3038
3039
3040# pylint: disable=redefined-outer-name
3041@tf_export("while_loop", v1=[])
3042def while_loop_v2(cond,
3043                  body,
3044                  loop_vars,
3045                  shape_invariants=None,
3046                  parallel_iterations=10,
3047                  back_prop=True,
3048                  swap_memory=False,
3049                  maximum_iterations=None,
3050                  name=None):
3051  """Repeat `body` while the condition `cond` is true.
3052
3053  `cond` is a callable returning a boolean scalar tensor. `body` is a callable
3054  returning a (possibly nested) tuple, namedtuple or list of tensors of the same
3055  arity (length and structure) and types as `loop_vars`. `loop_vars` is a
3056  (possibly nested) tuple, namedtuple or list of tensors that is passed to both
3057  `cond` and `body`. `cond` and `body` both take as many arguments as there are
3058  `loop_vars`.
3059
3060  In addition to regular Tensors or IndexedSlices, the body may accept and
3061  return TensorArray objects.  The flows of the TensorArray objects will
3062  be appropriately forwarded between loops and during gradient calculations.
3063
3064  Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
3065  call to `while_loop`, and not at all during `Session.run()`). `while_loop`
3066  stitches together the graph fragments created during the `cond` and `body`
3067  calls with some additional graph nodes to create the graph flow that
3068  repeats `body` until `cond` returns false.
3069
3070  For correctness, `tf.while_loop()` strictly enforces shape invariants for
3071  the loop variables. A shape invariant is a (possibly partial) shape that
3072  is unchanged across the iterations of the loop. An error will be raised
3073  if the shape of a loop variable after an iteration is determined to be more
3074  general than or incompatible with its shape invariant. For example, a shape
3075  of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
3076  compatible with [11, 17]. By default (if the argument `shape_invariants` is
3077  not specified), it is assumed that the initial shape of each tensor in
3078  `loop_vars` is the same in every iteration. The `shape_invariants` argument
3079  allows the caller to specify a less specific shape invariant for each loop
3080  variable, which is needed if the shape varies between iterations. The
3081  `tf.Tensor.set_shape`
3082  function may also be used in the `body` function to indicate that
3083  the output loop variable has a particular shape. The shape invariant for
3084  SparseTensor and IndexedSlices are treated specially as follows:
3085
3086  a) If a loop variable is a SparseTensor, the shape invariant must be
3087  TensorShape([r]) where r is the rank of the dense tensor represented
3088  by the sparse tensor. It means the shapes of the three tensors of the
3089  SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
3090  is the shape of the SparseTensor.dense_shape property. It must be the shape of
3091  a vector.
3092
3093  b) If a loop variable is an IndexedSlices, the shape invariant must be
3094  a shape invariant of the values tensor of the IndexedSlices. It means
3095  the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
3096  [shape.ndims]).
3097
3098  `while_loop` implements non-strict semantics, enabling multiple iterations
3099  to run in parallel. The maximum number of parallel iterations can be
3100  controlled by `parallel_iterations`, which gives users some control over
3101  memory consumption and execution order. For correct programs, `while_loop`
3102  should return the same result for any parallel_iterations > 0.
3103
3104  For training, TensorFlow stores the tensors that are produced in the
3105  forward inference and are needed in back propagation. These tensors are a
3106  main source of memory consumption and often cause OOM errors when training
3107  on GPUs. When the flag swap_memory is true, we swap out these tensors from
3108  GPU to CPU. This for example allows us to train RNN models with very long
3109  sequences and large batches.
3110
3111  Args:
3112    cond: A callable that represents the termination condition of the loop.
3113    body: A callable that represents the loop body.
3114    loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
3115      `Tensor`, and `TensorArray` objects.
3116    shape_invariants: The shape invariants for the loop variables.
3117    parallel_iterations: The number of iterations allowed to run in parallel. It
3118      must be a positive integer.
3119    back_prop: Whether backprop is enabled for this while loop.
3120    swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
3121    maximum_iterations: Optional maximum number of iterations of the while loop
3122      to run.  If provided, the `cond` output is AND-ed with an additional
3123      condition ensuring the number of iterations executed is no greater than
3124      `maximum_iterations`.
3125    name: Optional name prefix for the returned tensors.
3126
3127  Returns:
3128    The output tensors for the loop variables after the loop. The return value
3129      has the same structure as `loop_vars`.
3130
3131  Raises:
3132    TypeError: if `cond` or `body` is not callable.
3133    ValueError: if `loop_vars` is empty.
3134
3135  Example:
3136
3137  ```python
3138  i = tf.constant(0)
3139  c = lambda i: tf.less(i, 10)
3140  b = lambda i: tf.add(i, 1)
3141  r = tf.while_loop(c, b, [i])
3142  ```
3143
3144  Example with nesting and a namedtuple:
3145
3146  ```python
3147  import collections
3148  Pair = collections.namedtuple('Pair', 'j, k')
3149  ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
3150  c = lambda i, p: i < 10
3151  b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
3152  ijk_final = tf.while_loop(c, b, ijk_0)
3153  ```
3154
3155  Example using shape_invariants:
3156
3157  ```python
3158  i0 = tf.constant(0)
3159  m0 = tf.ones([2, 2])
3160  c = lambda i, m: i < 10
3161  b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
3162  tf.while_loop(
3163      c, b, loop_vars=[i0, m0],
3164      shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
3165  ```
3166
3167  Example which demonstrates non-strict semantics: In the following
3168  example, the final value of the counter `i` does not depend on `x`. So
3169  the `while_loop` can increment the counter parallel to updates of `x`.
3170  However, because the loop counter at one loop iteration depends
3171  on the value at the previous iteration, the loop counter itself cannot
3172  be incremented in parallel. Hence if we just want the final value of the
3173  counter (which we print on the line `print(sess.run(i))`), then
3174  `x` will never be incremented, but the counter will be updated on a
3175  single thread. Conversely, if we want the value of the output (which we
3176  print on the line `print(sess.run(out).shape)`), then the counter may be
3177  incremented on its own thread, while `x` can be incremented in
3178  parallel on a separate thread. In the extreme case, it is conceivable
3179  that the thread incrementing the counter runs until completion before
3180  `x` is incremented even a single time. The only thing that can never
3181  happen is that the thread updating `x` can never get ahead of the
3182  counter thread because the thread incrementing `x` depends on the value
3183  of the counter.
3184
3185  ```python
3186  import tensorflow as tf
3187
3188  n = 10000
3189  x = tf.constant(list(range(n)))
3190  c = lambda i, x: i < n
3191  b = lambda i, x: (tf.Print(i + 1, [i]), tf.Print(x + 1, [i], "x:"))
3192  i, out = tf.while_loop(c, b, (0, x))
3193  with tf.Session() as sess:
3194      print(sess.run(i))  # prints [0] ... [9999]
3195
3196      # The following line may increment the counter and x in parallel.
3197      # The counter thread may get ahead of the other thread, but not the
3198      # other way around. So you may see things like
3199      # [9996] x:[9987]
3200      # meaning that the counter thread is on iteration 9996,
3201      # while the other thread is on iteration 9987
3202      print(sess.run(out).shape)
3203  ```
3204
3205  """
3206  return while_loop(
3207      cond=cond,
3208      body=body,
3209      loop_vars=loop_vars,
3210      shape_invariants=shape_invariants,
3211      parallel_iterations=parallel_iterations,
3212      back_prop=back_prop,
3213      swap_memory=swap_memory,
3214      name=name,
3215      maximum_iterations=maximum_iterations,
3216      return_same_structure=True)
3217
3218
3219# pylint: disable=redefined-outer-name
3220@tf_export(v1=["while_loop"])
3221def while_loop(cond,
3222               body,
3223               loop_vars,
3224               shape_invariants=None,
3225               parallel_iterations=10,
3226               back_prop=True,
3227               swap_memory=False,
3228               name=None,
3229               maximum_iterations=None,
3230               return_same_structure=False):
3231  """Repeat `body` while the condition `cond` is true.
3232
3233  `cond` is a callable returning a boolean scalar tensor. `body` is a callable
3234  returning a (possibly nested) tuple, namedtuple or list of tensors of the same
3235  arity (length and structure) and types as `loop_vars`. `loop_vars` is a
3236  (possibly nested) tuple, namedtuple or list of tensors that is passed to both
3237  `cond` and `body`. `cond` and `body` both take as many arguments as there are
3238  `loop_vars`.
3239
3240  In addition to regular Tensors or IndexedSlices, the body may accept and
3241  return TensorArray objects.  The flows of the TensorArray objects will
3242  be appropriately forwarded between loops and during gradient calculations.
3243
3244  Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
3245  call to `while_loop`, and not at all during `Session.run()`). `while_loop`
3246  stitches together the graph fragments created during the `cond` and `body`
3247  calls with some additional graph nodes to create the graph flow that
3248  repeats `body` until `cond` returns false.
3249
3250  For correctness, `tf.while_loop()` strictly enforces shape invariants for
3251  the loop variables. A shape invariant is a (possibly partial) shape that
3252  is unchanged across the iterations of the loop. An error will be raised
3253  if the shape of a loop variable after an iteration is determined to be more
3254  general than or incompatible with its shape invariant. For example, a shape
3255  of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
3256  compatible with [11, 17]. By default (if the argument `shape_invariants` is
3257  not specified), it is assumed that the initial shape of each tensor in
3258  `loop_vars` is the same in every iteration. The `shape_invariants` argument
3259  allows the caller to specify a less specific shape invariant for each loop
3260  variable, which is needed if the shape varies between iterations. The
3261  `tf.Tensor.set_shape`
3262  function may also be used in the `body` function to indicate that
3263  the output loop variable has a particular shape. The shape invariant for
3264  SparseTensor and IndexedSlices are treated specially as follows:
3265
3266  a) If a loop variable is a SparseTensor, the shape invariant must be
3267  TensorShape([r]) where r is the rank of the dense tensor represented
3268  by the sparse tensor. It means the shapes of the three tensors of the
3269  SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
3270  is the shape of the SparseTensor.dense_shape property. It must be the shape of
3271  a vector.
3272
3273  b) If a loop variable is an IndexedSlices, the shape invariant must be
3274  a shape invariant of the values tensor of the IndexedSlices. It means
3275  the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
3276  [shape.ndims]).
3277
3278  `while_loop` implements non-strict semantics, enabling multiple iterations
3279  to run in parallel. The maximum number of parallel iterations can be
3280  controlled by `parallel_iterations`, which gives users some control over
3281  memory consumption and execution order. For correct programs, `while_loop`
3282  should return the same result for any parallel_iterations > 0.
3283
3284  For training, TensorFlow stores the tensors that are produced in the
3285  forward inference and are needed in back propagation. These tensors are a
3286  main source of memory consumption and often cause OOM errors when training
3287  on GPUs. When the flag swap_memory is true, we swap out these tensors from
3288  GPU to CPU. This for example allows us to train RNN models with very long
3289  sequences and large batches.
3290
3291  Args:
3292    cond: A callable that represents the termination condition of the loop.
3293    body: A callable that represents the loop body.
3294    loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
3295      `Tensor`, and `TensorArray` objects.
3296    shape_invariants: The shape invariants for the loop variables.
3297    parallel_iterations: The number of iterations allowed to run in parallel. It
3298      must be a positive integer.
3299    back_prop: Whether backprop is enabled for this while loop.
3300    swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
3301    name: Optional name prefix for the returned tensors.
3302    maximum_iterations: Optional maximum number of iterations of the while loop
3303      to run.  If provided, the `cond` output is AND-ed with an additional
3304      condition ensuring the number of iterations executed is no greater than
3305      `maximum_iterations`.
3306    return_same_structure: If True, output has same structure as `loop_vars`. If
3307      eager execution is enabled, this is ignored (and always treated as True).
3308
3309  Returns:
3310    The output tensors for the loop variables after the loop.
3311     If `return_same_structure` is True, the return value has the same
3312     structure as `loop_vars`.
3313     If `return_same_structure` is False, the return value is a Tensor,
3314     TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list
3315     otherwise.
3316
3317  Raises:
3318    TypeError: if `cond` or `body` is not callable.
3319    ValueError: if `loop_vars` is empty.
3320
3321  Example:
3322
3323  ```python
3324  i = tf.constant(0)
3325  c = lambda i: tf.less(i, 10)
3326  b = lambda i: tf.add(i, 1)
3327  r = tf.while_loop(c, b, [i])
3328  ```
3329
3330  Example with nesting and a namedtuple:
3331
3332  ```python
3333  import collections
3334  Pair = collections.namedtuple('Pair', 'j, k')
3335  ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
3336  c = lambda i, p: i < 10
3337  b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
3338  ijk_final = tf.while_loop(c, b, ijk_0)
3339  ```
3340
3341  Example using shape_invariants:
3342
3343  ```python
3344  i0 = tf.constant(0)
3345  m0 = tf.ones([2, 2])
3346  c = lambda i, m: i < 10
3347  b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
3348  tf.while_loop(
3349      c, b, loop_vars=[i0, m0],
3350      shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
3351  ```
3352
3353  Example which demonstrates non-strict semantics: In the following
3354  example, the final value of the counter `i` does not depend on `x`. So
3355  the `while_loop` can increment the counter parallel to updates of `x`.
3356  However, because the loop counter at one loop iteration depends
3357  on the value at the previous iteration, the loop counter itself cannot
3358  be incremented in parallel. Hence if we just want the final value of the
3359  counter (which we print on the line `print(sess.run(i))`), then
3360  `x` will never be incremented, but the counter will be updated on a
3361  single thread. Conversely, if we want the value of the output (which we
3362  print on the line `print(sess.run(out).shape)`), then the counter may be
3363  incremented on its own thread, while `x` can be incremented in
3364  parallel on a separate thread. In the extreme case, it is conceivable
3365  that the thread incrementing the counter runs until completion before
3366  `x` is incremented even a single time. The only thing that can never
3367  happen is that the thread updating `x` can never get ahead of the
3368  counter thread because the thread incrementing `x` depends on the value
3369  of the counter.
3370
3371  ```python
3372  import tensorflow as tf
3373
3374  n = 10000
3375  x = tf.constant(list(range(n)))
3376  c = lambda i, x: i < n
3377  b = lambda i, x: (tf.Print(i + 1, [i]), tf.Print(x + 1, [i], "x:"))
3378  i, out = tf.while_loop(c, b, (0, x))
3379  with tf.Session() as sess:
3380      print(sess.run(i))  # prints [0] ... [9999]
3381
3382      # The following line may increment the counter and x in parallel.
3383      # The counter thread may get ahead of the other thread, but not the
3384      # other way around. So you may see things like
3385      # [9996] x:[9987]
3386      # meaning that the counter thread is on iteration 9996,
3387      # while the other thread is on iteration 9987
3388      print(sess.run(out).shape)
3389  ```
3390
3391  """
3392  # Always enable control flow v2 if building a function, regardless of toggle.
3393  if (util.EnableControlFlowV2(ops.get_default_graph()) and
3394      not context.executing_eagerly()):
3395    return while_v2.while_loop(
3396        cond,
3397        body,
3398        loop_vars,
3399        shape_invariants=shape_invariants,
3400        parallel_iterations=parallel_iterations,
3401        maximum_iterations=maximum_iterations,
3402        name=name,
3403        return_same_structure=return_same_structure)
3404
3405  with ops.name_scope(name, "while", loop_vars):
3406    if not loop_vars:
3407      raise ValueError("No loop variables provided")
3408    if not callable(cond):
3409      raise TypeError("cond must be callable.")
3410    if not callable(body):
3411      raise TypeError("body must be callable.")
3412    if parallel_iterations < 1:
3413      raise TypeError("parallel_iterations must be a positive integer.")
3414
3415    if maximum_iterations is not None:
3416      maximum_iterations = ops.convert_to_tensor(
3417          maximum_iterations, name="maximum_iterations")
3418      if maximum_iterations.shape.ndims != 0:
3419        raise ValueError("maximum_iterations must be a scalar, saw shape: %s" %
3420                         maximum_iterations.shape)
3421
3422      counter = constant_op.constant(
3423          0, dtype=maximum_iterations.dtype, name="iteration_counter")
3424      orig_cond = cond
3425      orig_body = body
3426      if len(loop_vars) == 1:
3427        loop_vars = (counter, loop_vars[0])
3428        cond = lambda i, lv: (  # pylint: disable=g-long-lambda
3429            math_ops.logical_and(i < maximum_iterations, orig_cond(lv)))
3430        body = lambda i, lv: (i + 1, orig_body(lv))
3431      else:
3432        loop_vars = (counter, loop_vars)
3433        cond = lambda i, lv: (  # pylint: disable=g-long-lambda
3434            math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
3435        body = lambda i, lv: (i + 1, orig_body(*lv))
3436
3437    if context.executing_eagerly():
3438      try_to_pack = len(loop_vars) == 1
3439      packed = False  # whether the body result was packed into a 1-item tuple
3440
3441      while cond(*loop_vars):
3442        loop_vars = body(*loop_vars)
3443        if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
3444          packed = True
3445          loop_vars = (loop_vars,)
3446
3447      def convert(x):
3448        if isinstance(x, tensor_array_ops.TensorArray):
3449          return x
3450        return ops.convert_to_tensor(x)
3451      loop_vars = nest.map_structure(convert, loop_vars)
3452      if maximum_iterations is not None:
3453        return loop_vars[1]
3454      else:
3455        return loop_vars[0] if packed else loop_vars
3456
3457    if shape_invariants is not None:
3458      if maximum_iterations is not None:
3459        shape_invariants = (tensor_shape.TensorShape([]), shape_invariants)
3460
3461      nest.assert_same_structure(loop_vars, shape_invariants,
3462                                 expand_composites=False)
3463      shape_invariants = nest.map_structure(
3464          _get_shape_invariant, loop_vars, shape_invariants,
3465          expand_composites=False)
3466
3467    loop_context = WhileContext(
3468        maximum_iterations=maximum_iterations,
3469        parallel_iterations=parallel_iterations,
3470        back_prop=back_prop,
3471        swap_memory=swap_memory)
3472    # Only add non-nested loops to the collection. Any nested control flow will
3473    # be encapsulated in the root context.
3474    if loop_context.outer_context is None:
3475      ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
3476    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
3477                                    return_same_structure)
3478    if maximum_iterations is not None:
3479      return result[1]
3480    else:
3481      return result
3482
3483
3484# pylint: enable=redefined-outer-name
3485
3486
3487def _AsTensorList(x, p):
3488  """Return x as a list of Tensors or IndexedSlices.
3489
3490  For entries of `x` that are Operations, this returns an Identity of `p`
3491  with a dependency on the operation.
3492
3493  Args:
3494    x: A Tensor/IndexedSlices/Operation or a list or tuple of them.
3495    p: A Tensor to return for entries in `x` that are Operations.
3496
3497  Returns:
3498    A list of Tensors or IndexedSlices.
3499  """
3500  if not isinstance(x, (list, _basetuple)):
3501    x = [x]
3502
3503  l = []
3504  for v in x:
3505    if isinstance(v, ops.Operation):
3506      v = with_dependencies([v], p)
3507    v = ops.convert_to_tensor_or_composite(v)
3508    if isinstance(v, ops.Tensor):
3509      l.append(array_ops.identity(v))
3510    else:
3511      l.append(
3512          ops.IndexedSlices(
3513              array_ops.identity(v.values), array_ops.identity(v.indices)))
3514  return l
3515
3516
3517def _CheckResults(a, b):
3518  assert len(a) == len(b), (
3519      "Values returned by a() and b() must have the same length.")
3520  for x, y in zip(a, b):
3521    assert x.dtype == y.dtype, (
3522        "Values returned by a() [%s] and b() [%s] must have "
3523        "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name))
3524
3525
3526def with_dependencies(dependencies, output_tensor, name=None):
3527  """Produces the content of `output_tensor` only after `dependencies`.
3528
3529  In some cases, a user may want the output of an operation to be
3530  consumed externally only after some other dependencies have run
3531  first. This function ensures returns `output_tensor`, but only after all
3532  operations in `dependencies` have run. Note that this means that there is
3533  no guarantee that `output_tensor` will be evaluated after any `dependencies`
3534  have run.
3535
3536  See also `tf.tuple` and `tf.group`.
3537
3538  Args:
3539    dependencies: Iterable of operations to run before this op finishes.
3540    output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
3541    name: (Optional) A name for this operation.
3542
3543  Returns:
3544    Same as `output_tensor`.
3545
3546  Raises:
3547    TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
3548  """
3549  if context.executing_eagerly():
3550    return output_tensor
3551  with ops.name_scope(name, "control_dependency",
3552                      list(dependencies) + [output_tensor]) as name:
3553    with ops.colocate_with(output_tensor):
3554      with ops.control_dependencies(dependencies):
3555        output_tensor = ops.convert_to_tensor_or_composite(output_tensor)
3556        if isinstance(output_tensor, ops.Tensor):
3557          return _Identity(output_tensor, name=name)
3558        else:
3559          return ops.IndexedSlices(
3560              _Identity(output_tensor.values, name=name), output_tensor.indices,
3561              output_tensor.dense_shape)
3562
3563
3564def _GroupControlDeps(dev, deps, name=None):
3565  with ops.control_dependencies(deps):
3566    if dev is None:
3567      return no_op(name=name)
3568    else:
3569      with ops.device(dev):
3570        return no_op(name=name)
3571
3572
3573# TODO(touts): Accept "inputs" as a list.
3574@tf_export("group")
3575def group(*inputs, **kwargs):
3576  """Create an op that groups multiple operations.
3577
3578  When this op finishes, all ops in `inputs` have finished. This op has no
3579  output.
3580
3581  See also `tf.tuple` and
3582  `tf.control_dependencies`.
3583
3584  Args:
3585    *inputs: Zero or more tensors to group.
3586    name: A name for this operation (optional).
3587
3588  Returns:
3589    An Operation that executes all its inputs.
3590
3591  Raises:
3592    ValueError: If an unknown keyword argument is provided.
3593  """
3594  if context.executing_eagerly():
3595    return None
3596  name = kwargs.pop("name", None)
3597  if kwargs:
3598    raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys()))
3599  with ops.name_scope(name, "group_deps", inputs) as name:
3600    # Grouping no inputs means do nothing
3601    if not inputs:
3602      return no_op(name=name)
3603
3604    # Sorts *inputs according to their devices.
3605    ops_on_device = {}  # device -> operations specified on the device.
3606    for inp in nest.flatten(inputs, expand_composites=True):
3607      if not hasattr(inp, "device"):
3608        raise TypeError("Expected tf.group() expected Tensor arguments not "
3609                        "'%s' with type '%s'" % (inp, type(inp)))
3610      dev = inp.device
3611      if dev in ops_on_device:
3612        ops_on_device[dev].append(inp)
3613      else:
3614        ops_on_device[dev] = [inp]
3615    if len(ops_on_device) == 1:
3616      # 1-level tree. The root node is the returned NoOp node.
3617      (dev, deps), = ops_on_device.items()
3618      return _GroupControlDeps(dev, deps, name=name)
3619
3620    # 2-level tree. The root node is the returned NoOp node.
3621    # deps contains 1 NoOp node for each device.
3622    deps = []
3623
3624    def device_key(dev):
3625      """A sort key that allows None to be compared to strings."""
3626      return "" if dev is None else dev
3627
3628    for dev in sorted(ops_on_device, key=device_key):
3629      deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
3630
3631    with ops.control_dependencies(deps):
3632      return no_op(name=name)
3633
3634
3635@tf_export("tuple", v1=[])
3636def tuple_v2(tensors, control_inputs=None, name=None):
3637  """Group tensors together.
3638
3639  This creates a tuple of tensors with the same values as the `tensors`
3640  argument, except that the value of each tensor is only returned after the
3641  values of all tensors have been computed.
3642
3643  `control_inputs` contains additional ops that have to finish before this op
3644  finishes, but whose outputs are not returned.
3645
3646  This can be used as a "join" mechanism for parallel computations: all the
3647  argument tensors can be computed in parallel, but the values of any tensor
3648  returned by `tuple` are only available after all the parallel computations
3649  are done.
3650
3651  See also `tf.group` and
3652  `tf.control_dependencies`.
3653
3654  Args:
3655    tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
3656    control_inputs: List of additional ops to finish before returning.
3657    name: (optional) A name to use as a `name_scope` for the operation.
3658
3659  Returns:
3660    Same as `tensors`.
3661
3662  Raises:
3663    ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
3664    TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
3665      objects.
3666
3667  """
3668  return tuple(tensors=tensors, name=name, control_inputs=control_inputs)  # pylint: disable=redefined-builtin
3669
3670
3671@tf_export(v1=["tuple"])
3672def tuple(tensors, name=None, control_inputs=None):  # pylint: disable=redefined-builtin
3673  """Group tensors together.
3674
3675  This creates a tuple of tensors with the same values as the `tensors`
3676  argument, except that the value of each tensor is only returned after the
3677  values of all tensors have been computed.
3678
3679  `control_inputs` contains additional ops that have to finish before this op
3680  finishes, but whose outputs are not returned.
3681
3682  This can be used as a "join" mechanism for parallel computations: all the
3683  argument tensors can be computed in parallel, but the values of any tensor
3684  returned by `tuple` are only available after all the parallel computations
3685  are done.
3686
3687  See also `tf.group` and
3688  `tf.control_dependencies`.
3689
3690  Args:
3691    tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
3692    name: (optional) A name to use as a `name_scope` for the operation.
3693    control_inputs: List of additional ops to finish before returning.
3694
3695  Returns:
3696    Same as `tensors`.
3697
3698  Raises:
3699    ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
3700    TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
3701      objects.
3702
3703  """
3704  if context.executing_eagerly():
3705    return tensors
3706  with ops.name_scope(name, "tuple", tensors) as name:
3707    tensors = [
3708        t if (isinstance(t, ops.Operation) or tensor_util.is_tensor(t) or
3709              t is None) else ops.convert_to_tensor(t) for t in tensors
3710    ]
3711    gating_ops = [
3712        t if isinstance(t, ops.Operation) else t.op
3713        for t in tensors
3714        if t is not None
3715    ]
3716    if control_inputs:
3717      for c in control_inputs:
3718        if isinstance(c, ops.Tensor):
3719          c = c.op
3720        elif not isinstance(c, ops.Operation):
3721          raise TypeError("Control input must be Operation or Tensor: %s" % c)
3722        gating_ops.append(c)
3723    # Note that in order to ensure ordering in the pbtxt, we must take care to
3724    # ensure the order here.
3725    gating_ops = sorted(set(gating_ops), key=lambda op: op._id)  # Uniquify ops.
3726    if not gating_ops:
3727      raise ValueError("Must have at least one Tensor: %s" % tensors)
3728    gate = group(*gating_ops)
3729    tpl = []
3730    for t in tensors:
3731      if tensor_util.is_tensor(t):
3732        tpl.append(with_dependencies([gate], t))
3733      elif isinstance(t, ops.Operation):
3734        with ops.control_dependencies([gate]):
3735          tpl.append(group(t))
3736      else:
3737        tpl.append(None)
3738    return tpl
3739
3740
3741def _assert_at_most_n_true(predicates, n, msg):
3742  """Returns an Assert op that checks that at most n predicates are True.
3743
3744  Args:
3745    predicates: list of bool scalar tensors.
3746    n: maximum number of true predicates allowed.
3747    msg: Error message.
3748  """
3749  preds_c = array_ops.stack(predicates, name="preds_c")
3750  num_true_conditions = math_ops.reduce_sum(
3751      math_ops.cast(preds_c, dtypes.int32), name="num_true_conds")
3752  condition = math_ops.less_equal(num_true_conditions,
3753                                  constant_op.constant(n, name="n_true_conds"))
3754  preds_names = ", ".join(getattr(p, "name", "?") for p in predicates)
3755  error_msg = [
3756      "%s: more than %d conditions (%s) evaluated as True:" %
3757      (msg, n, preds_names), preds_c
3758  ]
3759  return Assert(condition, data=error_msg, summarize=len(predicates))
3760
3761
3762def _case_create_default_action(predicates, actions):
3763  """Creates default action for a list of actions and their predicates.
3764
3765  It uses the input actions to select an arbitrary as default and makes sure
3766  that corresponding predicates have valid values.
3767
3768  Args:
3769    predicates: a list of bool scalar tensors
3770    actions: a list of callable objects which return tensors.
3771
3772  Returns:
3773    a callable
3774  """
3775  k = len(predicates) - 1  # could pick any
3776  predicate, action = predicates[k], actions[k]
3777  other_predicates, other_actions = predicates[:k], actions[:k]
3778
3779  def default_action():
3780    others_msg = ("Implementation error: "
3781                  "selected default action #%d was called, but some of other "
3782                  "predicates are True: " % k)
3783    default_msg = ("Input error: "
3784                   "None of conditions evaluated as True:",
3785                   array_ops.stack(predicates, name="preds_c"))
3786    with ops.control_dependencies([
3787        _assert_at_most_n_true(other_predicates, n=0, msg=others_msg),
3788        Assert(predicate, data=default_msg)
3789    ]):
3790      return action()
3791
3792  return default_action, other_predicates, other_actions
3793
3794
3795def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name,
3796                                       allow_python_preds):
3797  """Verifies input arguments for the case function.
3798
3799  Args:
3800    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
3801      callable which returns a list of tensors.
3802    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3803    name: A name for the case operation.
3804    allow_python_preds: if true, pred_fn_pairs may contain Python bools in
3805      addition to boolean Tensors
3806
3807  Raises:
3808    TypeError: If `pred_fn_pairs` is not a list/dictionary.
3809    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3810    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3811               callable.
3812
3813  Returns:
3814    a tuple <list of scalar bool tensors, list of callables>.
3815  """
3816  if not isinstance(pred_fn_pairs, (list, _basetuple, dict)):
3817    raise TypeError("fns must be a list, tuple, or dict")
3818
3819  if isinstance(pred_fn_pairs, collections.OrderedDict):
3820    pred_fn_pairs = pred_fn_pairs.items()
3821  elif isinstance(pred_fn_pairs, dict):
3822    if context.executing_eagerly():
3823      # No name to sort on in eager mode. Use dictionary traversal order,
3824      # which is nondeterministic in versions of Python < 3.6
3825      if not exclusive:
3826        raise ValueError("Unordered dictionaries are not supported for the "
3827                         "`pred_fn_pairs` argument when `exclusive=False` and "
3828                         "eager mode is enabled.")
3829      pred_fn_pairs = list(pred_fn_pairs.items())
3830    else:
3831      pred_fn_pairs = sorted(
3832          pred_fn_pairs.items(), key=lambda item: item[0].name)
3833      if not exclusive:
3834        logging.warn(
3835            "%s: An unordered dictionary of predicate/fn pairs was "
3836            "provided, but exclusive=False. The order of conditional "
3837            "tests is deterministic but not guaranteed.", name)
3838  for pred_fn_pair in pred_fn_pairs:
3839    if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2:
3840      raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple")
3841    pred, fn = pred_fn_pair
3842
3843    if isinstance(pred, ops.Tensor):
3844      if pred.dtype != dtypes.bool:
3845        raise TypeError("pred must be Tensor of type bool: %s" % pred.name)
3846    elif not allow_python_preds:
3847      raise TypeError("pred must be a Tensor, got: %s" % pred)
3848    elif not isinstance(pred, bool):
3849      raise TypeError("pred must be a Tensor or bool, got: %s" % pred)
3850
3851    if not callable(fn):
3852      raise TypeError("fn for pred %s must be callable." % pred.name)
3853
3854  predicates, actions = zip(*pred_fn_pairs)
3855  return predicates, actions
3856
3857
3858def _case_helper(cond_fn,
3859                 pred_fn_pairs,
3860                 default,
3861                 exclusive,
3862                 name,
3863                 allow_python_preds=False,
3864                 **cond_kwargs):
3865  """Implementation of case that allows for different cond functions.
3866
3867  Args:
3868    cond_fn: method that has signature and semantics of `cond` above.
3869    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
3870      callable which returns a list of tensors.
3871    default: Optional callable that returns a list of tensors.
3872    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3873    name: A name for this operation (optional).
3874    allow_python_preds: if true, pred_fn_pairs may contain Python bools in
3875      addition to boolean Tensors
3876    **cond_kwargs: keyword arguments that will be passed to `cond_fn`.
3877
3878  Returns:
3879    The tensors returned by the first pair whose predicate evaluated to True, or
3880    those returned by `default` if none does.
3881
3882  Raises:
3883    TypeError: If `pred_fn_pairs` is not a list/dictionary.
3884    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3885    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3886               callable.
3887  """
3888  predicates, actions = _case_verify_and_canonicalize_args(
3889      pred_fn_pairs, exclusive, name, allow_python_preds)
3890  with ops.name_scope(name, "case", [predicates]):
3891    if default is None:
3892      default, predicates, actions = _case_create_default_action(
3893          predicates, actions)
3894    fn = default
3895    # To eval conditions in direct order we create nested conditions in reverse:
3896    #   cond_fn(c[0], true_fn=.., false_fn=cond_fn(c[1], ...))
3897    for predicate, action in reversed(list(zip(predicates, actions))):
3898      fn = functools.partial(
3899          cond_fn, predicate, true_fn=action, false_fn=fn, **cond_kwargs)
3900    if exclusive:
3901      with ops.control_dependencies([
3902          _assert_at_most_n_true(
3903              predicates, n=1, msg="Input error: exclusive=True")
3904      ]):
3905        return fn()
3906    else:
3907      return fn()
3908
3909
3910@tf_export("case")
3911def case(pred_fn_pairs,
3912         default=None,
3913         exclusive=False,
3914         strict=False,
3915         name="case"):
3916  """Create a case operation.
3917
3918  The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
3919  Each pair contains a boolean scalar tensor and a python callable that
3920  creates the tensors to be returned if the boolean evaluates to True.
3921  `default` is a callable generating a list of tensors. All the callables
3922  in `pred_fn_pairs` as well as `default` (if provided) should return the same
3923  number and types of tensors.
3924
3925  If `exclusive==True`, all predicates are evaluated, and an exception is
3926  thrown if more than one of the predicates evaluates to `True`.
3927  If `exclusive==False`, execution stops at the first predicate which
3928  evaluates to True, and the tensors generated by the corresponding function
3929  are returned immediately. If none of the predicates evaluate to True, this
3930  operation returns the tensors generated by `default`.
3931
3932  `tf.case` supports nested structures as implemented in
3933  `tf.contrib.framework.nest`. All of the callables must return the same
3934  (possibly nested) value structure of lists, tuples, and/or named tuples.
3935  Singleton lists and tuples form the only exceptions to this: when returned by
3936  a callable, they are implicitly unpacked to single values. This
3937  behavior is disabled by passing `strict=True`.
3938
3939  If an unordered dictionary is used for `pred_fn_pairs`, the order of the
3940  conditional tests is not guaranteed. However, the order is guaranteed to be
3941  deterministic, so that variables created in conditional branches are created
3942  in fixed order across runs.
3943
3944  @compatibility{eager}
3945  Unordered dictionaries are not supported in eager mode when `exclusive=False`.
3946  Use a list of tuples instead.
3947  @end_compatibility
3948
3949
3950  **Example 1:**
3951
3952  Pseudocode:
3953
3954  ```
3955  if (x < y) return 17;
3956  else return 23;
3957  ```
3958
3959  Expressions:
3960
3961  ```python
3962  f1 = lambda: tf.constant(17)
3963  f2 = lambda: tf.constant(23)
3964  r = tf.case([(tf.less(x, y), f1)], default=f2)
3965  ```
3966
3967  **Example 2:**
3968
3969  Pseudocode:
3970
3971  ```
3972  if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
3973  if (x < y) return 17;
3974  else if (x > z) return 23;
3975  else return -1;
3976  ```
3977
3978  Expressions:
3979
3980  ```python
3981  def f1(): return tf.constant(17)
3982  def f2(): return tf.constant(23)
3983  def f3(): return tf.constant(-1)
3984  r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
3985           default=f3, exclusive=True)
3986  ```
3987
3988  Args:
3989    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
3990      callable which returns a list of tensors.
3991    default: Optional callable that returns a list of tensors.
3992    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3993    strict: A boolean that enables/disables 'strict' mode; see above.
3994    name: A name for this operation (optional).
3995
3996  Returns:
3997    The tensors returned by the first pair whose predicate evaluated to True, or
3998    those returned by `default` if none does.
3999
4000  Raises:
4001    TypeError: If `pred_fn_pairs` is not a list/dictionary.
4002    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
4003    TypeError: If `fns[i]` is not callable for any i, or `default` is not
4004               callable.
4005  """
4006  return _case_helper(
4007      cond,
4008      pred_fn_pairs,
4009      default,
4010      exclusive,
4011      name,
4012      allow_python_preds=False,
4013      strict=strict)
4014
4015
4016class XLAControlFlowContext(ControlFlowContext):
4017  """Base class for XLA and TPU control flow contexts."""
4018
4019  def __init__(self):
4020    super(XLAControlFlowContext, self).__init__()
4021    self._name = "XLAControlFlowContext"
4022
4023  def to_control_flow_context_def(self, context_def, export_scope=None):
4024    # pylint: disable=useless-super-delegation
4025    # NOTE(slebedev): the method is required by `ControlFlowContext`.
4026    super(XLAControlFlowContext, self).to_control_flow_context_def(
4027        context_def, export_scope)
4028
4029  def IsXLAContext(self):
4030    return True
4031
4032  def AddOp(self, _):
4033    pass
4034
4035  def AddValue(self, x):
4036    return x
4037
4038
4039def from_control_flow_context_def(context_def, import_scope=None):
4040  """Deserializes `context_def` into the appropriate ControlFlowContext.
4041
4042  Args:
4043    context_def: ControlFlowContextDef proto
4044    import_scope: Optional `string`. Name scope to add.
4045
4046  Returns:
4047    A ControlFlowContext subclass
4048  """
4049  if context_def.HasField("cond_ctxt"):
4050    return CondContext.from_proto(
4051        context_def.cond_ctxt, import_scope=import_scope)
4052  if context_def.HasField("while_ctxt"):
4053    return WhileContext.from_proto(
4054        context_def.while_ctxt, import_scope=import_scope)
4055  raise NotImplementedError("Unknown ControlFlowContextDef field: %s" %
4056                            context_def.WhichOneof("ctxt"))
4057
4058
4059ops.register_proto_function(
4060    ops.GraphKeys.COND_CONTEXT,
4061    proto_type=control_flow_pb2.CondContextDef,
4062    to_proto=CondContext.to_proto,
4063    from_proto=CondContext.from_proto)
4064
4065ops.register_proto_function(
4066    ops.GraphKeys.WHILE_CONTEXT,
4067    proto_type=control_flow_pb2.WhileContextDef,
4068    to_proto=WhileContext.to_proto,
4069    from_proto=WhileContext.from_proto)
4070