1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Control Flow Operations.
16
17See the @{$python/control_flow_ops} guide.
18
19@@identity
20@@identity_n
21@@tuple
22@@group
23@@no_op
24@@count_up_to
25@@cond
26@@smart_cond
27@@case
28@@while_loop
29@@logical_and
30@@logical_not
31@@logical_or
32@@logical_xor
33@@equal
34@@not_equal
35@@less
36@@less_equal
37@@greater
38@@greater_equal
39@@where
40@@is_finite
41@@is_inf
42@@is_nan
43@@verify_tensor_all_finite
44@@check_numerics
45@@add_check_numerics_ops
46@@Assert
47@@Print
48"""
49# pylint: disable=g-bad-name
50from __future__ import absolute_import
51from __future__ import division
52from __future__ import print_function
53
54import abc
55import collections
56import functools
57
58import six
59
60from tensorflow.core.framework import attr_value_pb2
61from tensorflow.core.protobuf import control_flow_pb2
62from tensorflow.python.eager import context
63from tensorflow.python.framework import constant_op
64from tensorflow.python.framework import dtypes
65from tensorflow.python.framework import errors
66from tensorflow.python.framework import ops
67from tensorflow.python.framework import sparse_tensor
68from tensorflow.python.framework import tensor_shape
69from tensorflow.python.framework import tensor_util
70from tensorflow.python.ops import array_ops
71from tensorflow.python.ops import control_flow_util as util
72from tensorflow.python.ops import gen_array_ops
73from tensorflow.python.ops import gen_control_flow_ops
74from tensorflow.python.ops import gen_data_flow_ops
75from tensorflow.python.ops import gen_logging_ops
76from tensorflow.python.ops import math_ops
77from tensorflow.python.ops import tensor_array_ops
78# go/tf-wildcard-import
79# pylint: disable=wildcard-import,undefined-variable
80from tensorflow.python.ops.gen_control_flow_ops import *
81# pylint: enable=wildcard-import
82from tensorflow.python.platform import tf_logging as logging
83from tensorflow.python.util import compat
84from tensorflow.python.util import deprecation
85from tensorflow.python.util import nest
86from tensorflow.python.util import tf_should_use
87from tensorflow.python.util.tf_export import tf_export
88
89# We override the 'tuple' for a control flow op, so we keep python's
90# existing 'tuple' for later use in this module.
91_basetuple = tuple
92
93
94def _summarize_eager(tensor, summarize=None):
95  """Returns a summarized string representation of eager `tensor`.
96
97  Args:
98    tensor: EagerTensor to summarize
99    summarize: Include these many first elements of `array`
100  """
101  # reshape((-1,)) is the fastest way to get a flat array view
102  if tensor._rank():  # pylint: disable=protected-access
103    flat = tensor.numpy().reshape((-1,))
104    lst = [str(x) for x in flat[:summarize]]
105    if len(lst) < flat.size:
106      lst.append("...")
107  else:
108    # tensor.numpy() returns a scalar for zero dimensional arrays
109    if summarize != 0:
110      lst = [str(tensor.numpy())]
111    else:
112      lst = []
113
114  return ", ".join(lst)
115
116
117# pylint: disable=protected-access
118
119
120# Assert and Print are special symbols in python, so we must
121# use an upper-case version of them.
122@tf_export("Assert")
123@tf_should_use.should_use_result
124def Assert(condition, data, summarize=None, name=None):
125  """Asserts that the given condition is true.
126
127  If `condition` evaluates to false, print the list of tensors in `data`.
128  `summarize` determines how many entries of the tensors to print.
129
130  NOTE: In graph mode, to ensure that Assert executes, one usually attaches
131  a dependency:
132
133  ```python
134  # Ensure maximum element of x is smaller or equal to 1
135  assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x])
136  with tf.control_dependencies([assert_op]):
137    ... code using x ...
138  ```
139
140  Args:
141    condition: The condition to evaluate.
142    data: The tensors to print out when condition is false.
143    summarize: Print this many entries of each tensor.
144    name: A name for this operation (optional).
145
146  Returns:
147    assert_op: An `Operation` that, when executed, raises a
148    `tf.errors.InvalidArgumentError` if `condition` is not true.
149    @compatibility{eager} returns None.
150
151  Raises:
152    @compatibility{eager} `tf.errors.InvalidArgumentError` if `condition`
153    is not true
154  """
155  if context.in_eager_mode():
156    if not condition:
157      xs = ops.convert_n_to_tensor(data)
158      data_str = [_summarize_eager(x, summarize) for x in xs]
159      raise errors.InvalidArgumentError(
160          node_def=None,
161          op=None,
162          message="Expected '%s' to be true. Summarized data: %s" %
163          (condition, "\n".join(data_str)))
164    return
165
166  with ops.name_scope(name, "Assert", [condition, data]) as name:
167    xs = ops.convert_n_to_tensor(data)
168    if all([x.dtype in {dtypes.string, dtypes.int32} for x in xs]):
169      # As a simple heuristic, we assume that string and int32 are
170      # on host to avoid the need to use cond. If it is not case,
171      # we will pay the price copying the tensor to host memory.
172      return gen_logging_ops._assert(condition, data, summarize, name="Assert")
173    else:
174      condition = ops.convert_to_tensor(condition, name="Condition")
175
176      def true_assert():
177        return gen_logging_ops._assert(
178            condition, data, summarize, name="Assert")
179
180      guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
181      return guarded_assert.op
182
183
184def _Identity(data, name=None):
185  """Return a tensor with the same shape and contents as the input tensor.
186
187  Args:
188    data: A Tensor.
189    name: A name for this operation (optional).
190
191  Returns:
192    A Tensor with the same type and value as the input Tensor.
193  """
194  data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
195  if isinstance(data, ops.Tensor):
196    if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
197      return gen_array_ops._ref_identity(data, name=name)
198    else:
199      return array_ops.identity(data, name=name)
200  else:
201    if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
202      raise TypeError("Type %s not supported" % type(data))
203    values = _Identity(data.values, name=name)
204    indices = array_ops.identity(data.indices, name="indices")
205    if isinstance(data, ops.IndexedSlices):
206      dense_shape = data.dense_shape
207      if dense_shape is not None:
208        dense_shape = array_ops.identity(dense_shape, name="dense_shape")
209      return ops.IndexedSlices(values, indices, dense_shape)
210    else:
211      dense_shape = array_ops.identity(data.dense_shape, name="dense_shape")
212      return sparse_tensor.SparseTensor(indices, values, dense_shape)
213
214
215def _NextIteration(data, name=None):
216  data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
217  if isinstance(data, ops.Tensor):
218    if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
219      return ref_next_iteration(data, name=name)
220    else:
221      return next_iteration(data, name=name)
222  else:
223    if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
224      raise TypeError("Type %s not supported" % type(data))
225    values = _NextIteration(data.values, name=name)
226    indices = next_iteration(data.indices, name="indices")
227    if isinstance(data, ops.IndexedSlices):
228      dense_shape = data.dense_shape
229      if dense_shape is not None:
230        dense_shape = next_iteration(dense_shape, name="dense_shape")
231      return ops.IndexedSlices(values, indices, dense_shape)
232    else:
233      dense_shape = next_iteration(data.dense_shape, name="dense_shape")
234      return sparse_tensor.SparseTensor(indices, values, dense_shape)
235
236
237def _Enter(data,
238           frame_name,
239           is_constant=False,
240           parallel_iterations=10,
241           use_ref=True,
242           use_input_shape=True,
243           name=None):
244  """Creates or finds a child frame, and makes `data` available to it.
245
246  The unique `frame_name` is used by the `Executor` to identify frames. If
247  `is_constant` is true, `data` is a constant in the child frame; otherwise
248  it may be changed in the child frame. At most `parallel_iterations`
249  iterations are run in parallel in the child frame.
250
251  Args:
252    data: The tensor to be made available to the child frame.
253    frame_name: The name of the child frame.
254    is_constant: If true, the output is constant within the child frame.
255    parallel_iterations: The number of iterations allowed to run in parallel.
256    use_ref: If true, use ref_enter if data is of ref type.
257    name: A name for this operation (optional).
258
259  Returns:
260    The same tensor as `data`.
261  """
262  data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
263  if isinstance(data, ops.Tensor):
264    if data.dtype._is_ref_dtype and use_ref:  # pylint: disable=protected-access
265      result = gen_control_flow_ops._ref_enter(
266          data, frame_name, is_constant, parallel_iterations, name=name)
267    else:
268      result = gen_control_flow_ops._enter(
269          data, frame_name, is_constant, parallel_iterations, name=name)
270    if use_input_shape:
271      result.set_shape(data.get_shape())
272    return result
273  else:
274    if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
275      raise TypeError("Type %s not supported" % type(data))
276    values = _Enter(
277        data.values,
278        frame_name,
279        is_constant,
280        parallel_iterations=parallel_iterations,
281        use_input_shape=use_input_shape,
282        name=name)
283    indices = gen_control_flow_ops._enter(
284        data.indices,
285        frame_name,
286        is_constant,
287        parallel_iterations,
288        name="indices")
289    if use_input_shape:
290      indices.set_shape(data.indices.get_shape())
291    if isinstance(data, ops.IndexedSlices):
292      dense_shape = data.dense_shape
293      if dense_shape is not None:
294        dense_shape = gen_control_flow_ops._enter(
295            dense_shape,
296            frame_name,
297            is_constant,
298            parallel_iterations,
299            name="dense_shape")
300        if use_input_shape:
301          dense_shape.set_shape(data.dense_shape.get_shape())
302      return ops.IndexedSlices(values, indices, dense_shape)
303    else:
304      dense_shape = gen_control_flow_ops._enter(
305          data.dense_shape,
306          frame_name,
307          is_constant,
308          parallel_iterations,
309          name="dense_shape")
310      if use_input_shape:
311        dense_shape.set_shape(data.dense_shape.get_shape())
312      return sparse_tensor.SparseTensor(indices, values, dense_shape)
313
314
315def exit(data, name=None):  # pylint: disable=redefined-builtin
316  """Exits the current frame to its parent frame.
317
318  Exit makes its input `data` available to the parent frame.
319
320  Args:
321    data: The tensor to be made available to the parent frame.
322    name: A name for this operation (optional).
323
324  Returns:
325    The same tensor as `data`.
326  """
327  data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
328  if isinstance(data, ops.Tensor):
329    if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
330      return gen_control_flow_ops._ref_exit(data, name)
331    else:
332      return gen_control_flow_ops._exit(data, name)
333  else:
334    if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
335      raise TypeError("Type %s not supported" % type(data))
336    values = exit(data.values, name=name)
337    indices = gen_control_flow_ops._exit(data.indices, name="indices")
338    if isinstance(data, ops.IndexedSlices):
339      dense_shape = data.dense_shape
340      if dense_shape is not None:
341        dense_shape = gen_control_flow_ops._exit(dense_shape, name)
342      return ops.IndexedSlices(values, indices, dense_shape)
343    else:
344      dense_shape = gen_control_flow_ops._exit(data.dense_shape, name)
345      return sparse_tensor.SparseTensor(indices, values, dense_shape)
346
347
348def switch(data, pred, dtype=None, name=None):
349  """Forwards `data` to an output determined by `pred`.
350
351  If `pred` is false, the `data` input is forwarded to the first output.
352  Otherwise, the data goes to the second output.
353
354  This op handles `Tensor`s and `IndexedSlices`.
355
356  Args:
357    data: The tensor to be forwarded to the appropriate output.
358    pred: A scalar that specifies which output port will receive data.
359    dtype: Optional element type for the returned tensor. If missing,
360           the type is inferred from the type of `value`.
361    name: A name for this operation (optional).
362
363  Returns:
364    `(output_false, output_true)`: If `pred` is true, data will be forwarded
365    to `output_true`, otherwise it goes to `output_false`.
366  """
367  with ops.name_scope(name, "Switch", [data, pred]) as name:
368    data = ops.internal_convert_to_tensor_or_indexed_slices(
369        data, dtype=dtype, name="data", as_ref=True)
370    pred = ops.convert_to_tensor(pred, name="pred")
371    if isinstance(data, ops.Tensor):
372      return gen_control_flow_ops._switch(data, pred, name=name)
373    else:
374      if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
375        raise TypeError("Type %s not supported" % type(data))
376      val, ind = data.values, data.indices
377      val_f, val_t = gen_control_flow_ops._switch(val, pred, name=name)
378      ind_f, ind_t = gen_control_flow_ops._switch(ind, pred, name="indices")
379      if isinstance(data, ops.IndexedSlices):
380        dense_shape = data.dense_shape
381        if dense_shape is not None:
382          dense_shape_f, dense_shape_t = gen_control_flow_ops._switch(
383              dense_shape, pred, name="dense_shape")
384        else:
385          dense_shape_f, dense_shape_t = None, None
386        return (ops.IndexedSlices(val_f, ind_f, dense_shape_f),
387                ops.IndexedSlices(val_t, ind_t, dense_shape_t))
388      else:
389        dense_shape = data.dense_shape
390        dense_shape_f, dense_shape_t = gen_control_flow_ops._switch(
391            data.dense_shape, pred, name="dense_shape")
392        return (sparse_tensor.SparseTensor(ind_f, val_f, dense_shape_f),
393                sparse_tensor.SparseTensor(ind_t, val_t, dense_shape_t))
394
395
396def _SwitchRefOrTensor(data, pred, name="Switch"):
397  """Forwards `data` to an output determined by `pred`.
398
399  If `pred` is false, the `data` input is forwarded to the first output.
400  Otherwise, the data goes to the second output.
401
402  This op handles `Tensor`s and `IndexedSlices`.
403
404  Args:
405    data: The tensor to be forwarded to the appropriate output.
406    pred: A scalar that specifies which output port will receive data.
407    name: A name for this operation (optional).
408
409  Returns:
410    `(output_false, output_true)`: If `pred` is true, data will be forwarded to
411    `output_true`, otherwise it goes to `output_false`.
412
413  Raises:
414    TypeError: if data is not a Tensor or IndexedSlices
415  """
416  data = ops.convert_to_tensor_or_indexed_slices(data, name="data")
417  # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
418  # addresses the following scenario.
419  #
420  # Assume you execute Optimizer.apply_gradients() in a branch of a cond().
421  #
422  # 1. The update op is created inside a `with ops.colocate(var):` block
423  #
424  # 2. Some tensor `data` is captured and a switch is created in a
425  #    `with ops.colocate_with(data):` block.
426  #
427  # with ops.colocate_with(var):
428  #  with ops.colocate_with(data):
429  #    op = ...
430  #
431  # var and data may be pinned to different devices, so we want to ops
432  # created within ops.colocate_with(data) to ignore the existing stack.
433  with ops.colocate_with(data, ignore_existing=True):
434    if isinstance(data, ops.Tensor):
435      if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
436        return ref_switch(data, pred, name=name)
437    return switch(data, pred, name=name)
438
439
440def merge(inputs, name=None):
441  """Returns the value of an available element of `inputs`.
442
443  This op tests each of the tensors in `inputs` in turn to determine if any of
444  them is available. If it finds an available tensor, it returns it and its
445  index in `inputs`.
446
447  It is an error if more than one tensor in `inputs` is available. If no tensor
448  in `inputs` is available, the returned tensor and index are not set.
449
450  This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
451  `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
452  before merging.
453
454  Args:
455    inputs: The input tensors, at most one of which is available.
456    name: A name for this operation (optional).
457
458  Returns:
459    A tuple containing the chosen input tensor and its index in `inputs`.
460
461  Raises:
462    ValueError: If any of the inputs is None, or inputs are IndexedSlices and
463      some but not all have a dense_shape property.
464  """
465  if any([inp is None for inp in inputs]):
466    raise ValueError("At least one of the merge inputs is None: %s" % inputs)
467  with ops.name_scope(name, "Merge", inputs) as name:
468    inputs = [
469        ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref=True)
470        for inp in inputs
471    ]
472    if all([isinstance(v, ops.Tensor) for v in inputs]):
473      if all([v.dtype._is_ref_dtype for v in inputs]):  # pylint: disable=protected-access
474        return gen_control_flow_ops._ref_merge(inputs, name)
475      else:
476        return gen_control_flow_ops._merge(inputs, name)
477    elif all([isinstance(v, sparse_tensor.SparseTensor) for v in inputs]):
478      # Only handle the case when all inputs are SparseTensor.
479      values, _ = merge([inp.values for inp in inputs], name=name)
480      indices, chosen_index = gen_control_flow_ops._merge(
481          [inp.indices for inp in inputs], name="indices")
482      dense_shape, _ = gen_control_flow_ops._merge(
483          [inp.dense_shape for inp in inputs], name="dense_shape")
484      return (sparse_tensor.SparseTensor(indices, values, dense_shape),
485              chosen_index)
486    else:
487      # For now convert all the inputs as IndexedSlices.
488      inputs = math_ops._as_indexed_slices_list(inputs, optimize=False)
489      values, _ = merge([inp.values for inp in inputs], name=name)
490      indices, chosen_index = gen_control_flow_ops._merge(
491          [inp.indices for inp in inputs], name="indices")
492      if any(inp.dense_shape is not None for inp in inputs):
493        if any(inp.dense_shape is None for inp in inputs):
494          raise ValueError("Either all merged IndexedSlices must have a "
495                           "dense_shape, or none must have a dense_shape.")
496        dense_shape, _ = gen_control_flow_ops._merge(
497            [inp.dense_shape for inp in inputs], name="dense_shape")
498      else:
499        dense_shape = None
500      return ops.IndexedSlices(values, indices, dense_shape), chosen_index
501
502
503# pylint: enable=protected-access
504
505
506def _convert_tensorarray_to_flow(tensor_or_tensor_array):
507  if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):
508    return tensor_or_tensor_array.flow
509  else:
510    return tensor_or_tensor_array
511
512
513def _make_tensor_array(ta, t_or_flow):
514  # pylint: disable=protected-access
515  new_ta = tensor_array_ops.TensorArray(
516      dtype=ta.dtype,
517      handle=ta.handle,
518      flow=t_or_flow,
519      infer_shape=ta._infer_shape,
520      colocate_with_first_write_call=ta._colocate_with_first_write_call)
521  new_ta._colocate_with = ta._colocate_with
522  new_ta._element_shape = ta._element_shape
523  # pylint: enable=protected-access
524  return new_ta
525
526
527def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows):
528  if len(tensors_or_tensorarrays) != len(tensors_or_flows):
529    raise ValueError(
530        "Lengths of original Tensor list and new list do not match: %d vs. %d" %
531        (len(tensors_or_tensorarrays), len(tensors_or_flows)))
532  return [
533      _make_tensor_array(ta, t_or_flow)
534      if isinstance(ta, tensor_array_ops.TensorArray) else t_or_flow
535      for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)
536  ]
537
538
539def _ShapeLessThanOrEqual(shape1, shape2):
540  if shape2.dims is None:
541    return True
542  if shape1.ndims != shape2.ndims:
543    return False
544  for dim1, dim2 in zip(shape1.dims, shape2.dims):
545    if dim2.value is not None and dim1.value != dim2.value:
546      return False
547  return True
548
549
550def _SetShapeInvariants(input_vars, enter_vars, shapes):
551  """Set the shapes of the tensors in `enter_vars` to `shapes`.
552
553  Args:
554    input_vars: A list of tensors that are inputs to `enter_vars`.
555    enter_vars: A list of tensors whose shapes will be set.
556    shapes: A (possibly nested) list of shapes.
557
558  Raises:
559    ValueError: If any tensor in `enter_vars` has a less specific shape
560      than its corresponding shape in `shapes`.
561  """
562  if shapes is None:
563    return
564  flat_shapes = nest.flatten(shapes)
565  if not all([isinstance(s, tensor_shape.TensorShape) for s in flat_shapes]):
566    raise ValueError("`shapes` must be a (possibly nested) list of shapes.")
567  # Check that the shapes of the inputs are less than the shape invariants,
568  # and set the shapes of `enter_vars` to the shape invariants.
569  for inp, var, shape in zip(input_vars, enter_vars, flat_shapes):
570    if isinstance(var, ops.Tensor):
571      if not _ShapeLessThanOrEqual(inp.get_shape(), shape):
572        raise ValueError(
573            "The shape invariant specified for %s is not compatible with "
574            "the initial shape of the loop variable. It enters the loop "
575            "with shape %s, but the specified shape invariant is %s." %
576            (inp.name, inp.get_shape(), shape))
577      var.set_shape(shape)
578    else:
579      if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
580        raise TypeError("Type %s not supported" % type(var))
581      if isinstance(var, ops.IndexedSlices):
582        if not _ShapeLessThanOrEqual(inp.values.get_shape(), shape):
583          raise ValueError(
584              "The shape invariant specified for %s is not compatible with "
585              "the initial shape of the values tensor of this IndexedSlices. "
586              "It enters the loop with shape %s, but the specified shape "
587              "invariant is %s." % (inp.values.name, inp.values.get_shape(),
588                                    shape))
589        var.values.set_shape(shape)
590        var.indices.set_shape(tensor_shape.TensorShape([shape[0]]))
591        if var.dense_shape is not None:
592          var.dense_shape.set_shape(tensor_shape.TensorShape([shape.ndims]))
593      else:
594        if not _ShapeLessThanOrEqual(inp.dense_shape.get_shape(), shape):
595          raise ValueError(
596              "The shape invariant specified for %s is not compatible with "
597              "the initial shape of the shape tensor of this SparseTensor. "
598              "It enters the loop with shape %s, but the specified shape "
599              "invariant is %s." % (inp.dense_shape.name,
600                                    inp.dense_shape.get_shape(), shape))
601        var.values.set_shape(tensor_shape.TensorShape([None]))
602        var.indices.set_shape(tensor_shape.TensorShape([None, shape.ndims]))
603        var.dense_shape.set_shape(shape)
604
605
606def _EnforceShapeInvariant(merge_var, next_var):
607  """Check if the shapes of the loops variables are invariants.
608
609  Args:
610    merge_vars: The list of tensors representing the initial values of the
611      loop variables.
612    next_vars: The list of tensors representing the values of the loop
613      variables after one loop iteration.
614
615  Raises:
616    ValueError: If any tensor in `merge_vars` has a more specific shape than
617      its correspnding tensor in `next_var`.
618  """
619  if isinstance(merge_var, ops.Tensor):
620    m_shape = merge_var.get_shape()
621    n_shape = next_var.get_shape()
622    if not _ShapeLessThanOrEqual(n_shape, m_shape):
623      # TODO(skyewm): get original loop input that caused the shape error and
624      # report its name instead of the merge node's.
625      raise ValueError(
626          "The shape for %s is not an invariant for the loop. It enters "
627          "the loop with shape %s, but has shape %s after one iteration. "
628          "Provide shape invariants using either the `shape_invariants` "
629          "argument of tf.while_loop or set_shape() on the loop variables." %
630          (merge_var.name, m_shape, n_shape))
631  else:
632    if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
633      raise TypeError("Type %s not supported" % type(var))
634    if isinstance(var, ops.IndexedSlices):
635      m_values_shape = merge_var.values.get_shape()
636      m_indices_shape = merge_var.indices.get_shape()
637      m_shape_shape = tensor_shape.TensorShape(None)
638      if merge_var.dense_shape is not None:
639        m_shape_shape = merge_var.dense_shape.get_shape()
640      n_values_shape = next_var.values.get_shape()
641      n_indices_shape = next_var.indices.get_shape()
642      n_shape_shape = tensor_shape.TensorShape(None)
643      if next_var.dense_shape is not None:
644        n_shape_shape = next_var.dense_shape.get_shape()
645      if (not _ShapeLessThanOrEqual(n_values_shape, m_values_shape) or
646          not _ShapeLessThanOrEqual(n_indices_shape, m_indices_shape)):
647        if not _ShapeLessThanOrEqual(n_values_shape, m_values_shape):
648          raise ValueError(
649              "The shape for %s is not an invariant for the loop. It enters "
650              "the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) "
651              "after one iteration. Provide shape invariants using either the "
652              "`shape_invariants` argument of tf.while_loop or set_shape() "
653              "on the loop variables." %
654              (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape,
655               n_values_shape, n_indices_shape, n_shape_shape))
656    else:
657      m_values_shape = merge_var.values.get_shape()
658      m_indices_shape = merge_var.indices.get_shape()
659      m_shape_shape = merge_var.dense_shape.get_shape()
660      n_values_shape = next_var.values.get_shape()
661      n_indices_shape = next_var.indices.get_shape()
662      n_shape_shape = next_var.dense_shape.get_shape()
663      if (not _ShapeLessThanOrEqual(n_values_shape, m_values_shape) or
664          not _ShapeLessThanOrEqual(n_indices_shape, m_indices_shape) or
665          not _ShapeLessThanOrEqual(n_shape_shape, m_shape_shape)):
666        raise ValueError(
667            "The shape for %s is not an invariant for the loop. It enters "
668            "the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) "
669            "after one iteration. Provide shape invariants using either "
670            "the `shape_invariants` argument of tf.while_loop or set_shape() "
671            "on the loop variables." %
672            (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape,
673             n_values_shape, n_indices_shape, n_shape_shape))
674
675
676def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True):
677  """Add NextIteration and back edge from v to m."""
678  if isinstance(m, ops.Tensor):
679    v = ops.convert_to_tensor(v)
680    v = _NextIteration(v)
681    if enforce_shape_invariant:
682      # Make sure the shapes of loop outputs are correct. We do this before
683      # calling _update_input, which will raise a less-helpful error message if
684      # the types don't match.
685      # TODO(skyewm): call this for other cases below (needs testing)
686      _EnforceShapeInvariant(m, v)
687    m.op._update_input(1, v)  # pylint: disable=protected-access
688  elif isinstance(m, ops.IndexedSlices):
689    # pylint: disable=protected-access
690    v = math_ops._as_indexed_slices(v, optimize=False)
691    v = _NextIteration(v)
692    m.values.op._update_input(1, v.values)
693    m.indices.op._update_input(1, v.indices)
694    # pylint: enable=protected-access
695    if m.dense_shape is not None:
696      if v.dense_shape is None:
697        raise ValueError("Must have dense shape: %s" % v.name)
698      m.dense_shape.op._update_input(1, v.dense_shape)
699  elif isinstance(m, sparse_tensor.SparseTensor):
700    if not isinstance(v, sparse_tensor.SparseTensor):
701      raise ValueError("Must be a sparse tensor: %s" % v.name)
702    v = _NextIteration(v)
703    # pylint: disable=protected-access
704    m.values.op._update_input(1, v.values)
705    m.indices.op._update_input(1, v.indices)
706    m.dense_shape.op._update_input(1, v.dense_shape)
707    # pylint: enable=protected-access
708  else:
709    raise TypeError("Type %s not supported" % type(m))
710  return v
711
712
713def GetMaxSizeFromNestedMaximumIterations(value, while_ctxt):
714  """Calculate a max_size for use by stack ops inside an XLA while_loop.
715
716  Args:
717    value: The value inside the while_loop forward context.  Used for printing
718      error messages.
719    while_ctxt: The forward context inside which value resides.  This does
720      not always match the value's immediate context, as `value` may be
721      inside e.g. a cond context inside the while_loop.
722
723  Returns:
724    A tensor containing the `max_size` to feed to a Stack initializer.
725
726  Raises:
727    ValueError: If `value` is nested inside a `while_loop` that either
728      lacks a `maximum_iterations` parameter, or the `maximum_iterations`
729      parameter:
730
731        - is inside a `while_loop` that is a parent of the calling context, and
732        - cannot be evaluated at graph build time to a constant.
733  """
734  value_name = value.name
735  # curr_ctxt is the context that tf.gradients was called in.
736  curr_ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
737
738  curr_ctxt_name = curr_ctxt.name if curr_ctxt is not None else ""
739  max_size = constant_op.constant(1)
740
741  # Loop through all containing while contexts between value and the
742  # current context, multiplying together each context's
743  # max_iterations to get the maximum stack size.
744  while while_ctxt not in (None, curr_ctxt):
745    max_iter = while_ctxt.maximum_iterations
746    if max_iter is None:
747      raise ValueError(
748          "Cannot create a gradient accumulator for tensor '%s' inside "
749          "XLA while_loop because maximum_iterations was not passed to "
750          "the tf.while_loop call ('%s')." % (value_name, while_ctxt.name))
751
752    # pylint: disable=protected-access
753    max_iter_ctxt = max_iter.op._get_control_flow_context()
754    # pylint: enable=protected-access
755
756    # If max_iter_ctxt (non-strictly) contains curr_ctxt, then it's OK to use.
757    if util.IsContainingContext(curr_ctxt, max_iter_ctxt):
758      max_size *= max_iter
759    else:
760      # We cannot use max_iter because it's defined in a nested while
761      # or cond context, so will fail if we try to use it as input to
762      # any ops in curr_ctxt (e.g. max_size or the final accumulator
763      # stack). Attempt to get a constant value out to use instead.
764      const_max_iter = tensor_util.constant_value(max_iter)
765      if const_max_iter is None:
766        raise ValueError(
767            "Cannot create a gradient accumulator for tensor '%s' inside XLA "
768            "while_loop. maximum_iterations tensor '%s' for while_loop context "
769            "'%s' must be statically known (e.g. a constant value or known "
770            "shape dimension), or be defined at or outside the while loop "
771            "context '%s' (currently defined in '%s')." %
772            (value_name, max_iter.name, while_ctxt.name, curr_ctxt_name,
773             max_iter_ctxt.name))
774      max_size *= const_max_iter
775
776    # Find the next outer WhileContext (or stop if we reach the
777    # tf.gradient's context).
778    while_ctxt = util.GetContainingWhileContext(
779        while_ctxt.outer_context, stop_ctxt=curr_ctxt)
780
781  return max_size
782
783
784class GradLoopState(object):
785  """The state used for constructing the gradient graph for a while loop.
786
787  We create a GradLoopState for each while loop in forward and its
788  corresponding while loop in backprop. This gives us access to both
789  the forward and the backprop WhileContexts.
790
791  During the construction of gradient graph, any time when we detect
792  a forward value that is needed for backprop, we create a history
793  accumulator and add it to `history_map`. Any time when we backprop
794  a loop switch op (in _SwitchGrad), we add the grad merge op in
795  `switch_map`.
796  """
797
798  def __init__(self, forward_ctxt, outer_grad_state):
799    # The grad loop state for the outer while loop.
800    self._outer_grad_state = None
801
802    # The while loop context for forward.
803    self._forward_context = None
804
805    # The loop counter added by AddForwardLoopCounter. It is the value
806    # of the loop counter for the next iteration.
807    self._forward_index = None
808
809    # A sync op for forward.
810    self._forward_sync = None
811
812    # The while loop context for backprop.
813    self._grad_context = None
814
815    # The loop counter added by AddBackpropLoopCounter. It is the value
816    # of the loop counter for the current iteration.
817    self._grad_index = None
818
819    # A sync op for backprop.
820    self._grad_sync = None
821
822    # Information needed by backprop.
823    self._history_map = {}
824    self._switch_map = {}
825    self._unused_exits = []
826    self._deferred_exits = []
827    self._forward_loop_exits = list(forward_ctxt.loop_exits)
828    self._pending_exits_count = len(forward_ctxt.loop_exits)
829
830    self._outer_grad_state = outer_grad_state
831    if outer_grad_state:
832      outer_forward_ctxt = outer_grad_state.forward_context
833    else:
834      outer_forward_ctxt = forward_ctxt.outer_context
835
836    # Add the forward loop counter.
837    if outer_forward_ctxt:
838      outer_forward_ctxt.Enter()
839    cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
840    if outer_forward_ctxt:
841      outer_forward_ctxt.Exit()
842    self._forward_context = forward_ctxt
843    self._forward_index = forward_index
844
845    # Add the backprop WhileContext, and the backprop loop counter.
846    if outer_grad_state:
847      # This is a nested loop. Remember the iteration counts for each
848      # execution of this inner loop.
849      outer_forward_ctxt.AddName(cnt.name)
850      history_cnt = outer_grad_state.AddForwardAccumulator(cnt)
851
852      outer_grad_ctxt = outer_grad_state.grad_context
853      outer_grad_ctxt.Enter()
854      self._grad_context = WhileContext(
855          maximum_iterations=forward_ctxt.maximum_iterations,
856          parallel_iterations=forward_ctxt.parallel_iterations,
857          back_prop=forward_ctxt.back_prop,
858          swap_memory=forward_ctxt.swap_memory,
859          name=forward_ctxt.name,
860          grad_state=self)
861      real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt)
862      self._grad_index = self._grad_context.AddBackpropLoopCounter(
863          real_cnt, outer_grad_state)
864      outer_grad_ctxt.Exit()
865    else:
866      if outer_forward_ctxt:
867        outer_forward_ctxt.Enter()
868      self._grad_context = WhileContext(
869          maximum_iterations=forward_ctxt.maximum_iterations,
870          parallel_iterations=forward_ctxt.parallel_iterations,
871          back_prop=forward_ctxt.back_prop,
872          swap_memory=forward_ctxt.swap_memory,
873          name=forward_ctxt.name,
874          grad_state=self)
875      self._grad_index = self._grad_context.AddBackpropLoopCounter(
876          cnt, outer_grad_state)
877      if outer_forward_ctxt:
878        outer_forward_ctxt.Exit()
879
880  @property
881  def outer_grad_state(self):
882    """The grad loop state for outer loop."""
883    return self._outer_grad_state
884
885  @property
886  def forward_context(self):
887    """The while loop context for forward."""
888    return self._forward_context
889
890  @property
891  def forward_index(self):
892    """The loop index of forward loop."""
893    return self._forward_index
894
895  @property
896  def forward_sync(self):
897    """A control trigger node for synchronization in the forward loop.
898
899    One main use is to keep the push ops of a stack executed in the
900    iteration order.
901    """
902    if self._forward_sync is None:
903      with ops.control_dependencies(None):
904        self._forward_sync = control_trigger(name="f_sync")
905      self._forward_sync._set_control_flow_context(self._forward_context)
906      self._forward_index.op._add_control_input(self._forward_sync)
907    return self._forward_sync
908
909  @property
910  def grad_context(self):
911    """The corresponding WhileContext for gradient."""
912    return self._grad_context
913
914  @property
915  def grad_index(self):
916    """The loop index of backprop loop."""
917    return self._grad_index
918
919  @property
920  def grad_sync(self):
921    """A control trigger node for synchronization in the grad loop.
922
923    One main use is to keep the pop ops of a stack executed in the
924    iteration order.
925    """
926    if self._grad_sync is None:
927      with ops.control_dependencies(None):
928        self._grad_sync = control_trigger(name="b_sync")
929      self._grad_sync._set_control_flow_context(self._grad_context)
930      self._grad_index.op._add_control_input(self._grad_sync)
931      if self._grad_context.outer_context:
932        self._grad_context.outer_context.AddInnerOp(self._grad_sync)
933    return self._grad_sync
934
935  @property
936  def history_map(self):
937    """The map that records all the tensors needed for backprop."""
938    return self._history_map
939
940  @property
941  def switch_map(self):
942    """The map that records all the Switch ops for the while loop."""
943    return self._switch_map
944
945  @property
946  def unused_exits(self):
947    """The list of "unused" exits."""
948    return self._unused_exits
949
950  @property
951  def deferred_exits(self):
952    """The list of "deferred" exits."""
953    return self._deferred_exits
954
955  @property
956  def forward_loop_exits(self):
957    """The list of exits of the forward loop."""
958    return self._forward_loop_exits
959
960  @property
961  def pending_exits_count(self):
962    """The number of exits we expect to see but haven't."""
963    return self._pending_exits_count
964
965  @pending_exits_count.setter
966  def pending_exits_count(self, cnt):
967    """Set the pending count to cnt."""
968    self._pending_exits_count = cnt
969
970  def AddForwardAccumulator(self, value, dead_branch=False):
971    """Add an accumulator for each forward tensor that is needed in backprop.
972
973    This is added to the forward loop at the first time when a tensor
974    in the forward loop is used by backprop gradient computation loop.
975    We create an accumulator that accumulates the value of tensor at each
976    iteration. Called in the control flow context where gradients() is called.
977
978    The pseudocode is:
979    ```
980      acc = stack();
981      while (_pivot) {
982        acc = stack_push(acc, value);
983      }
984    ```
985
986    We make sure that the stack push op in one iteration is executed before
987    next iteration. This is achieved by adding a control edge from
988    `forward_index.op.inputs[0].op` to the push op, and another control
989    edge from the push op to either `forward_index.op` or `forward_sync`.
990
991    Args:
992      value: The source tensor in forward that is to be accumulated.
993      dead_branch: True iff the tensor is on a dead branch of a cond.
994
995    Returns:
996      The stack that contains the accumulated history of the tensor.
997
998    Raises:
999      TypeError: For internal errors involving the value condition context.
1000      ValueError: If `value` is inside a XLA scope and a valid max size
1001        for the stack can't be found.
1002    """
1003    # curr_ctxt is the context that tf.gradients was called in.
1004    curr_ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
1005    with ops.control_dependencies(None):
1006      if curr_ctxt:
1007        curr_ctxt.Enter()
1008      with ops.colocate_with(value):
1009        # We only need to pass maximum_iterations to the stack if
1010        # we're inside an XLA context.
1011        if not util.IsInXLAContext(value.op):
1012          max_size = constant_op.constant(-1, dtypes.int32)
1013        else:
1014          max_size = GetMaxSizeFromNestedMaximumIterations(
1015              value, self.forward_context)
1016        # pylint: disable=protected-access
1017        acc = gen_data_flow_ops._stack_v2(
1018            max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
1019        # pylint: enable=protected-access
1020      if curr_ctxt:
1021        curr_ctxt.Exit()
1022
1023      # Make acc available in the forward context.
1024      enter_acc = self.forward_context.AddValue(acc)
1025
1026      # Add the stack_push op in the context of value.op.
1027      swap_enabled = self.forward_context.swap_memory
1028      value_ctxt = util.GetOutputContext(value.op)
1029      if value_ctxt == self.forward_context:
1030        # value is not nested in the forward context.
1031        self.forward_context.Enter()
1032        # pylint: disable=protected-access
1033        push = gen_data_flow_ops._stack_push_v2(
1034            enter_acc, value, swap_memory=swap_enabled)
1035        # pylint: enable=protected-access
1036        self.forward_context.Exit()
1037        # Protect stack push and order it before forward_index.
1038        self.forward_index.op._add_control_input(push.op)
1039      else:
1040        # value is in a cond context within the forward context.
1041        if not isinstance(value_ctxt, CondContext):
1042          raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
1043        if dead_branch:
1044          # The special case for creating a zero tensor for a dead
1045          # branch of a switch. See ControlFlowState.ZerosLike().
1046          value_ctxt.outer_context.Enter()
1047          # pylint: disable=protected-access
1048          push = gen_data_flow_ops._stack_push_v2(
1049              enter_acc, value, swap_memory=swap_enabled)
1050          # pylint: enable=protected-access
1051          value_ctxt.outer_context.Exit()
1052          push.op._set_control_flow_context(value_ctxt)
1053        else:
1054          value_ctxt.Enter()
1055          # pylint: disable=protected-access
1056          push = gen_data_flow_ops._stack_push_v2(
1057              enter_acc, value, swap_memory=swap_enabled)
1058          # pylint: enable=protected-access
1059          value_ctxt.Exit()
1060        # Protect stack push and order it before forward_sync.
1061        self.forward_sync._add_control_input(push.op)
1062      # Order stack push after the successor of forward_index
1063      add_op = self.forward_index.op.inputs[0].op
1064      push.op._add_control_input(add_op)
1065      return acc
1066
1067  def AddBackpropAccumulatedValue(self, history_value, value,
1068                                  dead_branch=False):
1069    """Add the getter for an accumulated value in the grad context.
1070
1071    This is added to the backprop loop. Called in the grad context to
1072    get the value of an accumulated value. The stack pop op must be guarded
1073    by the pred of the controlling cond.
1074
1075    Args:
1076      history_value: The history (a stack) of a value.
1077      value: The value that is pushed onto the stack.
1078      dead_branch: True iff the tensor is on a dead branch of a cond.
1079
1080    Returns:
1081      The current value (the top of the stack).
1082    """
1083    history_ctxt = history_value.op._get_control_flow_context()
1084    # Find the cond context that controls history_value if any.
1085    cond_ctxt = None
1086    value_ctxt = value.op._get_control_flow_context()
1087    while value_ctxt and value_ctxt != history_ctxt:
1088      if isinstance(value_ctxt, CondContext):
1089        cond_ctxt = value_ctxt
1090        break
1091      value_ctxt = value_ctxt.outer_context
1092    with ops.control_dependencies(None):
1093      self.grad_context.Enter()
1094      if cond_ctxt:
1095        # Guard stack pop with a switch if it is controlled by a cond.
1096        grad_state = self
1097        pred = None
1098        while pred is None and grad_state:
1099          pred = grad_state.history_map.get(cond_ctxt.pred.name)
1100          grad_state = grad_state.outer_grad_state
1101        if pred is None:
1102          pred = cond_ctxt.pred
1103        branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
1104        history_value = _SwitchRefOrTensor(history_value, pred)[branch]
1105      # pylint: disable=protected-access
1106      pop = gen_data_flow_ops._stack_pop_v2(history_value,
1107                                            value.dtype.base_dtype)
1108      # pylint: enable=protected-access
1109      pop.set_shape(value.get_shape())
1110      self.grad_context.Exit()
1111    parallel_iterations = self.grad_context.parallel_iterations
1112    if parallel_iterations > 1:
1113      # All pops are ordered after pivot_for_body and before grad_sync.
1114      self.grad_sync._add_control_input(pop.op)
1115    return pop
1116
1117  def GetRealValue(self, value):
1118    """Get the real value of `value`.
1119
1120    If backprop "uses" a value produced by forward inference, an accumulator
1121    is added in the forward loop to accumulate its values.  We use the
1122    accumulated value. This method must be called in the grad loop context.
1123    `value` must be in forward and needed for backprop.
1124
1125    Args:
1126      value: A tensor to be captured.
1127
1128    Returns:
1129      The same tensor obtained from the saved history.
1130    """
1131    assert value.op.type not in ["Variable", "VariableV2"]
1132    real_value = self._history_map.get(value.name)
1133    if real_value is None:
1134      cur_value = value
1135      cur_grad_state = self
1136      while True:
1137        enter_op = util.GetLoopConstantEnter(cur_value)
1138        if enter_op:
1139          # Special case: cur_value comes from a constant Enter node.
1140          cur_value = enter_op.inputs[0]
1141          cur_grad_state = cur_grad_state.outer_grad_state
1142          if cur_grad_state is None:
1143            # We are now outside all nested loops for this gradient(),
1144            # so `value` is a loop invariant and there is no need to
1145            # save the history of value. Just make cur_value to enter
1146            # the right control flow context.
1147            real_value = self._grad_context.AddValue(cur_value)
1148            break
1149        elif constant_op.is_constant(cur_value):
1150          # If the value to be forwarded is a constant, clone the constant in
1151          # the gradient loop rather than using a stack.
1152          # TODO(phawkins): consider hoisting the constant out of the loop
1153          # instead.
1154          real_value = constant_op.constant(
1155              tensor_util.constant_value(cur_value), dtype=cur_value.dtype)
1156          break
1157        else:
1158          # Record the history of this value in forward_ctxt.
1159          self._grad_context.Exit()
1160          history_value = cur_grad_state.AddForwardAccumulator(cur_value)
1161          self._grad_context.Enter()
1162          break
1163
1164      if real_value is None:
1165        # Add the stack pop op in the grad context.
1166        real_value = cur_grad_state.AddBackpropAccumulatedValue(
1167            history_value, cur_value)
1168        if cur_grad_state != self:
1169          real_value = self._grad_context.AddValue(real_value)
1170      self._history_map[value.name] = real_value
1171    return real_value
1172
1173
1174def _GetWhileContext(op):
1175  """Get the WhileContext to which this op belongs."""
1176  ctxt = op._get_control_flow_context()
1177  if ctxt:
1178    ctxt = ctxt.GetWhileContext()
1179  return ctxt
1180
1181
1182class ControlFlowState(object):
1183  """Maintain the mapping from the loops to their grad states."""
1184
1185  def __init__(self):
1186    self._map = {}  # maps forward loop context to GradLoopState
1187
1188  def GetGradState(self, op, before):
1189    """Return the grad state for this op if it's in a forward loop context."""
1190    if before and util.IsLoopExit(op):
1191      forward_ctxt = op._get_control_flow_context()
1192      forward_ctxt = forward_ctxt.outer_context
1193      if forward_ctxt:
1194        forward_ctxt = forward_ctxt.GetWhileContext()
1195    else:
1196      forward_ctxt = _GetWhileContext(op)
1197    if forward_ctxt:
1198      return self._map.get(forward_ctxt)
1199    return None
1200
1201  def ProcessUnusedLoopExits(self, pending_count, to_ops_set):
1202    """Process all the "unused" loop exits.
1203
1204    The "unused" exits of the loops are added to `unused_exits`. An exit is
1205    unused if its pending_count is 0. If there is an exit with real gradient,
1206    all these deferred exits will enter the backprop loop with zero gradient.
1207    Otherwise, they will enter the backprop loop with None. As an example,
1208    people often write:
1209
1210    ```python
1211    v1, _ = tf.while_loop(p, b, [x1, x2])
1212    result = gradients(v1, x1)
1213    ```
1214
1215    The exit node for x2 is not included by the betweenness analysis. But we
1216    need to backprop x2 if x2 is involved in computing v1.
1217
1218    Args:
1219      pending_count: The number of backprop inputs for every op.
1220      to_ops_set: The set of ops for ys in gradients(ys, xs)
1221
1222    Returns:
1223      The set of unused loop exits that we know at this point we need
1224      to backprop.
1225    """
1226    loop_exits = []
1227    for _, grad_state in self._map.items():
1228      # pylint: disable=protected-access
1229      for y in grad_state.forward_loop_exits:
1230        if pending_count[y.op._id] == 0:
1231          grad_state.pending_exits_count -= 1
1232          if y.op._id not in to_ops_set:
1233            grad_state.unused_exits.append(y)
1234          if grad_state.pending_exits_count == 0:
1235            loop_exits.extend(grad_state.unused_exits)
1236      # Need to include Enters in backprop for higher-order gradients.
1237      for y in grad_state.forward_context.loop_enters:
1238        if pending_count[y.op._id] == 0:
1239          pending_count[y.op._id] = 1
1240      # pylint: enable=protected-access
1241    return loop_exits
1242
1243  def EnterGradWhileContext(self, op, before):
1244    """Enter the WhileContext for gradient computation."""
1245    grad_state = self.GetGradState(op, before)
1246    if grad_state:
1247      grad_state.grad_context.Enter()
1248
1249  def ExitGradWhileContext(self, op, before):
1250    """Exit the WhileContext for gradient computation."""
1251    grad_state = self.GetGradState(op, before)
1252    if grad_state:
1253      grad_state.grad_context.Exit()
1254
1255  def AddWhileContext(self, op, between_op_list, between_ops):
1256    """Add the grad state for the while loop that op belongs to.
1257
1258    Note that op is an Exit, and this method must be called in
1259    the control flow context where gradients() is called.
1260
1261    Note that this method modifies `between_op_list` and `between_ops`.
1262    """
1263    forward_ctxt = _GetWhileContext(op)
1264    grad_state = self._map.get(forward_ctxt)
1265    if grad_state is None:
1266      # This is a new while loop so create a grad state for it.
1267      outer_forward_ctxt = forward_ctxt.outer_context
1268      if outer_forward_ctxt:
1269        outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
1270      outer_grad_state = None
1271      if outer_forward_ctxt:
1272        outer_grad_state = self._map.get(outer_forward_ctxt)
1273      grad_state = GradLoopState(forward_ctxt, outer_grad_state)
1274      self._map[forward_ctxt] = grad_state
1275
1276      # We need to include all exits of a loop for backprop.
1277      for loop_exit in grad_state.forward_loop_exits:
1278        if not between_ops[loop_exit.op._id]:
1279          between_ops[loop_exit.op._id] = True
1280          between_op_list.append(loop_exit.op)
1281
1282  def ZerosLikeForExit(self, val):
1283    """Create zeros_like gradient for a loop exit.
1284
1285    If the result of a loop variable is not used but is involved in
1286    computing the result of some needed loop variable, we create a
1287    zero-valued tensor that is fed as gradient for the Exit node of that
1288    loop variable. Note that val.op is an Exit, and this method must be
1289    called in the control flow context where gradients() is called.
1290
1291    Args:
1292      val: The output tensor of an Exit op.
1293
1294    Returns:
1295      A zero tensor of the same shape of val.
1296    """
1297    val_shape = val.get_shape()
1298    forward_ctxt = val.op._get_control_flow_context()
1299    outer_forward_ctxt = forward_ctxt.outer_context
1300    if outer_forward_ctxt:
1301      outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
1302    outer_grad_state = None
1303    if outer_forward_ctxt:
1304      outer_grad_state = self._map.get(outer_forward_ctxt)
1305    if outer_grad_state:
1306      # This is a nested loop.
1307      if val_shape.is_fully_defined():
1308        # If the shape is known statically, just create a zero tensor
1309        # with the right shape in the right context.
1310        outer_grad_state.grad_context.Enter()
1311        result = array_ops.zeros(val_shape.dims, val.dtype)
1312        outer_grad_state.grad_context.Exit()
1313      else:
1314        # Only the shape of value is needed for backprop.
1315        forward_ctxt.outer_context.Enter()
1316        shape = array_ops.shape_internal(val, optimize=False)
1317        forward_ctxt.outer_context.Exit()
1318        # Save the shape to a stack.
1319        history_shape = outer_grad_state.AddForwardAccumulator(shape)
1320        # Get the shape back from the stack.
1321        outer_grad_ctxt = outer_grad_state.grad_context
1322        outer_grad_ctxt.Enter()
1323        real_shape = outer_grad_state.AddBackpropAccumulatedValue(
1324            history_shape, shape)
1325        result = array_ops.zeros(real_shape, val.dtype)
1326        outer_grad_ctxt.Exit()
1327    else:
1328      # This is not a nested loop.
1329      if val_shape.is_fully_defined():
1330        # If the shape is known statically, just create a zero tensor
1331        # with the right shape.
1332        result = array_ops.zeros(val_shape.dims, val.dtype)
1333      else:
1334        result = array_ops.zeros_like(val, optimize=False)
1335    return result
1336
1337  def ZerosLike(self, op, index):
1338    """Create zeros_like for the specified output of an op.
1339
1340    If op is in a while loop that is part of gradients(), this method
1341    must be called in its grad loop context.
1342
1343    Args:
1344      op: A tensorflow operation.
1345      index: the index for a specific output of the op.
1346
1347    Returns:
1348      A zero tensor of the same shape of op.outputs[index].
1349    """
1350    if util.IsLoopSwitch(op):
1351      return None
1352    dead_branch = util.IsSwitch(op)
1353    forward_ctxt = _GetWhileContext(op)
1354    grad_state = self._map.get(forward_ctxt)
1355    if grad_state is None:
1356      # op is not in a while loop that is part of gradients().
1357      return ZerosLikeOutsideLoop(op, index)
1358    op_ctxt = op._get_control_flow_context()
1359    val = ops.convert_to_tensor(op.outputs[index], name="tensor")
1360    shape = val.get_shape()
1361    if shape.is_fully_defined():
1362      # If the shape is known statically, just create a zero tensor with
1363      # the right shape in the grad loop context.
1364      result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
1365      if dead_branch:
1366        # op is a cond switch. Guard the zero tensor with a switch.
1367        pred = grad_state.history_map.get(op_ctxt.pred.name)
1368        branch = op_ctxt.branch
1369        result = _SwitchRefOrTensor(result, pred)[1 - branch]
1370    else:
1371      # Unknown shape so keep a history of the shape at runtime.
1372      if dead_branch:
1373        # Need to add a special switch to guard the value.
1374        pred = op_ctxt.pred
1375        branch = op_ctxt.branch
1376        op_ctxt.outer_context.Enter()
1377        val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch]
1378        zeros_shape = array_ops.shape_internal(val, optimize=False)
1379        op_ctxt.outer_context.Exit()
1380        val.op._set_control_flow_context(op_ctxt)
1381        zeros_shape.op._set_control_flow_context(op_ctxt)
1382      else:
1383        op_ctxt.Enter()
1384        zeros_shape = array_ops.shape_internal(val, optimize=False)
1385        op_ctxt.Exit()
1386
1387      # Add forward accumulator for shape.
1388      grad_state.grad_context.Exit()
1389      history_zeros_shape = grad_state.AddForwardAccumulator(
1390          zeros_shape, dead_branch=dead_branch)
1391      grad_state.grad_context.Enter()
1392
1393      # Create a zero tensor with the right shape.
1394      shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
1395                                                     zeros_shape, dead_branch)
1396      result = array_ops.zeros(shape, val.dtype)
1397    return result
1398
1399  def PostProcessing(self):
1400    """Perform postprocessing at the end of gradients().
1401
1402    We have created the gradient graph at this point. So this function
1403    can be used to perform any postprocessing on the gradient graph.
1404    We currently perform the following postprocessing:
1405      1. Patch the gradient graph if the output of a loop variable
1406         doesn't depend on its input.
1407    """
1408    for _, grad_state in self._map.items():
1409      for _, b_merge in grad_state.switch_map.items():
1410        if b_merge.op.inputs[0] == b_merge.op.inputs[1]:
1411          # The value of this loop variable at iteration i+1 doesn't
1412          # depend on its value at iteration i. So use zeros as the
1413          # gradients for all iterations > 0.
1414          dtype = b_merge.op.inputs[0].dtype
1415          shape = b_merge.op.inputs[0].get_shape()
1416          # pylint: disable=protected-access
1417          if shape.is_fully_defined():
1418            grad_state.grad_context.Enter()
1419            # Create a zeros and use it for iterations > 0.
1420            grad_val = constant_op.constant(0, dtype=dtype, shape=shape)
1421            next_grad_val = _NextIteration(grad_val)
1422            grad_state.grad_context.Exit()
1423          else:
1424            # Create a zeros in the outer grad context.
1425            outer_grad_ctxt = grad_state.grad_context.outer_context
1426            if outer_grad_ctxt:
1427              outer_grad_ctxt.Enter()
1428            enter_grad_op = b_merge.op.inputs[0].op
1429            enter_grad = enter_grad_op.inputs[0]
1430            grad_shape = array_ops.shape_internal(enter_grad, optimize=False)
1431            grad_val = array_ops.zeros(grad_shape)
1432            if outer_grad_ctxt:
1433              outer_grad_ctxt.Exit()
1434            # Use the zeros for iterations > 0.
1435            grad_state.grad_context.Enter()
1436            next_grad_val = _NextIteration(grad_val)
1437            grad_state.grad_context.Exit()
1438          b_merge.op._update_input(1, next_grad_val)
1439          # pylint: enable=protected-access
1440
1441
1442def MaybeCreateControlFlowState(between_op_list, between_ops,
1443                                colocate_gradients_with_ops):
1444  """Create the state for all the while loops involved in one gradients().
1445
1446  We create a ControlFlowState when there are while loops involved in
1447  gradients(). In gradients(), control flow logic is only invoked when
1448  the ControlFlowState is not None.
1449
1450  Note that this method modifies `between_op_list` and `between_ops`.
1451  """
1452  loop_state = None
1453  for op in between_op_list:
1454    if util.IsLoopExit(op):
1455      if loop_state is None:
1456        loop_state = ControlFlowState()
1457      if colocate_gradients_with_ops:
1458        with ops.colocate_with(op):
1459          loop_state.AddWhileContext(op, between_op_list, between_ops)
1460      else:
1461        loop_state.AddWhileContext(op, between_op_list, between_ops)
1462  return loop_state
1463
1464
1465def ZerosLikeOutsideLoop(op, index):
1466  """Create zeros_like for the specified output of an op."""
1467  val = op.outputs[index]
1468  if not util.IsSwitch(op):
1469    return array_ops.zeros_like(val, optimize=False)
1470  else:
1471    op_ctxt = op._get_control_flow_context()
1472    if op_ctxt:
1473      # We are in a cond context. Use a switch to create zeros only when needed.
1474      pred = op_ctxt.pred
1475      branch = op_ctxt.branch
1476      switch_val = switch(op.inputs[0], pred)[1 - branch]
1477      zeros_shape = array_ops.shape_internal(switch_val, optimize=False)
1478      return array_ops.zeros(zeros_shape, dtype=val.dtype)
1479    else:
1480      return array_ops.zeros_like(val, optimize=False)
1481
1482
1483class ControlFlowContext(object):
1484  """The base class for control flow context.
1485
1486  The usage pattern is a sequence of (Enter, Exit) followed by a final
1487  ExitResult.
1488
1489  We maintain the following state for control flow contexts during graph
1490  construction:
1491   1. graph has _control_flow_context: the current context used to
1492      construct new nodes. Changed by ctxt.Enter() and ctxt.Exit()
1493   2. op has _control_flow_context: the context to which the op belongs.
1494      Set at the time the op is created. Immutable.
1495   3. A ControlFlowContext has _outer_context: the context in which this
1496      context is created. Set at the time a context is created. Immutable.
1497   4. A ControlFlowContext has _context_stack.
1498      Pushed and popped by ctxt.Enter() and ctxt.Exit()
1499  """
1500
1501  def __init__(self, values_def=None, import_scope=None):
1502    self._nested_contexts = []
1503    self._outer_context = ops.get_default_graph()._get_control_flow_context()
1504    if self._outer_context:
1505      self._outer_context._nested_contexts.append(self)  # pylint: disable=protected-access
1506    self._context_stack = []
1507    if values_def:
1508      self._init_values_from_proto(values_def, import_scope=import_scope)
1509    else:
1510      # Values that have been already seen in this context.
1511      self._values = set()
1512      # Values referenced by but external to this context.
1513      self._external_values = {}
1514
1515  def _init_values_from_proto(self, values_def, import_scope=None):
1516    """Initializes values and external_values from `ValuesDef` protocol buffer.
1517
1518    Args:
1519      values_def: `ValuesDef` protocol buffer.
1520      import_scope: Optional `string`. Name scope to add.
1521    """
1522    assert isinstance(values_def, control_flow_pb2.ValuesDef)
1523    self._values = set(
1524        ops.prepend_name_scope(value, import_scope)
1525        for value in values_def.values)
1526    g = ops.get_default_graph()
1527    self._external_values = {}
1528    for k, v in values_def.external_values.items():
1529      k = ops.prepend_name_scope(k, import_scope)
1530      self._external_values[k] = g.as_graph_element(
1531          ops.prepend_name_scope(v, import_scope))
1532    op_names = set([
1533        op.split(":")[0]
1534        for op in self._values - set(self._external_values.keys())
1535    ])
1536    for op in op_names:
1537      # pylint: disable=protected-access
1538      g.as_graph_element(op)._set_control_flow_context(self)
1539      # pylint: enable=protected-access
1540
1541  @property
1542  def name(self):
1543    return self._name
1544
1545  @property
1546  def outer_context(self):
1547    """Return the context containing this context."""
1548    return self._outer_context
1549
1550  @property
1551  def grad_state(self):
1552    raise NotImplementedError("Abstract method")
1553
1554  @property
1555  def back_prop(self):
1556    raise NotImplementedError("Abstract method")
1557
1558  @abc.abstractmethod
1559  def to_control_flow_context_def(self, context_def, export_scope=None):
1560    """Serializes this into `context_def`.
1561
1562    Args:
1563      context_def: a `ControlFlowContextDef` protocol buffer.
1564      export_scope: Optional `string`. Name scope to remove.
1565    """
1566    raise NotImplementedError("Abstract method")
1567
1568  def _to_values_def(self, export_scope=None):
1569    """Converts the values to a `ValuesDef` protocol buffer.
1570
1571    Args:
1572      export_scope: Optional `string`. Name scope to remove.
1573
1574    Returns:
1575      A `ValuesDef` protocol buffer.
1576    """
1577    values_def = control_flow_pb2.ValuesDef()
1578    values_def.values.extend(
1579        [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)])
1580    for k, v in self._external_values.items():
1581      k = ops.strip_name_scope(k, export_scope)
1582      values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope)
1583    return values_def
1584
1585  def AddName(self, name):
1586    self._values.add(name)
1587
1588  # pylint: disable=protected-access
1589  def Enter(self):
1590    """Enter this control flow context."""
1591    graph = ops.get_default_graph()
1592    self._context_stack.append(graph._get_control_flow_context())
1593    graph._set_control_flow_context(self)
1594
1595  def Exit(self):
1596    """Exit this control flow context."""
1597    graph = ops.get_default_graph()
1598    last_context = self._context_stack.pop()
1599    graph._set_control_flow_context(last_context)
1600
1601  def ExitResult(self, result):
1602    """Make a list of tensors available in the outer context."""
1603    if self._outer_context:
1604      nest.map_structure(lambda x: self._outer_context.AddName(x.name), result)
1605
1606  def GetWhileContext(self):
1607    """Return the while context containing this context."""
1608    if self._outer_context:
1609      return self._outer_context.GetWhileContext()
1610    return None
1611
1612  def _IsInOuterContext(self, op):
1613    op_ctxt = util.GetOutputContext(op)
1614    outer_ctxt = self.outer_context
1615    while outer_ctxt != op_ctxt:
1616      if outer_ctxt is None:
1617        return False
1618      outer_ctxt = outer_ctxt.outer_context
1619    return True
1620
1621  def _RemoveExternalControlEdges(self, op):
1622    """Remove any external control dependency on this op."""
1623    while_ctxt = self.GetWhileContext()
1624    # A control input of `op` is internal if it is in the same while
1625    # loop context as the enclosing while loop context of self.
1626    if while_ctxt is None:
1627      internal_control_inputs = op.control_inputs
1628    else:
1629      internal_control_inputs = []
1630      for x in op.control_inputs:
1631        ctxt = util.GetOutputContext(x)
1632        if ctxt is not None and ctxt.GetWhileContext() == while_ctxt:
1633          internal_control_inputs.append(x)
1634    external_control_inputs = []
1635    if len(internal_control_inputs) != len(op.control_inputs):
1636      external_control_inputs = list(set(op.control_inputs)
1637                                     - set(internal_control_inputs))
1638      op._remove_all_control_inputs()
1639      op._add_control_inputs(internal_control_inputs)
1640    return internal_control_inputs, external_control_inputs
1641
1642  # pylint: enable=protected-access
1643
1644  def AddInnerOp(self, op):
1645    """Notifies a scope about an operator added to an inner scope."""
1646    if self._outer_context:
1647      self._outer_context.AddInnerOp(op)
1648
1649  def GetControlPivot(self):
1650    """Returns the pivot node for this context, or None."""
1651    return None
1652
1653  def IsWhileContext(self):
1654    return False
1655
1656  def IsCondContext(self):
1657    return False
1658
1659  def IsXLAContext(self):
1660    return False
1661
1662  def __str__(self):
1663    return self.name
1664
1665
1666class CondContext(ControlFlowContext):
1667  """The context for the conditional construct."""
1668
1669  def __init__(self,
1670               pred=None,
1671               pivot=None,
1672               branch=None,
1673               name="cond_text",
1674               context_def=None,
1675               import_scope=None):
1676    """Creates a `CondContext`.
1677
1678    Args:
1679      pred: The `boolean` tensor for the conditional predicate.
1680      pivot: The predicate tensor in this branch.
1681      branch: 0 or 1 representing this branch.
1682      name: Name of the `CondContext` python object.
1683      context_def: Optional `ContextDef` protocol buffer to initialize the
1684        `CondContext` object from.
1685      import_scope: Optional `string`. Name scope to add. Only used when
1686        initialing from protocol buffer.
1687    """
1688    self._name = ops.get_default_graph().unique_name(name)
1689
1690    if context_def:
1691      self._init_from_proto(context_def, import_scope=import_scope)
1692    else:
1693      # Initializes the default fields.
1694      ControlFlowContext.__init__(self)
1695      self._pred = pred  # The boolean tensor for the cond predicate
1696      self._pivot = pivot  # The predicate tensor in this branch
1697      self._branch = branch  # 0 or 1 representing this branch
1698
1699      # Values considered to have been already seen in this context.
1700      self._values.add(pred.name)
1701      self._values.add(pivot.name)
1702
1703  def _init_from_proto(self, context_def, import_scope=None):
1704    """Creates a new `CondContext` from protocol buffer.
1705
1706    Args:
1707      context_def: `CondContextDef` protocol buffer.
1708      import_scope: Optional `string`. Name scope to add.
1709    """
1710    assert isinstance(context_def, control_flow_pb2.CondContextDef)
1711    # Create from context_def.
1712    g = ops.get_default_graph()
1713    self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
1714    self._pred = g.as_graph_element(
1715        ops.prepend_name_scope(context_def.pred_name, import_scope))
1716    self._pivot = g.as_graph_element(
1717        ops.prepend_name_scope(context_def.pivot_name, import_scope))
1718    self._branch = context_def.branch
1719    super(CondContext, self).__init__(
1720        values_def=context_def.values_def, import_scope=import_scope)
1721
1722  @property
1723  def pred(self):
1724    return self._pred
1725
1726  @property
1727  def pivot(self):
1728    return self._pivot
1729
1730  @property
1731  def branch(self):
1732    return self._branch
1733
1734  @property
1735  def grad_state(self):
1736    if self.GetWhileContext():
1737      return self.GetWhileContext().grad_state
1738    return None
1739
1740  @property
1741  def back_prop(self):
1742    if self.GetWhileContext():
1743      self.GetWhileContext().back_prop
1744    return False
1745
1746  def GetControlPivot(self):
1747    return self._pivot
1748
1749  def to_proto(self, export_scope=None):
1750    """Converts a `CondContext` to a `CondContextDef` protocol buffer.
1751
1752    Args:
1753      export_scope: Optional `string`. Name scope to remove.
1754
1755    Returns:
1756      A `CondContextDef` protocol buffer.
1757    """
1758    if (export_scope is None or self.name.startswith(export_scope)):
1759      context_def = control_flow_pb2.CondContextDef()
1760      context_def.context_name = ops.strip_name_scope(self.name, export_scope)
1761      context_def.pred_name = ops.strip_name_scope(self._pred.name,
1762                                                   export_scope)
1763      context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
1764                                                    export_scope)
1765      context_def.branch = self._branch
1766      context_def.values_def.MergeFrom(super(CondContext, self)._to_values_def(
1767          export_scope))
1768      # TODO(b/72868227): enable this once the corresponding control_flow.proto
1769      # changes have been checked in (they aren't checked in and this is
1770      # disabled for now to ensure forwards compatibility).
1771      if False:  # pylint: disable=using-constant-test
1772        for nested in self._nested_contexts:
1773          nested_def = context_def.nested_contexts.add()
1774          nested.to_control_flow_context_def(nested_def)
1775
1776      return context_def
1777    else:
1778      return None
1779
1780  @staticmethod
1781  def from_proto(context_def, import_scope=None):
1782    """Returns a `CondContext` object created from `context_def`."""
1783    ret = CondContext(context_def=context_def,
1784                      import_scope=import_scope)
1785
1786    # TODO(b/72868227): remove "if hasattr(...)" once the corresponding
1787    # control_flow.proto changes have been checked in (they aren't checked in
1788    # and this is here for now to ensure forwards compatibility).
1789    if hasattr(context_def, "nested_contexts"):
1790      ret.Enter()
1791      for nested_def in context_def.nested_contexts:
1792        from_control_flow_context_def(nested_def)
1793      ret.Exit()
1794    return ret
1795
1796  def to_control_flow_context_def(self, context_def, export_scope=None):
1797    context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
1798
1799  def AddValue(self, val):
1800    """Add `val` to the current context and its outer context recursively."""
1801    if val.name in self._values:
1802      # Use the real value if it comes from outer context. This is needed in
1803      # particular for nested conds.
1804      result = self._external_values.get(val.name)
1805      result = val if result is None else result
1806    else:
1807      result = val
1808      self._values.add(val.name)
1809      if self._outer_context:
1810        result = self._outer_context.AddValue(val)
1811        self._values.add(result.name)
1812      with ops.control_dependencies(None):
1813        result = _SwitchRefOrTensor(result, self._pred)[self._branch]
1814        if self._outer_context:
1815          self._outer_context.AddInnerOp(result.op)
1816
1817      result.op.graph.prevent_fetching(result.op)
1818      # pylint: disable=protected-access
1819      result.op._set_control_flow_context(self)
1820      # pylint: enable=protected-access
1821
1822      self._values.add(result.name)
1823      self._external_values[val.name] = result
1824    return result
1825
1826  def AddOp(self, op):
1827    self._AddOpInternal(op)
1828
1829  def _AddOpInternal(self, op):
1830    """Add `op` to the current context."""
1831    if not op.inputs:
1832      # Remove any external control dependency on this op
1833      self._RemoveExternalControlEdges(op)
1834      # pylint: disable=protected-access
1835      op._add_control_input(self._pivot.op)
1836      # pylint: enable=protected-access
1837      for x in op.outputs:
1838        self._values.add(x.name)
1839    else:
1840      for index in range(len(op.inputs)):
1841        x = op.inputs[index]
1842        real_x = self.AddValue(x)
1843        if real_x != x:
1844          # pylint: disable=protected-access
1845          op._update_input(index, real_x)
1846          # pylint: enable=protected-access
1847      # Remove any external control dependency on this op.
1848      self._RemoveExternalControlEdges(op)
1849      for x in op.outputs:
1850        self._values.add(x.name)
1851      # pylint: disable=protected-access
1852      if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
1853        op._add_control_input(self._pivot.op)
1854      # pylint: enable=protected-access
1855
1856    if self._outer_context or not util.IsLoopExit(op):
1857      op.graph.prevent_fetching(op)
1858
1859    if self._outer_context:
1860      self._outer_context.AddInnerOp(op)
1861
1862  def _ProcessOutputTensor(self, val):
1863    """Process an output tensor of a conditional branch."""
1864    real_val = val
1865    if val.name not in self._values:
1866      # Handle the special case of lambda: x
1867      self._values.add(val.name)
1868      if self._outer_context:
1869        real_val = self._outer_context.AddValue(val)
1870        self._values.add(real_val.name)
1871      real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]
1872      self._external_values[val.name] = real_val
1873    else:
1874      external_val = self._external_values.get(val.name)
1875      if external_val is not None:
1876        real_val = external_val
1877    return real_val
1878
1879  def _BuildCondTensor(self, v):
1880    if isinstance(v, ops.Operation):
1881      # Use pivot as the proxy for this op.
1882      return with_dependencies([v], self._pivot)
1883    elif isinstance(v, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
1884      values = self._ProcessOutputTensor(v.values)
1885      indices = self._ProcessOutputTensor(v.indices)
1886      if isinstance(v, ops.IndexedSlices):
1887        dense_shape = v.dense_shape
1888        if dense_shape is not None:
1889          dense_shape = self._ProcessOutputTensor(dense_shape)
1890        return ops.IndexedSlices(values, indices, dense_shape)
1891      else:
1892        dense_shape = self._ProcessOutputTensor(v.dense_shape)
1893        return sparse_tensor.SparseTensor(indices, values, dense_shape)
1894    else:
1895      v = nest.map_structure(_convert_tensorarray_to_flow, v)
1896      return self._ProcessOutputTensor(ops.convert_to_tensor(v))
1897
1898  def BuildCondBranch(self, fn):
1899    """Add the subgraph defined by fn() to the graph."""
1900    pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1901    original_result = fn()
1902    post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1903    if len(post_summaries) > len(pre_summaries):
1904      new_summaries = post_summaries[len(pre_summaries):]
1905      summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
1906      summary_ref[:] = pre_summaries
1907      with ops.control_dependencies(new_summaries):
1908        if original_result is None:
1909          return no_op(), None
1910        else:
1911          original_result = nest.map_structure(array_ops.identity,
1912                                               original_result)
1913    if original_result is None:
1914      return None, None
1915
1916    result = nest.map_structure(self._BuildCondTensor, original_result)
1917    if not isinstance(result, (list, _basetuple)):
1918      result = [result]
1919    return original_result, result
1920
1921  def IsCondContext(self):
1922    return True
1923
1924
1925def _UnpackIfSingleton(res):
1926  if isinstance(res, (list, _basetuple)) and len(res) == 1:
1927    return res[0]
1928  else:
1929    return res
1930
1931
1932# pylint: disable=redefined-outer-name
1933# pylint: disable=g-doc-args
1934@tf_export("cond")
1935@deprecation.deprecated_args(
1936    None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
1937    "fn1", "fn2")
1938def cond(pred,
1939         true_fn=None,
1940         false_fn=None,
1941         strict=False,
1942         name=None,
1943         fn1=None,
1944         fn2=None):
1945  """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
1946
1947  `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
1948  `false_fn` must have the same non-zero number and type of outputs.
1949
1950  Note that the conditional execution applies only to the operations defined in
1951  `true_fn` and `false_fn`. Consider the following simple program:
1952
1953  ```python
1954  z = tf.multiply(a, b)
1955  result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
1956  ```
1957
1958  If `x < y`, the `tf.add` operation will be executed and `tf.square`
1959  operation will not be executed. Since `z` is needed for at least one
1960  branch of the `cond`, the `tf.multiply` operation is always executed,
1961  unconditionally.
1962  Although this behavior is consistent with the dataflow model of TensorFlow,
1963  it has occasionally surprised some users who expected a lazier semantics.
1964
1965  Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
1966  call to `cond`, and not at all during `Session.run()`). `cond`
1967  stitches together the graph fragments created during the `true_fn` and
1968  `false_fn` calls with some additional graph nodes to ensure that the right
1969  branch gets executed depending on the value of `pred`.
1970
1971  `tf.cond` supports nested structures as implemented in
1972  `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
1973  same (possibly nested) value structure of lists, tuples, and/or named tuples.
1974  Singleton lists and tuples form the only exceptions to this: when returned by
1975  `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
1976  This behavior is disabled by passing `strict=True`.
1977
1978  Args:
1979    pred: A scalar determining whether to return the result of `true_fn` or
1980      `false_fn`.
1981    true_fn: The callable to be performed if pred is true.
1982    false_fn: The callable to be performed if pred is false.
1983    strict: A boolean that enables/disables 'strict' mode; see above.
1984    name: Optional name prefix for the returned tensors.
1985
1986  Returns:
1987    Tensors returned by the call to either `true_fn` or `false_fn`. If the
1988    callables return a singleton list, the element is extracted from the list.
1989
1990  Raises:
1991    TypeError: if `true_fn` or `false_fn` is not callable.
1992    ValueError: if `true_fn` and `false_fn` do not return the same number of
1993      tensors, or return tensors of different types.
1994
1995  Example:
1996
1997  ```python
1998  x = tf.constant(2)
1999  y = tf.constant(5)
2000  def f1(): return tf.multiply(x, 17)
2001  def f2(): return tf.add(y, 23)
2002  r = tf.cond(tf.less(x, y), f1, f2)
2003  # r is set to f1().
2004  # Operations in f2 (e.g., tf.add) are not executed.
2005  ```
2006
2007  """
2008  # We needed to make true_fn/false_fn keyword arguments for
2009  # backwards-compatibility. This check exists so that we can convert back to
2010  # having them be positional arguments.
2011  # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
2012  # `fn1` and `fn2` are deleted.
2013  if fn1 is not None:
2014    if true_fn is not None:
2015      raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
2016    true_fn = fn1
2017  elif true_fn is None:
2018    raise TypeError("cond(): true_fn argument required")
2019  if fn2 is not None:
2020    if false_fn is not None:
2021      raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
2022    false_fn = fn2
2023  elif false_fn is None:
2024    raise TypeError("cond(): false_fn argument required")
2025
2026  if not callable(true_fn):
2027    raise TypeError("true_fn must be callable.")
2028  if not callable(false_fn):
2029    raise TypeError("false_fn must be callable.")
2030
2031  with ops.name_scope(name, "cond", [pred]):
2032    if context.in_eager_mode():
2033      if pred:
2034        return _UnpackIfSingleton(true_fn())
2035      return _UnpackIfSingleton(false_fn())
2036
2037    # Add the Switch to the graph.
2038    if isinstance(pred, bool):
2039      raise TypeError("pred must not be a Python bool")
2040    p_2, p_1 = switch(pred, pred)
2041    pivot_1 = array_ops.identity(p_1, name="switch_t")
2042    pivot_2 = array_ops.identity(p_2, name="switch_f")
2043    pred = array_ops.identity(pred, name="pred_id")
2044    # Disable the fetching of tensors that are only on one branch of cond.
2045    for tensor in [p_1, p_2, pivot_1, pivot_2, pred]:
2046      tensor.op.graph.prevent_fetching(tensor.op)
2047
2048    # Build the graph for the true branch in a new context.
2049    context_t = CondContext(pred, pivot_1, branch=1)
2050    context_t.Enter()
2051    orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
2052    if orig_res_t is None:
2053      raise ValueError("true_fn must have a return value.")
2054    context_t.ExitResult(res_t)
2055    context_t.Exit()
2056
2057    # Build the graph for the false branch in a new context.
2058    context_f = CondContext(pred, pivot_2, branch=0)
2059    context_f.Enter()
2060    orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
2061    if orig_res_f is None:
2062      raise ValueError("false_fn must have a return value.")
2063    context_f.ExitResult(res_f)
2064    context_f.Exit()
2065
2066    if not strict:
2067      orig_res_t = _UnpackIfSingleton(orig_res_t)
2068      orig_res_f = _UnpackIfSingleton(orig_res_f)
2069
2070    # Check that the return values of the two branches have the same structure.
2071    try:
2072      nest.assert_same_structure(orig_res_t, orig_res_f)
2073    except TypeError as e:
2074      raise TypeError(
2075          "Incompatible return types of true_fn and false_fn: {}".format(e))
2076    except ValueError as e:
2077      raise ValueError(
2078          "Incompatible return values of true_fn and false_fn: {}".format(e))
2079
2080    # Add the final merge to the graph.
2081    if not res_t:
2082      raise ValueError("true_fn and false_fn must return at least one result.")
2083
2084    res_t_flat = nest.flatten(res_t)
2085    res_f_flat = nest.flatten(res_f)
2086
2087    for x, y in zip(res_t_flat, res_f_flat):
2088      assert ((isinstance(x, ops.IndexedSlices) and
2089               isinstance(y, ops.IndexedSlices)) or
2090              (isinstance(x, sparse_tensor.SparseTensor) and
2091               isinstance(y, sparse_tensor.SparseTensor)) or
2092              (isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)))
2093      val_x = x if isinstance(x, ops.Tensor) else x.values
2094      val_y = y if isinstance(y, ops.Tensor) else y.values
2095      if val_x.dtype.base_dtype != val_y.dtype.base_dtype:
2096        raise ValueError(
2097            "Outputs of true_fn and false_fn must have the same type: %s, %s" %
2098            (val_x.dtype.name, val_y.dtype.name))
2099
2100    merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
2101    merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges)
2102
2103    # Only add non-nested conds to the collection. Any nested control flow will
2104    # be encapsulated in the root context.
2105    assert context_t.outer_context == context_f.outer_context
2106    # TODO(b/72868227): remove "if True..." once the corresponding
2107    # control_flow.proto changes have been checked in (they aren't checked in
2108    # and this is disabled for now to ensure forwards compatibility).
2109    if True or context_t.outer_context is None:
2110      ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
2111      ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
2112
2113    merges = nest.pack_sequence_as(structure=orig_res_t, flat_sequence=merges)
2114
2115    # Singleton lists and tuples are automatically unpacked if strict == False.
2116    if not strict:
2117      merges = _UnpackIfSingleton(merges)
2118    return merges
2119
2120
2121# pylint: enable=g-doc-args
2122# pylint: enable=redefined-outer-name
2123
2124
2125def smart_cond(pred, true_fn=None, false_fn=None, name=None):
2126  """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
2127
2128  If `pred` is a bool or has a constant value, we return either `true_fn()`
2129  or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
2130
2131  Arguments:
2132    pred: A scalar determining whether to return the result of `true_fn` or
2133      `false_fn`.
2134    true_fn: The callable to be performed if pred is true.
2135    false_fn: The callable to be performed if pred is false.
2136    name: Optional name prefix when using `tf.cond`.
2137
2138  Returns:
2139    Tensors returned by the call to either `true_fn` or `false_fn`.
2140
2141  Raises:
2142    TypeError: If `true_fn` or `false_fn` is not callable.
2143  """
2144  if not callable(true_fn):
2145    raise TypeError('`true_fn` must be callable.')
2146  if not callable(false_fn):
2147    raise TypeError('`false_fn` must be callable.')
2148
2149  pred_value = smart_constant_value(pred)
2150  if pred_value is not None:
2151    if pred_value:
2152      return true_fn()
2153    else:
2154      return false_fn()
2155  else:
2156    return cond(pred, true_fn=true_fn, false_fn=false_fn, name=name)
2157
2158
2159def smart_constant_value(pred):
2160  """Return the bool value for `pred`, or None if `pred` had a dynamic value.
2161
2162  Arguments:
2163    pred: A scalar, either a Python bool or tensor.
2164
2165  Returns:
2166    True or False if `pred` has a constant boolean value, None otherwise.
2167
2168  Raises:
2169    TypeError: If `pred` is not a Tensor or bool.
2170  """
2171  if isinstance(pred, bool):
2172    pred_value = pred
2173  elif isinstance(pred, ops.Tensor):
2174    pred_value = tensor_util.constant_value(pred)
2175  else:
2176    raise TypeError('`pred` must be a Tensor or a Python bool.')
2177  return pred_value
2178
2179
2180def _resource_safe_shape(t):
2181  """Returns the shape of t or the variable it points to."""
2182  if t.dtype == dtypes.resource:
2183    while t.op.inputs:
2184      t = t.op.inputs[0]
2185    return tensor_shape.TensorShape(t.op.get_attr("shape"))
2186  return array_ops.shape_internal(t, optimize=False)
2187
2188
2189# TODO(yuanbyu): Consider having a unified notion of context for
2190# not only conditionals and loops but also control dependency and
2191# subgraphs.
2192class WhileContext(ControlFlowContext):
2193  """The context for the loop construct."""
2194
2195  def __init__(self,
2196               maximum_iterations=None,
2197               parallel_iterations=10,
2198               back_prop=True,
2199               swap_memory=False,
2200               name="while_context",
2201               grad_state=None,
2202               context_def=None,
2203               import_scope=None):
2204    """"Creates a `WhileContext`.
2205
2206    Args:
2207      maximum_iterations: Optional upper bound on number of loop iterations.
2208      parallel_iterations: The number of iterations allowed to run in parallel.
2209      back_prop: Whether backprop is enabled for this while loop.
2210      swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
2211      name: Optional name prefix for the returned tensors.
2212      grad_state: The gradient loop state.
2213      context_def: Optional `WhileContextDef` protocol buffer to initialize
2214        the `Whilecontext` python object from.
2215      import_scope: Optional `string`. Name scope to add. Only used when
2216        initialing from protocol buffer.
2217    """
2218    if context_def:
2219      self._init_from_proto(context_def, import_scope=import_scope)
2220    else:
2221      ControlFlowContext.__init__(self)
2222      self._init_from_args(maximum_iterations, parallel_iterations, back_prop,
2223                           swap_memory, name)
2224    # The gradient loop state.
2225    self._grad_state = grad_state
2226
2227  def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop,
2228                      swap_memory, name):
2229    """Creates a new `WhileContext` from arguments.
2230
2231    Args:
2232      maximum_iterations: Optional upper bound on number of loop iterations.
2233      parallel_iterations: The number of iterations allowed to run in parallel.
2234      back_prop: Whether backprop is enabled for this while loop.
2235      swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
2236      name: Optional name prefix for the returned tensors.
2237
2238    Raises:
2239      ValueError: If `parallel_iterations` has invalid value.
2240    """
2241    if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0):
2242      raise ValueError("`parallel_iterations` must be a positive integer: "
2243                       "%s" % parallel_iterations)
2244    self._name = ops.get_default_graph().unique_name(name)
2245    self._maximum_iterations = maximum_iterations
2246    self._parallel_iterations = parallel_iterations
2247    self._back_prop = back_prop
2248    self._swap_memory = swap_memory
2249    # We use this node to control constants created by the pred lambda.
2250    self._pivot_for_pred = None
2251    # We use this node to control constants created by the body lambda.
2252    self._pivot_for_body = None
2253    # The boolean tensor for loop termination condition. Used in code
2254    # generation for gradient computation
2255    self._pivot = None
2256    # The list of exit tensors for loop variables.
2257    self._loop_exits = []
2258    # The list of enter tensors for loop variables.
2259    self._loop_enters = []
2260
2261  def _init_from_proto(self, context_def, import_scope=None):
2262    """Creates a new `WhileContext` from protocol buffer.
2263
2264    Args:
2265      context_def: `WhileContextDef` protocol buffer.
2266      import_scope: Optional `string`. Name scope to add.
2267    """
2268    assert isinstance(context_def, control_flow_pb2.WhileContextDef)
2269    # Create from context_def.
2270    g = ops.get_default_graph()
2271    self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
2272    if context_def.maximum_iterations_name:
2273      self._maximum_iterations = g.as_graph_element(
2274          ops.prepend_name_scope(context_def.maximum_iterations_name,
2275                                 import_scope))
2276    else:
2277      self._maximum_iterations = None
2278    self._parallel_iterations = context_def.parallel_iterations
2279    self._back_prop = context_def.back_prop
2280    self._swap_memory = context_def.swap_memory
2281    self._pivot_for_pred = g.as_graph_element(
2282        ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope))
2283    # We use this node to control constants created by the body lambda.
2284    self._pivot_for_body = g.as_graph_element(
2285        ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope))
2286    # The boolean tensor for loop termination condition. Used in code
2287    # generation for gradient computation.
2288    self._pivot = g.as_graph_element(
2289        ops.prepend_name_scope(context_def.pivot_name, import_scope))
2290    # The list of exit tensors for loop variables.
2291    self._loop_exits = [
2292        g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope))
2293        for exit_name in context_def.loop_exit_names
2294    ]
2295    # The list of enter tensors for loop variables.
2296    self._loop_enters = [
2297        g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope))
2298        for enter_name in context_def.loop_enter_names
2299    ]
2300    super(WhileContext, self).__init__(
2301        values_def=context_def.values_def, import_scope=import_scope)
2302
2303    # import_scope causes self.name to be different from the original serialized
2304    # context's name. Rewrite "frame_name" attrs with the new name.
2305    if import_scope:
2306      for tensor_name in self._values:
2307        op = g.as_graph_element(tensor_name).op
2308        if util.IsLoopEnter(op):
2309          # pylint: disable=protected-access
2310          op._set_attr("frame_name",
2311                       attr_value_pb2.AttrValue(s=compat.as_bytes(self.name)))
2312          # pylint: enable=protected-access
2313
2314  @property
2315  def maximum_iterations(self):
2316    """The maximum number of iterations that will be executed."""
2317    return self._maximum_iterations
2318
2319  @property
2320  def parallel_iterations(self):
2321    """The number of iterations allowed to run in parallel."""
2322    return self._parallel_iterations
2323
2324  @property
2325  def back_prop(self):
2326    """True iff backprop is enabled for this while loop."""
2327    return self._back_prop
2328
2329  @property
2330  def swap_memory(self):
2331    """True iff GPU-CPU memory swap is enabled for this while loop."""
2332    return self._swap_memory
2333
2334  @property
2335  def pivot(self):
2336    """The boolean tensor representing the loop termination condition."""
2337    return self._pivot
2338
2339  @property
2340  def loop_enters(self):
2341    """The list of enter tensors for loop variables."""
2342    return self._loop_enters
2343
2344  @property
2345  def loop_exits(self):
2346    """The list of exit tensors for loop variables."""
2347    return self._loop_exits
2348
2349  @property
2350  def grad_state(self):
2351    """The gradient loop state."""
2352    return self._grad_state
2353
2354  def to_proto(self, export_scope=None):
2355    """Converts a `WhileContext` to a `WhileContextDef` protocol buffer.
2356
2357    Args:
2358      export_scope: Optional `string`. Name scope to remove.
2359
2360    Returns:
2361      A `WhileContextDef` protocol buffer.
2362    """
2363    if (export_scope is None or self.name.startswith(export_scope)):
2364      context_def = control_flow_pb2.WhileContextDef()
2365      context_def.context_name = ops.strip_name_scope(self.name, export_scope)
2366      context_def.parallel_iterations = self._parallel_iterations
2367      if self._maximum_iterations is not None:
2368        context_def.maximum_iterations_name = ops.strip_name_scope(
2369            self._maximum_iterations.name, export_scope)
2370      context_def.back_prop = self._back_prop
2371      context_def.swap_memory = self._swap_memory
2372      context_def.pivot_for_pred_name = ops.strip_name_scope(
2373          self._pivot_for_pred.name, export_scope)
2374      context_def.pivot_for_body_name = ops.strip_name_scope(
2375          self._pivot_for_body.name, export_scope)
2376      context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
2377                                                    export_scope)
2378      context_def.loop_exit_names.extend([
2379          ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits
2380      ])
2381      context_def.loop_enter_names.extend([
2382          ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters
2383      ])
2384      context_def.values_def.MergeFrom(
2385          super(WhileContext, self)._to_values_def(
2386              export_scope=export_scope))
2387      # TODO(b/72868227): remove "if True..." once the corresponding
2388      # control_flow.proto changes have been checked in (they aren't checked in
2389      # and this is disabled for now to ensure forwards compatibility).
2390      if False:  # pylint: disable=using-constant-test
2391        for nested in self._nested_contexts:
2392          nested_def = context_def.nested_contexts.add()
2393          nested.to_control_flow_context_def(nested_def)
2394
2395      return context_def
2396    else:
2397      return None
2398
2399  def to_control_flow_context_def(self, context_def, export_scope=None):
2400    context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
2401
2402  @staticmethod
2403  def from_proto(context_def, import_scope=None):
2404    """Returns a `WhileContext` object created from `context_def`.
2405
2406    Args:
2407      context_def: A `WhileContextDef` protocol buffer.
2408      import_scope: Optional `string`. Name scope to add.
2409
2410    Returns:
2411      A `WhileContext` Python object.
2412    """
2413    ret = WhileContext(context_def=context_def,
2414                       import_scope=import_scope)
2415    # TODO(b/72868227): remove "if hasattr(...)" once the corresponding
2416    # control_flow.proto changes have been checked in (they aren't checked in
2417    # and this is disabled for now to ensure forwards compatibility).
2418    if hasattr(context_def, "nested_contexts"):
2419      ret.Enter()
2420      for nested_def in context_def.nested_contexts:
2421        from_control_flow_context_def(nested_def, import_scope=import_scope)
2422      ret.Exit()
2423    return ret
2424
2425  def GetWhileContext(self):
2426    return self
2427
2428  def GetControlPivot(self):
2429    if self._pivot_for_body is not None:
2430      return self._pivot_for_body
2431    return self._pivot_for_pred
2432
2433  def AddValue(self, val):
2434    """Add `val` to the current context and its outer context recursively."""
2435    result = val
2436    if val.name not in self._values:
2437      self._values.add(val.name)
2438
2439      # If we are in a grad context and val is from its forward context,
2440      # use GetRealValue(), which adds the logic to save the history of
2441      # val in forward.
2442      grad_ctxt = ops.get_default_graph()._get_control_flow_context()
2443      if grad_ctxt:
2444        grad_ctxt = grad_ctxt.GetWhileContext()
2445        if grad_ctxt.grad_state:
2446          forward_ctxt = _GetWhileContext(val.op)
2447          if util.IsLoopExit(val.op):
2448            forward_ctxt = forward_ctxt.outer_context
2449            if forward_ctxt:
2450              forward_ctxt = forward_ctxt.GetWhileContext()
2451          if forward_ctxt == grad_ctxt.grad_state.forward_context:
2452            real_val = grad_ctxt.grad_state.GetRealValue(val)
2453            self._external_values[val.name] = real_val
2454            return real_val
2455
2456      if self._outer_context is not None:
2457        result = self._outer_context.AddValue(val)
2458      # Create an Enter to make `result` known to this loop context.
2459      with ops.control_dependencies(None):
2460        enter = _Enter(
2461            result,
2462            self._name,
2463            is_constant=True,
2464            parallel_iterations=self._parallel_iterations)
2465        enter.graph.prevent_feeding(enter)
2466        if self._outer_context:
2467          self._outer_context.AddInnerOp(enter.op)
2468      # Fix the control inputs and control flow context of these enter ops.
2469      self._FixControlInputsAndContext([enter])
2470
2471      # Add `enter` in this context.
2472      self._values.add(enter.name)
2473      self._external_values[val.name] = enter
2474      result = enter
2475    else:
2476      actual_val = self._external_values.get(val.name)
2477      if actual_val is not None:
2478        result = actual_val
2479    return result
2480
2481  def AddOp(self, op):
2482    """Add `op` to the current context."""
2483    # For a reduction op, if op is in a grad context and its input is from
2484    # its forward context, moving op to the forward context means we would
2485    # store the tensor after the reduction as opposed to the tensor before
2486    # reduction, and therefore could significantly reduce memory consumption.
2487    # For now, we do this only for a few ops.
2488    if op.type in {"Shape", "Size", "Rank"}:
2489      grad_ctxt = ops.get_default_graph()._get_control_flow_context()
2490      if grad_ctxt:
2491        grad_ctxt = grad_ctxt.GetWhileContext()
2492        if grad_ctxt.grad_state:
2493          op_input_forward_ctxt = _GetWhileContext(op.inputs[0].op)
2494          if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context:
2495            op_input_ctxt = op.inputs[0].op._get_control_flow_context()
2496            op._set_control_flow_context(op_input_ctxt)
2497            op_input_ctxt._AddOpInternal(op)
2498            return
2499    self._AddOpInternal(op)
2500
2501  def _AddOpInternal(self, op):
2502    """Add `op` to the current context.
2503
2504    We move any external control dependencies of the op to the loop pivot, to
2505    ensure they get executed.
2506    """
2507    if not op.inputs:
2508      # Remove any external control dependency on this op
2509      control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
2510      # Add a control edge from the control pivot to this op.
2511      if not control_inputs:
2512        # pylint: disable=protected-access
2513        op._add_control_input(self.GetControlPivot().op)
2514        # pylint: enable=protected-access
2515      for x in op.outputs:
2516        self._values.add(x.name)
2517    else:
2518      for index in range(len(op.inputs)):
2519        x = op.inputs[index]
2520        real_x = self.AddValue(x)
2521        if real_x != x:
2522          op._update_input(index, real_x)  # pylint: disable=protected-access
2523      # Remove any external control dependency on this op.
2524      _, external_inputs = self._RemoveExternalControlEdges(op)
2525      # Add a control dependency to prevent loop invariants from
2526      # enabling ops that should not be executed.
2527      self._MaybeAddControlDependency(op)
2528      for x in op.outputs:
2529        self._values.add(x.name)
2530    if external_inputs:
2531      # Use an identity to pull control inputs as data inputs. Note that we
2532      # ignore ops which don't have outputs. TODO(apassos): fix that
2533      with ops.control_dependencies(None):
2534        self.Enter()
2535        external_inputs = [array_ops.identity(x.outputs[0]).op
2536                           for x in external_inputs if x.outputs]
2537        self.Exit()
2538      op._add_control_inputs(external_inputs)  # pylint: disable=protected-access
2539    if self._outer_context or not util.IsLoopExit(op):
2540      op.graph.prevent_fetching(op)
2541      for x in op.outputs:
2542        op.graph.prevent_feeding(x)
2543
2544    if self._outer_context:
2545      self._outer_context.AddInnerOp(op)
2546
2547  def _MaybeAddControlDependency(self, op):
2548    """Add a control input to the op if it only depends on loop invariants."""
2549
2550    def _IsOpFree(op):
2551      """Determines if `op` needs a control dependency."""
2552      if op.control_inputs:
2553        return False
2554      # pylint: disable=protected-access
2555      if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
2556        return True
2557      # pylint: enable=protected-access
2558      for x in op.inputs:
2559        if not util.IsLoopConstantEnter(x.op):
2560          return False
2561      return True
2562
2563    if _IsOpFree(op):
2564      # pylint: disable=protected-access
2565      op._add_control_input(self.GetControlPivot().op)
2566      # pylint: enable=protected-access
2567
2568  def AddForwardLoopCounter(self, outer_grad_state):
2569    """Adds a loop that counts the number of iterations.
2570
2571    This is added to the forward loop at the time when we start to
2572    create the loop for backprop gradient computation. Called in
2573    the outer context of this forward context.
2574
2575    The pseudocode is:
2576      `n = 0; while (_pivot) { n++; }`
2577
2578    Note that a control dependency is added to `n` to ensure the correct
2579    execution order of stack push ops.
2580
2581    Args:
2582      outer_grad_state: The outer grad state. None if not nested.
2583
2584    Returns:
2585      The number of iterations taken by the forward loop and the loop index.
2586    """
2587    n = constant_op.constant(0, name="f_count")
2588    if outer_grad_state is not None:
2589      # Force the stack pushes of i-th execution of an inner loop to be ordered
2590      # before the pushes of (i+1)-th execution of the same inner loop.
2591      outer_add_op = outer_grad_state.forward_index.op.inputs[0].op
2592      n.op._add_control_input(outer_add_op)  # pylint: disable=protected-access
2593
2594    self.Enter()
2595    self.AddName(n.name)
2596    enter_n = _Enter(
2597        n,
2598        self._name,
2599        is_constant=False,
2600        parallel_iterations=self._parallel_iterations,
2601        name="f_count")
2602    self.loop_enters.append(enter_n)
2603
2604    merge_n = merge([enter_n, enter_n])[0]
2605    switch_n = switch(merge_n, self._pivot)
2606
2607    index = math_ops.add(switch_n[1], 1)
2608    next_n = _NextIteration(index)
2609    merge_n.op._update_input(1, next_n)
2610
2611    total_iterations = exit(switch_n[0], name="f_count")
2612    self.loop_exits.append(total_iterations)
2613    self.ExitResult([total_iterations])
2614    self.Exit()
2615    return total_iterations, next_n
2616
2617  def AddBackpropLoopCounter(self, count, outer_grad_state):
2618    """Add the backprop loop that controls the iterations.
2619
2620    This is added to the backprop loop. It is used to control the loop
2621    termination of the backprop loop. Called in the outer context of
2622    this grad context.
2623
2624    The pseudocode is:
2625      `n = count; while (n >= 1) { n--; }`
2626
2627    Note that a control dependency is added to `final_zero` to ensure the
2628    correct execution order of stack pop ops.
2629
2630    Args:
2631      count: The number of iterations for backprop.
2632      outer_grad_state: The outer grad state. None if not nested.
2633
2634    Returns:
2635      The loop index.
2636    """
2637    one = constant_op.constant(1, name="b_count")
2638
2639    self.Enter()
2640    self.AddName(count.name)
2641    enter_count = _Enter(
2642        count,
2643        self._name,
2644        is_constant=False,
2645        parallel_iterations=self._parallel_iterations,
2646        name="b_count")
2647    self.loop_enters.append(enter_count)
2648
2649    merge_count = merge([enter_count, enter_count])[0]
2650    self._pivot_for_pred = merge_count
2651
2652    pred = math_ops.greater_equal(merge_count, one)
2653    self._pivot = loop_cond(pred, name="b_count")
2654    switch_count = switch(merge_count, self._pivot)
2655
2656    index = math_ops.subtract(switch_count[1], one)
2657    self._pivot_for_body = index
2658    next_count = _NextIteration(index)
2659    merge_count.op._update_input(1, next_count)
2660
2661    final_zero = exit(switch_count[0], name="b_count")
2662    self.loop_exits.append(final_zero)
2663    if outer_grad_state is not None:
2664      # Force the stack pops of i-th execution of an inner loop to be ordered
2665      # before the pops of (i+1)-th execution of the same inner loop.
2666      # pylint: disable=protected-access
2667      outer_grad_state.grad_sync._add_control_input(final_zero.op)
2668      # pylint: enable=protected-access
2669
2670    self.ExitResult([final_zero])
2671    self.Exit()
2672    return next_count
2673
2674  def AddBackpropAccumulator(self, op, grad):
2675    """Add an accumulation loop for every loop invariant.
2676
2677    This is added to the backprop loop. It is used to accumulate partial
2678    gradients within each loop iteration. Called when in the gradient while
2679    context.
2680
2681    The pseudocode is:
2682      ```
2683      acc = 0.0;
2684      while (_pivot) {
2685        acc += grad;
2686      }
2687      ```
2688
2689    Args:
2690      op: The Enter op for a loop invariant.
2691      grad: The partial gradient of an iteration for a loop invariant.
2692
2693    Returns:
2694      The gradient for a loop invariant.
2695    """
2696    self.Exit()
2697    # Create a zeros tensor with the right shape for acc. If we don't
2698    # know the full shape statically, we will have to get the shape
2699    # dynamically from the forward inference. Getting the shape right
2700    # for the zeros is only needed for the base case when the loop exits
2701    # without running any iterations.
2702    shape = grad.get_shape()
2703    if shape.is_fully_defined():
2704      if self.outer_context:
2705        self.outer_context.Enter()
2706      acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")
2707      if self.outer_context:
2708        self.outer_context.Exit()
2709    else:
2710      value = op.inputs[0]
2711      if (isinstance(self.outer_context, WhileContext) and
2712          self.outer_context.grad_state is not None):
2713        # We are in a nested while loop.
2714        forward_ctxt = self.grad_state.forward_context
2715        forward_ctxt.outer_context.Enter()
2716        zeros_shape = array_ops.shape_internal(value, optimize=False)
2717        forward_ctxt.outer_context.Exit()
2718        outer_grad_state = self.grad_state.outer_grad_state
2719        history_zeros_shape = outer_grad_state.AddForwardAccumulator(
2720            zeros_shape)
2721        self.outer_context.Enter()
2722        real_shape = outer_grad_state.AddBackpropAccumulatedValue(
2723            history_zeros_shape, zeros_shape)
2724        acc = array_ops.zeros(real_shape, grad.dtype)
2725        self.outer_context.Exit()
2726      else:
2727        if self.outer_context:
2728          self.outer_context.Enter()
2729        zeros_shape = array_ops.shape_internal(value, optimize=False)
2730        acc = array_ops.zeros(zeros_shape, grad.dtype)
2731        if self.outer_context:
2732          self.outer_context.Exit()
2733
2734    self.Enter()
2735    self.AddName(acc.name)
2736    enter_acc = _Enter(
2737        acc,
2738        self._name,
2739        is_constant=False,
2740        parallel_iterations=self._parallel_iterations,
2741        name="b_acc")
2742    self.loop_enters.append(enter_acc)
2743
2744    merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
2745    switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot)
2746
2747    add_acc = math_ops.add(switch_acc_true, grad)
2748    next_acc = _NextIteration(add_acc)
2749    merge_acc.op._update_input(1, next_acc)  # pylint: disable=protected-access
2750
2751    result_acc = exit(switch_acc_false, name="b_acc")
2752    self.loop_exits.append(result_acc)
2753    self.ExitResult([result_acc])
2754    return result_acc
2755
2756  def AddBackpropIndexedSlicesAccumulator(self, op, grad):
2757    """This is used for accumulating gradients that are IndexedSlices.
2758
2759    This is essentially the equivalent of AddBackpropAccumulator but optimized
2760    for things like updating embeddings from within a while loop.
2761
2762    Args:
2763      op: The Enter op for a loop invariant.
2764      grad: The partial gradients represented as an IndexedSlices.
2765
2766    Returns:
2767      The accumulated IndexedSlices gradient of the loop invariant.
2768    """
2769    values = grad.values
2770    indices = grad.indices
2771    dense_shape = grad.dense_shape
2772
2773    self.Exit()
2774    if self.outer_context:
2775      self.outer_context.Enter()
2776    if values.get_shape().is_fully_defined():
2777      values_shape = tensor_shape.TensorShape(
2778          [tensor_shape.Dimension(1)] + values.get_shape().dims[1:])
2779      if self.outer_context:
2780        self.outer_context.Enter()
2781      values_acc = constant_op.constant(
2782          0, values.dtype, shape=values_shape, name="b_acc")
2783      if self.outer_context:
2784        self.outer_context.Exit()
2785    else:
2786      values_shape = _resource_safe_shape(op.inputs[0])[1:]
2787      values_shape = array_ops.concat([[1], values_shape], 0)
2788      values_acc = array_ops.zeros(values_shape, dtype=values.dtype)
2789    indices_acc = constant_op.constant([0], indices.dtype)
2790    shape_acc = None
2791    if dense_shape is not None:
2792      if dense_shape.get_shape().is_fully_defined():
2793        if self.outer_context:
2794          self.outer_context.Enter()
2795        shape_acc = constant_op.constant(
2796            0, dense_shape.dtype, shape=dense_shape.get_shape())
2797        if self.outer_context:
2798          self.outer_context.Exit()
2799      else:
2800        shape_acc = array_ops.zeros_like(
2801            array_ops.shape_internal(op.inputs[0], optimize=False),
2802            optimize=False)
2803
2804    if self.outer_context:
2805      self.outer_context.Exit()
2806
2807    self.Enter()
2808    self.AddName(values_acc.name)
2809    self.AddName(indices_acc.name)
2810    init_acc = [indices_acc, values_acc]
2811    if shape_acc is not None:
2812      self.AddName(shape_acc.name)
2813      init_acc.append(shape_acc)
2814
2815    # Set use_input_shape=False since the accumulator tensors will grow in
2816    # size. If use_input_shape=True, the _update_input call below will result in
2817    # incompatible shapes.
2818    enter_acc = [
2819        _Enter(
2820            x,
2821            self._name,
2822            is_constant=False,
2823            parallel_iterations=self._parallel_iterations,
2824            use_input_shape=False,
2825            name="b_acc") for x in init_acc
2826    ]
2827    # Manually set appropriate partial shapes.
2828    enter_acc[0].set_shape([None])
2829    if values_acc.shape.dims is not None:
2830      enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:])
2831    self.loop_enters.extend(enter_acc)
2832
2833    merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc]
2834    switch_acc = [switch(x, self._pivot) for x in merge_acc]
2835
2836    # The actual accumulation.
2837    acc_indexed_slices = [
2838        array_ops.concat([xa[1], xv], 0)
2839        for xa, xv in zip(switch_acc[:2], [indices, values])
2840    ]
2841    if shape_acc is not None:
2842      # For the shape we just keep the maximum
2843      acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1]))
2844
2845    next_acc = [_NextIteration(x) for x in acc_indexed_slices]
2846    for xm, xn in zip(merge_acc, next_acc):
2847      xm.op._update_input(1, xn)  # pylint: disable=protected-access
2848
2849    exit_acc = [exit(x[0], name="b_acc") for x in switch_acc]
2850    self.loop_exits.extend(exit_acc)
2851
2852    self.ExitResult(exit_acc)
2853    return ops.IndexedSlices(
2854        indices=exit_acc[0],
2855        values=exit_acc[1],
2856        dense_shape=exit_acc[2] if shape_acc is not None else None)
2857
2858  def _InitializeValues(self, values):
2859    """Makes the values known to this context."""
2860    self._values = set()
2861    for x in values:
2862      if isinstance(x, ops.Tensor):
2863        self._values.add(x.name)
2864      else:
2865        self._values.add(x.values.name)
2866        self._values.add(x.indices.name)
2867        if isinstance(x, ops.IndexedSlices):
2868          dense_shape = x.dense_shape
2869        elif isinstance(x, sparse_tensor.SparseTensor):
2870          dense_shape = x.dense_shape
2871        else:
2872          raise TypeError("Type %s not supported" % type(x))
2873        if dense_shape is not None:
2874          self._values.add(dense_shape.name)
2875
2876  def _BuildLoop(self, pred, body, original_loop_vars, loop_vars,
2877                 shape_invariants):
2878    """Core: Add the loop termination condition and body to the graph."""
2879    flat_loop_vars = nest.flatten(original_loop_vars)
2880
2881    # Let the context know the loop variables so the loop variables
2882    # would be added in the outer contexts properly.
2883    self._InitializeValues(loop_vars)
2884    real_vars = loop_vars
2885    if self._outer_context:
2886      real_vars = [self._outer_context.AddValue(x) for x in loop_vars]
2887    with ops.control_dependencies(None):
2888      enter_vars = [
2889          _Enter(
2890              x,
2891              self._name,
2892              is_constant=False,
2893              parallel_iterations=self._parallel_iterations,
2894              use_input_shape=(shape_invariants is None)) for x in real_vars
2895      ]
2896      for x in enter_vars:
2897        x.graph.prevent_feeding(x)
2898        if self._outer_context:
2899          self._outer_context.AddInnerOp(x.op)
2900
2901    # Finds the closest enclosing non-None control pivot.
2902    outer_context = self._outer_context
2903    control_pivot = None
2904    while outer_context is not None and control_pivot is None:
2905      control_pivot = outer_context.GetControlPivot()
2906      # pylint: disable=protected-access
2907      outer_context = outer_context._outer_context
2908      # pylint: enable=protected-access
2909
2910    if control_pivot is not None:
2911      for var in enter_vars:
2912        if util.IsLoopConstantEnter(var.op.inputs[0].op):
2913          # pylint: disable=protected-access
2914          var.op._add_control_input(control_pivot.op)
2915          # pylint: enable=protected-access
2916    _SetShapeInvariants(real_vars, enter_vars, shape_invariants)
2917
2918    # Fix the control inputs and control flow context of these enter ops.
2919    self._FixControlInputsAndContext(enter_vars)
2920    self._InitializeValues(enter_vars)
2921    self._loop_enters = enter_vars
2922
2923    merge_vars = [merge([x, x])[0] for x in enter_vars]
2924    self._pivot_for_pred = merge_vars[0]
2925
2926    # Build the graph for pred.
2927    merge_vars_with_tensor_arrays = (
2928        _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars))
2929    packed_vars = nest.pack_sequence_as(
2930        structure=original_loop_vars,
2931        flat_sequence=merge_vars_with_tensor_arrays)
2932    c = ops.convert_to_tensor(pred(*packed_vars))
2933    self._pivot = loop_cond(c, name="LoopCond")
2934    switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]
2935
2936    # Build the graph for body.
2937    vars_for_body = [_Identity(x[1]) for x in switch_vars]
2938    self._pivot_for_body = vars_for_body[0]
2939    # Convert TensorArray flow variables inside the context back into
2940    # their associated TensorArrays for calling the body.
2941    vars_for_body_with_tensor_arrays = (
2942        _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body))
2943    packed_vars_for_body = nest.pack_sequence_as(
2944        structure=original_loop_vars,
2945        flat_sequence=vars_for_body_with_tensor_arrays)
2946    pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2947    body_result = body(*packed_vars_for_body)
2948    post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2949    if not nest.is_sequence(body_result):
2950      body_result = [body_result]
2951    if len(post_summaries) > len(pre_summaries):
2952      new_summaries = post_summaries[len(pre_summaries):]
2953      summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
2954      summary_ref[:] = pre_summaries
2955      with ops.control_dependencies(new_summaries):
2956
2957        def map_fn(x):
2958          # TODO(apassos) figure out how to trigger with tensor arrays as well
2959          if isinstance(x, tensor_array_ops.TensorArray):
2960            return x
2961          return array_ops.identity(x)
2962
2963        body_result = nest.map_structure(map_fn, body_result)
2964
2965    # Compare the structure types of input and output of body.
2966    # For backwards compatibility, the first layer is forced to a list
2967    # during this comparison, because inputs are typically lists and
2968    # outputs of the body are typically tuples.
2969    nest.assert_same_structure(list(packed_vars_for_body), list(body_result))
2970
2971    # Store body_result to keep track of TensorArrays returned by body
2972    original_body_result = body_result
2973    # Convert TensorArrays returned by body into their flow variables
2974    result = nest.map_structure(_convert_tensorarray_to_flow,
2975                                nest.flatten(body_result))
2976    result = ops.convert_n_to_tensor_or_indexed_slices(result)
2977
2978    # Add NextIteration and the back edges to complete the loop.
2979    if len(merge_vars) != len(result):
2980      raise ValueError("Number of inputs and outputs of body must match "
2981                       "loop_vars: %d, %d" % (len(merge_vars), len(result)))
2982    next_vars = []
2983    for m, v in zip(merge_vars, result):
2984      next_vars.append(_AddNextAndBackEdge(m, v))
2985
2986    # Add the exit ops.
2987    exit_vars = [exit(x[0]) for x in switch_vars]
2988    self._loop_exits = exit_vars
2989
2990    # Exit the loop.
2991    self.ExitResult(exit_vars)
2992
2993    return original_body_result, exit_vars
2994
2995  def BuildLoop(self, pred, body, loop_vars, shape_invariants):
2996    """Add the loop termination condition and body to the graph."""
2997
2998    # Keep original_loop_vars to identify which are TensorArrays
2999    original_loop_vars = loop_vars
3000    # Convert TensorArrays to their flow variables
3001    loop_vars = nest.map_structure(_convert_tensorarray_to_flow,
3002                                   nest.flatten(loop_vars))
3003    loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars)
3004    try:
3005      self.Enter()
3006      original_body_result, exit_vars = self._BuildLoop(
3007          pred, body, original_loop_vars, loop_vars, shape_invariants)
3008    finally:
3009      self.Exit()
3010
3011    flat_result = nest.flatten(original_body_result)
3012    # Convert TensorArray flow variables outside the context back into
3013    # their associated TensorArrays for returning to caller.
3014    exit_vars_with_tensor_arrays = (
3015        _convert_flows_to_tensorarrays(flat_result, exit_vars))
3016    packed_exit_vars = nest.pack_sequence_as(
3017        structure=original_body_result,
3018        flat_sequence=exit_vars_with_tensor_arrays)
3019    return (packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars)
3020
3021  def _FixControlInputsAndContext(self, enters):
3022    graph = ops.get_default_graph()
3023    # pylint: disable=protected-access
3024    for e in enters:
3025      if isinstance(e, ops.Tensor):
3026        xs = [e]
3027      else:
3028        if not isinstance(e, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
3029          raise TypeError("Type %s not supported" % type(e))
3030        xs = [e.values, e.indices]
3031        shape = e.dense_shape
3032        if shape is not None:
3033          xs.append(shape)
3034      for x in xs:
3035        inp_op = x.op.inputs[0].op
3036        control_inputs = graph._control_dependencies_for_inputs([inp_op])
3037        outer_control_inputs = [
3038            op for op in control_inputs if self._IsInOuterContext(op)
3039        ]
3040        x.op._set_control_flow_context(self)
3041        x.op._add_control_inputs(outer_control_inputs)
3042        graph._record_op_seen_by_control_dependencies(x.op)
3043    # pylint: enable=protected-access
3044
3045  def IsWhileContext(self):
3046    return True
3047
3048
3049# pylint: disable=redefined-outer-name
3050@tf_export("while_loop")
3051def while_loop(cond,
3052               body,
3053               loop_vars,
3054               shape_invariants=None,
3055               parallel_iterations=10,
3056               back_prop=True,
3057               swap_memory=False,
3058               name=None,
3059               maximum_iterations=None):
3060  """Repeat `body` while the condition `cond` is true.
3061
3062  `cond` is a callable returning a boolean scalar tensor. `body` is a callable
3063  returning a (possibly nested) tuple, namedtuple or list of tensors of the same
3064  arity (length and structure) and types as `loop_vars`. `loop_vars` is a
3065  (possibly nested) tuple, namedtuple or list of tensors that is passed to both
3066  `cond` and `body`. `cond` and `body` both take as many arguments as there are
3067  `loop_vars`.
3068
3069  In addition to regular Tensors or IndexedSlices, the body may accept and
3070  return TensorArray objects.  The flows of the TensorArray objects will
3071  be appropriately forwarded between loops and during gradient calculations.
3072
3073  Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
3074  call to `while_loop`, and not at all during `Session.run()`). `while_loop`
3075  stitches together the graph fragments created during the `cond` and `body`
3076  calls with some additional graph nodes to create the graph flow that
3077  repeats `body` until `cond` returns false.
3078
3079  For correctness, `tf.while_loop()` strictly enforces shape invariants for
3080  the loop variables. A shape invariant is a (possibly partial) shape that
3081  is unchanged across the iterations of the loop. An error will be raised
3082  if the shape of a loop variable after an iteration is determined to be more
3083  general than or incompatible with its shape invariant. For example, a shape
3084  of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
3085  compatible with [11, 17]. By default (if the argument `shape_invariants` is
3086  not specified), it is assumed that the initial shape of each tensor in
3087  `loop_vars` is the same in every iteration. The `shape_invariants` argument
3088  allows the caller to specify a less specific shape invariant for each loop
3089  variable, which is needed if the shape varies between iterations. The
3090  @{tf.Tensor.set_shape}
3091  function may also be used in the `body` function to indicate that
3092  the output loop variable has a particular shape. The shape invariant for
3093  SparseTensor and IndexedSlices are treated specially as follows:
3094
3095  a) If a loop variable is a SparseTensor, the shape invariant must be
3096  TensorShape([r]) where r is the rank of the dense tensor represented
3097  by the sparse tensor. It means the shapes of the three tensors of the
3098  SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
3099  is the shape of the SparseTensor.dense_shape property. It must be the shape of
3100  a vector.
3101
3102  b) If a loop variable is an IndexedSlices, the shape invariant must be
3103  a shape invariant of the values tensor of the IndexedSlices. It means
3104  the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
3105  [shape.ndims]).
3106
3107  `while_loop` implements non-strict semantics, enabling multiple iterations
3108  to run in parallel. The maximum number of parallel iterations can be
3109  controlled by `parallel_iterations`, which gives users some control over
3110  memory consumption and execution order. For correct programs, `while_loop`
3111  should return the same result for any parallel_iterations > 0.
3112
3113  For training, TensorFlow stores the tensors that are produced in the
3114  forward inference and are needed in back propagation. These tensors are a
3115  main source of memory consumption and often cause OOM errors when training
3116  on GPUs. When the flag swap_memory is true, we swap out these tensors from
3117  GPU to CPU. This for example allows us to train RNN models with very long
3118  sequences and large batches.
3119
3120  Args:
3121    cond: A callable that represents the termination condition of the loop.
3122    body: A callable that represents the loop body.
3123    loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
3124      `Tensor`, and `TensorArray` objects.
3125    shape_invariants: The shape invariants for the loop variables.
3126    parallel_iterations: The number of iterations allowed to run in parallel.
3127      It must be a positive integer.
3128    back_prop: Whether backprop is enabled for this while loop.
3129    swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
3130    name: Optional name prefix for the returned tensors.
3131    maximum_iterations: Optional maximum number of iterations of the while loop
3132      to run.  If provided, the `cond` output is AND-ed with an additional
3133      condition ensuring the number of iterations executed is no greater than
3134      `maximum_iterations`.
3135
3136  Returns:
3137    The output tensors for the loop variables after the loop. When the length
3138    of `loop_vars` is 1 this is a Tensor, TensorArray or IndexedSlice and when
3139    the length of `loop_vars` is greater than 1 it returns a list.
3140
3141  Raises:
3142    TypeError: if `cond` or `body` is not callable.
3143    ValueError: if `loop_vars` is empty.
3144
3145  Example:
3146
3147  ```python
3148  i = tf.constant(0)
3149  c = lambda i: tf.less(i, 10)
3150  b = lambda i: tf.add(i, 1)
3151  r = tf.while_loop(c, b, [i])
3152  ```
3153
3154  Example with nesting and a namedtuple:
3155
3156  ```python
3157  import collections
3158  Pair = collections.namedtuple('Pair', 'j, k')
3159  ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
3160  c = lambda i, p: i < 10
3161  b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
3162  ijk_final = tf.while_loop(c, b, ijk_0)
3163  ```
3164
3165  Example using shape_invariants:
3166
3167  ```python
3168  i0 = tf.constant(0)
3169  m0 = tf.ones([2, 2])
3170  c = lambda i, m: i < 10
3171  b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
3172  tf.while_loop(
3173      c, b, loop_vars=[i0, m0],
3174      shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
3175  ```
3176
3177  Example which demonstrates non-strict semantics: In the following
3178  example, the final value of the counter `i` does not depend on `x`. So
3179  the `while_loop` can increment the counter parallel to updates of `x`.
3180  However, because the loop counter at one loop iteration depends
3181  on the value at the previous iteration, the loop counter itself cannot
3182  be incremented in parallel. Hence if we just want the final value of the
3183  counter (which we print on the line `print(sess.run(i))`), then
3184  `x` will never be incremented, but the counter will be updated on a
3185  single thread. Conversely, if we want the value of the output (which we
3186  print on the line `print(sess.run(out).shape)`), then the counter may be
3187  incremented on its own thread, while `x` can be incremented in
3188  parallel on a separate thread. In the extreme case, it is conceivable
3189  that the thread incrementing the counter runs until completion before
3190  `x` is incremented even a single time. The only thing that can never
3191  happen is that the thread updating `x` can never get ahead of the
3192  counter thread because the thread incrementing `x` depends on the value
3193  of the counter.
3194  ```python
3195  import tensorflow as tf
3196
3197  n = 10000
3198  x = tf.constant(list(range(n)))
3199  c = lambda i, x: i < n
3200  b = lambda i, x: (tf.Print(i + 1, [i]), tf.Print(x + 1, [i], "x:"))
3201  i, out = tf.while_loop(c, b, (0, x))
3202  with tf.Session() as sess:
3203      print(sess.run(i))  # prints [0] ... [9999]
3204
3205      # The following line may increment the counter and x in parallel.
3206      # The counter thread may get ahead of the other thread, but not the
3207      # other way around. So you may see things like
3208      # [9996] x:[9987]
3209      # meaning that the counter thread is on iteration 9996,
3210      # while the other thread is on iteration 9987
3211      print(sess.run(out).shape)
3212  ```
3213
3214  """
3215  with ops.name_scope(name, "while", loop_vars):
3216    if not loop_vars:
3217      raise ValueError("No loop variables provided")
3218    if not callable(cond):
3219      raise TypeError("cond must be callable.")
3220    if not callable(body):
3221      raise TypeError("body must be callable.")
3222    if parallel_iterations < 1:
3223      raise TypeError("parallel_iterations must be a positive integer.")
3224
3225    if maximum_iterations is not None:
3226      maximum_iterations = ops.convert_to_tensor(
3227          maximum_iterations, name="maximum_iterations")
3228      if maximum_iterations.shape.ndims != 0:
3229        raise ValueError("maximum_iterations must be a scalar, saw shape: %s" %
3230                         maximum_iterations.shape)
3231
3232      counter = constant_op.constant(
3233          0, dtype=maximum_iterations.dtype, name="iteration_counter")
3234      orig_cond = cond
3235      orig_body = body
3236      if len(loop_vars) == 1:
3237        loop_vars = (counter, loop_vars[0])
3238        cond = lambda i, lv: (  # pylint: disable=g-long-lambda
3239            math_ops.logical_and(i < maximum_iterations, orig_cond(lv)))
3240        body = lambda i, lv: (i + 1, orig_body(lv))
3241      else:
3242        loop_vars = (counter, loop_vars)
3243        cond = lambda i, lv: (  # pylint: disable=g-long-lambda
3244            math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
3245        body = lambda i, lv: (i + 1, orig_body(*lv))
3246
3247    if context.in_eager_mode():
3248      while cond(*loop_vars):
3249        loop_vars = body(*loop_vars)
3250      if maximum_iterations is not None:
3251        return loop_vars[1]
3252      else:
3253        return loop_vars
3254
3255    if shape_invariants is not None:
3256      if maximum_iterations is not None:
3257        shape_invariants = (tensor_shape.TensorShape([]), shape_invariants)
3258      nest.assert_same_structure(loop_vars, shape_invariants)
3259
3260    loop_context = WhileContext(
3261        maximum_iterations=maximum_iterations,
3262        parallel_iterations=parallel_iterations,
3263        back_prop=back_prop,
3264        swap_memory=swap_memory)
3265    # Only add non-nested loops to the collection. Any nested control flow will
3266    # be encapsulated in the root context.
3267    # TODO(b/72868227): enable condition once the corresponding
3268    # control_flow.proto changes have been checked in (they aren't checked in
3269    # and this is disabled for now to ensure forwards compatibility).
3270    if True or loop_context.outer_context is None:
3271      ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
3272    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
3273    if maximum_iterations is not None:
3274      return result[1]
3275    else:
3276      return result
3277
3278
3279# pylint: enable=redefined-outer-name
3280
3281
3282def _AsTensorList(x, p):
3283  """Return x as a list of Tensors or IndexedSlices.
3284
3285  For entries of `x` that are Operations, this returns an Identity of `p`
3286  with a dependency on the operation.
3287
3288  Args:
3289    x: A Tensor/IndexedSlices/Operation or a list or tuple of them.
3290    p: A Tensor to return for entries in `x` that are Operations.
3291
3292  Returns:
3293    A list of Tensors or IndexedSlices.
3294  """
3295  if not isinstance(x, (list, _basetuple)):
3296    x = [x]
3297
3298  l = []
3299  for v in x:
3300    if isinstance(v, ops.Operation):
3301      v = with_dependencies([v], p)
3302    v = ops.convert_to_tensor_or_indexed_slices(v)
3303    if isinstance(v, ops.Tensor):
3304      l.append(array_ops.identity(v))
3305    else:
3306      l.append(
3307          ops.IndexedSlices(
3308              array_ops.identity(v.values), array_ops.identity(v.indices)))
3309  return l
3310
3311
3312def _CheckResults(a, b):
3313  assert len(a) == len(b), (
3314      "Values returned by a() and b() must have the same length.")
3315  for x, y in zip(a, b):
3316    assert x.dtype == y.dtype, (
3317        "Values returned by a() [%s] and b() [%s] must have "
3318        "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name))
3319
3320
3321def with_dependencies(dependencies, output_tensor, name=None):
3322  """Produces the content of `output_tensor` only after `dependencies`.
3323
3324  In some cases, a user may want the output of an operation to be
3325  consumed externally only after some other dependencies have run
3326  first. This function ensures returns `output_tensor`, but only after all
3327  operations in `dependencies` have run. Note that this means that there is
3328  no guarantee that `output_tensor` will be evaluated after any `dependencies`
3329  have run.
3330
3331  See also @{tf.tuple$tuple} and @{tf.group$group}.
3332
3333  Args:
3334    dependencies: Iterable of operations to run before this op finishes.
3335    output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
3336    name: (Optional) A name for this operation.
3337
3338  Returns:
3339    Same as `output_tensor`.
3340
3341  Raises:
3342    TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
3343  """
3344  if context.in_eager_mode():
3345    return output_tensor
3346  with ops.name_scope(name, "control_dependency",
3347                      list(dependencies) + [output_tensor]) as name:
3348    with ops.colocate_with(output_tensor):
3349      with ops.control_dependencies(dependencies):
3350        output_tensor = ops.convert_to_tensor_or_indexed_slices(output_tensor)
3351        if isinstance(output_tensor, ops.Tensor):
3352          return _Identity(output_tensor, name=name)
3353        else:
3354          return ops.IndexedSlices(
3355              _Identity(output_tensor.values, name=name), output_tensor.indices,
3356              output_tensor.dense_shape)
3357
3358
3359def _GroupControlDeps(dev, deps, name=None):
3360  with ops.control_dependencies(deps):
3361    if dev is None:
3362      return no_op(name=name)
3363    else:
3364      with ops.device(dev):
3365        return no_op(name=name)
3366
3367
3368# TODO(touts): Accept "inputs" as a list.
3369@tf_export("group")
3370def group(*inputs, **kwargs):
3371  """Create an op that groups multiple operations.
3372
3373  When this op finishes, all ops in `inputs` have finished. This op has no
3374  output.
3375
3376  See also @{tf.tuple$tuple} and
3377  @{tf.control_dependencies$control_dependencies}.
3378
3379  Args:
3380    *inputs: Zero or more tensors to group.
3381    name: A name for this operation (optional).
3382
3383  Returns:
3384    An Operation that executes all its inputs.
3385
3386  Raises:
3387    ValueError: If an unknown keyword argument is provided.
3388  """
3389  if context.in_eager_mode():
3390    return None
3391  name = kwargs.pop("name", None)
3392  if kwargs:
3393    raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys()))
3394  with ops.name_scope(name, "group_deps", inputs) as name:
3395    # Grouping no inputs means do nothing
3396    if not inputs:
3397      return no_op(name=name)
3398
3399    # Sorts *inputs according to their devices.
3400    ops_on_device = {}  # device -> operations specified on the device.
3401    for inp in nest.flatten(inputs):
3402      if not hasattr(inp, "device"):
3403        raise TypeError("Expected tf.group() expected Tensor arguments not "
3404                        "'%s' with type '%s'" % (inp, type(inp)))
3405      if not hasattr(inp, "device"):
3406        if isinstance(inp, list):
3407          raise TypeError("To call tf.group() with a list, use "
3408                          "tf.group(*[...]) not tf.group([...]).")
3409        raise TypeError("Expected tf.group() expected Tensor arguments not "
3410                        "'%s' with type '%s'" % (inp, type(inp)))
3411      dev = inp.device
3412      if dev in ops_on_device:
3413        ops_on_device[dev].append(inp)
3414      else:
3415        ops_on_device[dev] = [inp]
3416    if len(ops_on_device) == 1:
3417      # 1-level tree. The root node is the returned NoOp node.
3418      (dev, deps), = ops_on_device.items()
3419      return _GroupControlDeps(dev, deps, name=name)
3420
3421    # 2-level tree. The root node is the returned NoOp node.
3422    # deps contains 1 NoOp node for each device.
3423    deps = []
3424
3425    def device_key(dev):
3426      """A sort key that allows None to be compared to strings."""
3427      return "" if dev is None else dev
3428
3429    for dev in sorted(six.iterkeys(ops_on_device), key=device_key):
3430      deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
3431
3432    with ops.control_dependencies(deps):
3433      return no_op(name=name)
3434
3435
3436@tf_export("tuple")
3437def tuple(tensors, name=None, control_inputs=None):  # pylint: disable=redefined-builtin
3438  """Group tensors together.
3439
3440  This creates a tuple of tensors with the same values as the `tensors`
3441  argument, except that the value of each tensor is only returned after the
3442  values of all tensors have been computed.
3443
3444  `control_inputs` contains additional ops that have to finish before this op
3445  finishes, but whose outputs are not returned.
3446
3447  This can be used as a "join" mechanism for parallel computations: all the
3448  argument tensors can be computed in parallel, but the values of any tensor
3449  returned by `tuple` are only available after all the parallel computations
3450  are done.
3451
3452  See also @{tf.group$group} and
3453  @{tf.control_dependencies$control_dependencies}.
3454
3455  Args:
3456    tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
3457    name: (optional) A name to use as a `name_scope` for the operation.
3458    control_inputs: List of additional ops to finish before returning.
3459
3460  Returns:
3461    Same as `tensors`.
3462
3463  Raises:
3464    ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
3465    TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
3466      objects.
3467
3468  """
3469  if context.in_eager_mode():
3470    return tensors
3471  with ops.name_scope(name, "tuple", tensors) as name:
3472    gating_ops = [t.op for t in tensors if t is not None]
3473    if control_inputs:
3474      for c in control_inputs:
3475        if isinstance(c, ops.Tensor):
3476          c = c.op
3477        elif not isinstance(c, ops.Operation):
3478          raise TypeError("Control input must be Operation or Tensor: %s" % c)
3479        gating_ops.append(c)
3480    # Note that in order to ensure ordering in the pbtxt, we must take care to
3481    # ensure the order here.
3482    gating_ops = sorted(set(gating_ops), key=lambda op: op._id)  # Uniquify ops.
3483    if not gating_ops:
3484      raise ValueError("Must have at least one Tensor: %s" % tensors)
3485    gate = group(*gating_ops)
3486    tpl = []
3487    for t in tensors:
3488      if t is not None:
3489        tpl.append(with_dependencies([gate], t))
3490      else:
3491        tpl.append(None)
3492    return tpl
3493
3494
3495def _assert_at_most_n_true(predicates, n, msg):
3496  """Returns an Assert op that checks that at most n predicates are True.
3497
3498  Args:
3499    predicates: list of bool scalar tensors.
3500    n: maximum number of true predicates allowed.
3501    msg: Error message.
3502  """
3503  preds_c = array_ops.stack(predicates, name="preds_c")
3504  num_true_conditions = math_ops.reduce_sum(
3505      math_ops.cast(preds_c, dtypes.int32), name="num_true_conds")
3506  condition = math_ops.less_equal(num_true_conditions,
3507                                  constant_op.constant(n, name="n_true_conds"))
3508  preds_names = ", ".join(getattr(p, "name", "?") for p in predicates)
3509  error_msg = [
3510      "%s: more than %d conditions (%s) evaluated as True:" %
3511      (msg, n, preds_names), preds_c
3512  ]
3513  return Assert(condition, data=error_msg, summarize=len(predicates))
3514
3515
3516def _case_create_default_action(predicates, actions):
3517  """Creates default action for a list of actions and their predicates.
3518
3519  It uses the input actions to select an arbitrary as default and makes sure
3520  that corresponding predicates have valid values.
3521
3522  Args:
3523    predicates: a list of bool scalar tensors
3524    actions: a list of callable objects which return tensors.
3525
3526  Returns:
3527    a callable
3528  """
3529  k = len(predicates) - 1  # could pick any
3530  predicate, action = predicates[k], actions[k]
3531  other_predicates, other_actions = predicates[:k], actions[:k]
3532
3533  def default_action():
3534    others_msg = ("Implementation error: "
3535                  "selected default action #%d was called, but some of other "
3536                  "predicates are True: " % k)
3537    default_msg = ("Input error: "
3538                   "None of conditions evaluated as True:",
3539                   array_ops.stack(predicates, name="preds_c"))
3540    with ops.control_dependencies([
3541        _assert_at_most_n_true(other_predicates, n=0, msg=others_msg),
3542        Assert(predicate, data=default_msg)
3543    ]):
3544      return action()
3545
3546  return default_action, other_predicates, other_actions
3547
3548
3549def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name):
3550  """Verifies input arguments for the case function.
3551
3552  Args:
3553    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
3554                   callable which returns a list of tensors.
3555    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3556    name: A name for the case operation.
3557
3558  Raises:
3559    TypeError: If `pred_fn_pairs` is not a list/dictionary.
3560    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3561    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3562               callable.
3563
3564  Returns:
3565    a tuple <list of scalar bool tensors, list of callables>.
3566  """
3567  if not isinstance(pred_fn_pairs, (list, _basetuple, dict)):
3568    raise TypeError("fns must be a list, tuple, or dict")
3569
3570  if isinstance(pred_fn_pairs, collections.OrderedDict):
3571    pred_fn_pairs = pred_fn_pairs.items()
3572  elif isinstance(pred_fn_pairs, dict):
3573    pred_fn_pairs = sorted(pred_fn_pairs.items(), key=lambda item: item[0].name)
3574    if not exclusive:
3575      logging.warn("%s: An unordered dictionary of predicate/fn pairs was "
3576                   "provided, but exclusive=False. The order of conditional "
3577                   "tests is deterministic but not guaranteed.", name)
3578  for pred_fn_pair in pred_fn_pairs:
3579    if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2:
3580      raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple")
3581    pred, fn = pred_fn_pair
3582    if pred.dtype != dtypes.bool:
3583      raise TypeError("pred must be of type bool: %s", pred.name)
3584    if not callable(fn):
3585      raise TypeError("fn for pred %s must be callable." % pred.name)
3586  predicates, actions = zip(*pred_fn_pairs)
3587  return predicates, actions
3588
3589
3590@tf_export("case")
3591def case(pred_fn_pairs,
3592         default=None,
3593         exclusive=False,
3594         strict=False,
3595         name="case"):
3596  """Create a case operation.
3597
3598  The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
3599  Each pair contains a boolean scalar tensor and a python callable that
3600  creates the tensors to be returned if the boolean evaluates to True.
3601  `default` is a callable generating a list of tensors. All the callables
3602  in `pred_fn_pairs` as well as `default` (if provided) should return the same
3603  number and types of tensors.
3604
3605  If `exclusive==True`, all predicates are evaluated, and an exception is
3606  thrown if more than one of the predicates evaluates to `True`.
3607  If `exclusive==False`, execution stops at the first predicate which
3608  evaluates to True, and the tensors generated by the corresponding function
3609  are returned immediately. If none of the predicates evaluate to True, this
3610  operation returns the tensors generated by `default`.
3611
3612  `tf.case` supports nested structures as implemented in
3613  `tensorflow.python.util.nest`. All of the callables must return the same
3614  (possibly nested) value structure of lists, tuples, and/or named tuples.
3615  Singleton lists and tuples form the only exceptions to this: when returned by
3616  a callable, they are implicitly unpacked to single values. This
3617  behavior is disabled by passing `strict=True`.
3618
3619  If an unordered dictionary is used for `pred_fn_pairs`, the order of the
3620  conditional tests is not guaranteed. However, the order is guaranteed to be
3621  deterministic, so that variables created in conditional branches are created
3622  in fixed order across runs.
3623
3624  **Example 1:**
3625
3626  Pseudocode:
3627
3628  ```
3629  if (x < y) return 17;
3630  else return 23;
3631  ```
3632
3633  Expressions:
3634
3635  ```python
3636  f1 = lambda: tf.constant(17)
3637  f2 = lambda: tf.constant(23)
3638  r = case([(tf.less(x, y), f1)], default=f2)
3639  ```
3640
3641  **Example 2:**
3642
3643  Pseudocode:
3644
3645  ```
3646  if (x < y && x > z) raise OpError("Only one predicate may evaluate true");
3647  if (x < y) return 17;
3648  else if (x > z) return 23;
3649  else return -1;
3650  ```
3651
3652  Expressions:
3653
3654  ```python
3655  def f1(): return tf.constant(17)
3656  def f2(): return tf.constant(23)
3657  def f3(): return tf.constant(-1)
3658  r = case({tf.less(x, y): f1, tf.greater(x, z): f2},
3659           default=f3, exclusive=True)
3660  ```
3661
3662  Args:
3663    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
3664                   callable which returns a list of tensors.
3665    default: Optional callable that returns a list of tensors.
3666    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
3667    strict: A boolean that enables/disables 'strict' mode; see above.
3668    name: A name for this operation (optional).
3669
3670  Returns:
3671    The tensors returned by the first pair whose predicate evaluated to True, or
3672    those returned by `default` if none does.
3673
3674  Raises:
3675    TypeError: If `pred_fn_pairs` is not a list/dictionary.
3676    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
3677    TypeError: If `fns[i]` is not callable for any i, or `default` is not
3678               callable.
3679  """
3680  predicates, actions = _case_verify_and_canonicalize_args(
3681      pred_fn_pairs, exclusive, name)
3682  with ops.name_scope(name, "case", [predicates]):
3683    if default is None:
3684      default, predicates, actions = _case_create_default_action(
3685          predicates, actions)
3686    fn = default
3687    # To eval conditions in direct order we create nested conditions in reverse:
3688    #   cond(c[0], true_fn=.., false_fn=cond(c[1], ...))
3689    for predicate, action in reversed(list(zip(predicates, actions))):
3690      fn = functools.partial(
3691          cond, predicate, true_fn=action, false_fn=fn, strict=strict)
3692    if exclusive:
3693      with ops.control_dependencies([
3694          _assert_at_most_n_true(
3695              predicates, n=1, msg="Input error: exclusive=True")
3696      ]):
3697        return fn()
3698    else:
3699      return fn()
3700
3701
3702class XLAControlFlowContext(ControlFlowContext):
3703  """Base class for XLA and TPU control flow contexts."""
3704
3705  def __init__(self):
3706    super(XLAControlFlowContext, self).__init__()
3707    self._name = "XLAControlFlowContext"
3708
3709  def IsXLAContext(self):
3710    return True
3711
3712  def AddOp(self, _):
3713    pass
3714
3715  def AddValue(self, x):
3716    return x
3717
3718
3719def from_control_flow_context_def(context_def, import_scope=None):
3720  """Deserializes `context_def` into the appropriate ControlFlowContext.
3721
3722  Args:
3723    context_def: ControlFlowContextDef proto
3724    import_scope: Optional `string`. Name scope to add.
3725
3726  Returns:
3727    A ControlFlowContext subclass
3728  """
3729  if context_def.HasField("cond_ctxt"):
3730    return CondContext.from_proto(context_def.cond_ctxt,
3731                                  import_scope=import_scope)
3732  if context_def.HasField("while_ctxt"):
3733    return WhileContext.from_proto(context_def.while_ctxt,
3734                                   import_scope=import_scope)
3735  raise NotImplementedError("Unknown ControlFlowContextDef field: %s"
3736                            % context_def.WhichOneof("ctxt"))
3737
3738
3739ops.register_proto_function(
3740    ops.GraphKeys.COND_CONTEXT,
3741    proto_type=control_flow_pb2.CondContextDef,
3742    to_proto=CondContext.to_proto,
3743    from_proto=CondContext.from_proto)
3744
3745ops.register_proto_function(
3746    ops.GraphKeys.WHILE_CONTEXT,
3747    proto_type=control_flow_pb2.WhileContextDef,
3748    to_proto=WhileContext.to_proto,
3749    from_proto=WhileContext.from_proto)
3750