1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OiR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=g-long-lambda
17"""Tests for tensorflow.ops.control_flow_ops."""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import collections
24import math
25import sys
26import time
27
28import numpy as np
29from six.moves import xrange  # pylint: disable=redefined-builtin
30
31from tensorflow.core.protobuf import config_pb2
32from tensorflow.python.client import device_lib
33from tensorflow.python.client import session
34from tensorflow.python.eager import context
35from tensorflow.python.eager import def_function
36from tensorflow.python.eager import function as eager_function
37from tensorflow.python.eager import wrap_function
38from tensorflow.python.framework import constant_op
39from tensorflow.python.framework import dtypes
40from tensorflow.python.framework import errors_impl
41from tensorflow.python.framework import function
42from tensorflow.python.framework import ops
43from tensorflow.python.framework import sparse_tensor
44from tensorflow.python.framework import tensor_shape
45from tensorflow.python.framework import test_util
46from tensorflow.python.ops import array_ops
47from tensorflow.python.ops import control_flow_ops
48from tensorflow.python.ops import control_flow_util
49from tensorflow.python.ops import data_flow_ops
50from tensorflow.python.ops import functional_ops
51from tensorflow.python.ops import gen_array_ops
52from tensorflow.python.ops import gen_control_flow_ops
53from tensorflow.python.ops import gen_data_flow_ops
54from tensorflow.python.ops import gen_logging_ops
55from tensorflow.python.ops import gen_state_ops
56from tensorflow.python.ops import gradient_checker_v2
57from tensorflow.python.ops import gradients_impl
58from tensorflow.python.ops import init_ops
59from tensorflow.python.ops import linalg_ops
60from tensorflow.python.ops import logging_ops
61from tensorflow.python.ops import map_fn
62from tensorflow.python.ops import math_ops
63from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
64from tensorflow.python.ops import nn_ops
65from tensorflow.python.ops import random_ops
66from tensorflow.python.ops import resource_variable_ops
67from tensorflow.python.ops import script_ops
68from tensorflow.python.ops import sparse_ops
69from tensorflow.python.ops import state_ops
70from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
71from tensorflow.python.ops import tensor_array_ops
72from tensorflow.python.ops import variable_scope
73from tensorflow.python.ops import variables
74from tensorflow.python.ops import while_v2  # pylint: disable=unused-import
75# pylint: disable=unused-import
76from tensorflow.python.ops.ragged import ragged_factory_ops
77import tensorflow.python.ops.tensor_array_grad
78# pylint: enable=unused-import
79from tensorflow.python.platform import test
80from tensorflow.python.training import adam
81from tensorflow.python.training import gradient_descent
82from tensorflow.python.util import nest
83
84
85def check_consumers(graph):
86  """Sanity check on the consumer list of the tensors."""
87
88  consumer_count = {}
89  for op in graph.get_operations():
90    for v in op.inputs:
91      cnt = consumer_count.get(v, 0)
92      consumer_count[v] = cnt + 1
93  for k, v in consumer_count.items():
94    if len(k.consumers()) != v:
95      return False
96  return True
97
98
99def all_fetchables():
100  tensor_names = []
101  graph = ops.get_default_graph()
102  for op in graph.get_operations():
103    for t in op.outputs:
104      if graph.is_fetchable(t):
105        tensor_names.append(t.name)
106  return tensor_names
107
108
109def all_feedables():
110  feedable_tensors = []
111  graph = ops.get_default_graph()
112  for op in graph.get_operations():
113    for t in op.inputs:
114      if graph.is_feedable(t):
115        feedable_tensors.append(t)
116  return feedable_tensors
117
118
119def opt_cfg():
120  return config_pb2.ConfigProto(
121      allow_soft_placement=True,
122      graph_options=config_pb2.GraphOptions(
123          optimizer_options=config_pb2.OptimizerOptions(
124              opt_level=config_pb2.OptimizerOptions.L1,
125              do_function_inlining=True,
126              do_constant_folding=True)))
127
128
129def isum(s, maximum_iterations=None):
130  i = constant_op.constant(0, name="i")
131  c = lambda i, s: math_ops.less(i, 10)
132  b = lambda i, s: [math_ops.add(i, 1), math_ops.add(i, s)]
133  _, r_s = control_flow_ops.while_loop(
134      c, b, [i, s], maximum_iterations=maximum_iterations)
135  return r_s
136
137
138@test_util.with_control_flow_v2
139class ControlFlowTest(test.TestCase):
140
141  @test_util.run_v1_only("b/120545219")
142  def testRefIdentity(self):
143    with self.cached_session():
144      v = variables.VariableV1(7)
145
146      v = control_flow_ops._Identity(v)
147      op = state_ops.assign(v, 9)
148      v2 = control_flow_ops.with_dependencies([op], v)
149
150      self.assertTrue(isinstance(v2, ops.Tensor))
151      self.evaluate(variables.global_variables_initializer())
152      self.assertEqual(9, self.evaluate(v2))
153
154  @test_util.run_v1_only("b/120545219")
155  def testRefEnter(self):
156    with self.cached_session():
157      v = variables.VariableV1(7)
158
159      enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True)
160      nine = constant_op.constant(9)
161      enter_nine = gen_control_flow_ops.enter(nine, "foo_1")
162      op = state_ops.assign(enter_v, enter_nine)
163      v2 = control_flow_ops.with_dependencies([op], enter_v)
164      v3 = control_flow_ops.exit(v2)
165      self.evaluate(variables.global_variables_initializer())
166      self.assertEqual(9, self.evaluate(v3))
167
168  @test_util.run_v1_only("b/120545219")
169  def testRefSwitch(self):
170    with self.cached_session():
171      v = variables.VariableV1(7)
172
173      p = constant_op.constant(True)
174      v1 = control_flow_ops._SwitchRefOrTensor(v._ref(), p)  # pylint: disable=protected-access
175      v2 = state_ops.assign(v1[1], 9)
176      self.evaluate(variables.global_variables_initializer())
177      self.assertEqual(9, self.evaluate(v2))
178
179  def testEnterMulExit(self):
180    with self.cached_session():
181      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
182      enter_data = gen_control_flow_ops.enter(data, "foo_1", False)
183      five = constant_op.constant(5)
184      enter_five = gen_control_flow_ops.enter(five, "foo_1", False)
185      mul_op = math_ops.multiply(enter_data, enter_five)
186      exit_op = control_flow_ops.exit(mul_op)
187
188      result = self.evaluate(exit_op)
189    self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
190
191  @test_util.run_deprecated_v1
192  def testEnterShapePropagation(self):
193    with self.cached_session():
194      v = variables.Variable([0.0, 0.0], dtype=dtypes.float32)
195
196      # If is_constant=True, the shape information should be propagated.
197      enter_v_constant = gen_control_flow_ops.enter(
198          v, "frame1", is_constant=True)
199      self.assertEqual(enter_v_constant.shape, [2])
200
201      # Otherwise, the shape should be unknown.
202      enter_v_non_constant = gen_control_flow_ops.enter(
203          v, "frame2", is_constant=False)
204      self.assertEqual(enter_v_non_constant.shape, None)
205
206  @test_util.run_v1_only("b/120545219")
207  def testSwitchMergeIndexedSlices(self):
208    with self.cached_session():
209      values = constant_op.constant([1, 2, 3, 4, 5, 6])
210      indices = constant_op.constant([0, 2, 4, 6, 8, 10])
211      data = ops.IndexedSlices(values, indices)
212      pred = ops.convert_to_tensor(True)
213      switch_op = control_flow_ops.switch(data, pred)
214      merge_op = control_flow_ops.merge(switch_op)[0]
215
216      val = merge_op.values
217      ind = merge_op.indices
218    self.assertAllEqual(np.arange(1, 7), val)
219    self.assertAllEqual(np.arange(0, 12, 2), ind)
220
221  @test_util.run_v1_only("b/120545219")
222  def testSwitchDeadBranch(self):
223    with self.cached_session():
224      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
225      ports = ops.convert_to_tensor(True, name="ports")
226      switch_op = control_flow_ops.switch(data, ports)
227      dead_branch = array_ops.identity(switch_op[0])
228
229      with self.assertRaisesWithPredicateMatch(
230          errors_impl.InvalidArgumentError,
231          lambda e: "Retval[0] does not have value" in str(e)):
232        self.evaluate(dead_branch)
233
234  @test_util.run_v1_only("b/120545219")
235  def testSwitchMergeLess(self):
236    with self.cached_session():
237      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
238      zero = ops.convert_to_tensor(0)
239      one = ops.convert_to_tensor(1)
240      less_op = math_ops.less(zero, one)
241      switch_op = control_flow_ops.switch(data, less_op)
242      merge_op = control_flow_ops.merge(switch_op)[0]
243
244      result = self.evaluate(merge_op)
245    self.assertAllEqual(np.arange(1, 7), result)
246
247  @test_util.run_v1_only("b/120545219")
248  def testSwitchMergeAddIdentity(self):
249    with self.cached_session():
250      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
251      ports = ops.convert_to_tensor(False, name="ports")
252      switch_op = control_flow_ops.switch(data, ports)
253      one = constant_op.constant(1)
254      add_op = math_ops.add(switch_op[0], one)
255      id_op = array_ops.identity(switch_op[1])
256      merge_op = control_flow_ops.merge([add_op, id_op])[0]
257
258      result = self.evaluate(merge_op)
259    self.assertAllEqual(np.array([x + 1 for x in [1, 2, 3, 4, 5, 6]]), result)
260
261  @test_util.run_v1_only("b/120545219")
262  def testSwitchMergeAddMul(self):
263    with self.cached_session():
264      data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
265      ports = ops.convert_to_tensor(True, name="ports")
266      switch_op = control_flow_ops.switch(data, ports)
267      one = constant_op.constant(1)
268      add_op = math_ops.add(switch_op[0], one)
269      five = constant_op.constant(5)
270      mul_op = math_ops.multiply(switch_op[1], five)
271      merge_op = control_flow_ops.merge([add_op, mul_op])[0]
272
273      result = self.evaluate(merge_op)
274    self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
275
276  @test_util.run_v1_only("b/120545219")
277  def testLoop_false(self):
278    with self.cached_session():
279      false = ops.convert_to_tensor(False)
280      n = constant_op.constant(10)
281
282      enter_false = gen_control_flow_ops.enter(false, "foo_1", False)
283      enter_n = gen_control_flow_ops.enter(n, "foo_1", False)
284
285      merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0]
286      switch_n = control_flow_ops.switch(merge_n, enter_false)
287      exit_n = control_flow_ops.exit(switch_n[0])
288      next_n = control_flow_ops.next_iteration(switch_n[0])
289      merge_n.op._update_input(1, next_n)
290
291      result = self.evaluate(exit_n)
292    self.assertAllEqual(10, result)
293
294  @test_util.run_deprecated_v1
295  def testLoop_1(self):
296    with self.cached_session():
297      zero = constant_op.constant(0)
298      one = constant_op.constant(1)
299      n = constant_op.constant(10)
300
301      enter_i = gen_control_flow_ops.enter(zero, "foo", False)
302      enter_one = gen_control_flow_ops.enter(one, "foo", True)
303      enter_n = gen_control_flow_ops.enter(n, "foo", True)
304
305      with ops.device(test.gpu_device_name()):
306        merge_i = control_flow_ops.merge([enter_i, enter_i])[0]
307
308      less_op = math_ops.less(merge_i, enter_n)
309      cond_op = control_flow_ops.loop_cond(less_op)
310      switch_i = control_flow_ops.switch(merge_i, cond_op)
311
312      add_i = math_ops.add(switch_i[1], enter_one)
313
314      next_i = control_flow_ops.next_iteration(add_i)
315      merge_i.op._update_input(1, next_i)
316
317      exit_i = control_flow_ops.exit(switch_i[0])
318      result = self.evaluate(exit_i)
319    self.assertAllEqual(10, result)
320
321  @test_util.run_v1_only("b/120545219")
322  def testLoop_2(self):
323    with self.cached_session():
324      zero = constant_op.constant(0)
325      one = constant_op.constant(1)
326      n = constant_op.constant(10)
327
328      enter_i = gen_control_flow_ops.enter(zero, "foo", False)
329      enter_one = gen_control_flow_ops.enter(one, "foo", True)
330      enter_n = gen_control_flow_ops.enter(n, "foo", True)
331
332      merge_i = control_flow_ops.merge([enter_i, enter_i])[0]
333
334      less_op = math_ops.less(merge_i, enter_n)
335      cond_op = control_flow_ops.loop_cond(less_op)
336      switch_i = control_flow_ops.switch(merge_i, cond_op)
337
338      add_i = math_ops.add(switch_i[1], enter_one)
339
340      with ops.device(test.gpu_device_name()):
341        next_i = control_flow_ops.next_iteration(add_i)
342      merge_i.op._update_input(1, next_i)
343
344      exit_i = control_flow_ops.exit(switch_i[0])
345      result = self.evaluate(exit_i)
346    self.assertAllEqual(10, result)
347
348  @test_util.run_v1_only("b/120545219")
349  def testDifferentFrame(self):
350    with self.cached_session():
351      data = array_ops.placeholder(dtypes.float32, shape=[])
352      enter_1 = gen_control_flow_ops.enter(data, "foo_1", False)
353      enter_2 = gen_control_flow_ops.enter(data, "foo_2", False)
354      res = math_ops.add(enter_1, enter_2)
355      with self.assertRaisesOpError("has inputs from different frames"):
356        res.eval(feed_dict={data: 1.0})
357
358  @test_util.run_deprecated_v1
359  def testCondBool(self):
360    values = constant_op.constant(10)
361    fn1 = lambda: math_ops.add(values, 1)
362    fn2 = lambda: math_ops.subtract(values, 1)
363    with self.assertRaisesRegexp(TypeError, "must not be a Python bool"):
364      _ = control_flow_ops.cond(False, fn1, fn2)
365
366  @test_util.run_deprecated_v1
367  def testCondInt(self):
368    p = array_ops.placeholder(dtypes.bool, shape=[])
369    v = constant_op.constant(10)
370    fn1 = lambda: math_ops.add(v, 1)
371    fn2 = lambda: math_ops.subtract(v, 1)
372    y = control_flow_ops.cond(p, fn1, fn2)
373    grad = gradients_impl.gradients(y, [v])
374    self.assertAllEqual([None], grad)
375
376  def testCondOutputShape(self):
377    x = constant_op.constant(1.0)
378    b = control_flow_ops.cond(
379        constant_op.constant(True), lambda: math_ops.square(x),
380        lambda: math_ops.subtract(x, 1.))
381    self.assertEqual(b.shape, tensor_shape.scalar())
382
383  @test_util.run_v1_only("b/120545219")
384  def testFetchable(self):
385    with self.cached_session() as sess:
386      x = array_ops.placeholder(dtypes.float32)
387      control_flow_ops.cond(
388          constant_op.constant(True), lambda: x + 2, lambda: x + 0)
389      graph = ops.get_default_graph()
390      for op in graph.get_operations():
391        for t in op.inputs:
392          if graph.is_fetchable(t.op):
393            sess.run(t, feed_dict={x: 3})
394          else:
395            with self.assertRaisesRegexp(ValueError,
396                                         "has been marked as not fetchable"):
397              sess.run(t, feed_dict={x: 3})
398
399  @test_util.disable_control_flow_v2("Not relevant")
400  @test_util.run_v1_only("b/120545219")
401  def testFeedable(self):
402    with self.cached_session() as sess:
403      c = constant_op.constant(2)
404      i0 = constant_op.constant(0)
405      r = control_flow_ops.while_loop(lambda i: i < 1000,
406                                      lambda i: math_ops.square(c) + i, [i0])
407      self.assertEqual(1000, r.eval(feed_dict={i0: 0}))
408      feedable_tensors = all_feedables()
409      for t in feedable_tensors:
410        sess.run(r, feed_dict={t: 3})
411      graph = ops.get_default_graph()
412      for op in graph.get_operations():
413        for t in op.inputs:
414          if t not in feedable_tensors and t.dtype is dtypes.int32:
415            with self.assertRaisesRegexp(ValueError, "may not be fed"):
416              sess.run(r, feed_dict={t: 3})
417
418  @test_util.run_v1_only("b/120545219")
419  def testCondIndexedSlices(self):
420    with self.cached_session():
421      values = constant_op.constant(10)
422      indices = constant_op.constant(0)
423      x = ops.IndexedSlices(values, indices)
424      pred = math_ops.less(1, 2)
425      fn1 = lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices)
426      fn2 = lambda: ops.IndexedSlices(math_ops.subtract(x.values, 1), indices)
427      r = control_flow_ops.cond(pred, fn1, fn2)
428
429      val = r.values
430      ind = r.indices
431    self.assertAllEqual(11, val)
432    self.assertAllEqual(0, ind)
433
434  def testCondMismatchedIndexedSlices(self):
435    @def_function.function
436    def foo():
437      values = constant_op.constant(10)
438      indices = constant_op.constant(0)
439      x = ops.IndexedSlices(values, indices)
440      v1_msg = "The two structures don't have the same nested structure"
441      v2_msg = ("true_fn and false_fn arguments to tf.cond must have the same "
442                "number, type, and overall structure of return values.")
443      with self.assertRaisesRegexp(
444          TypeError,
445          v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
446        control_flow_ops.cond(
447            constant_op.constant(True),
448            lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices),
449            lambda: math_ops.add(x.values, 1), indices)
450    foo()
451
452  def testCondSparseTensor(self):
453    values = constant_op.constant([2.0, 4.0], name="values")
454    indices = constant_op.constant([[0], [3]],
455                                   dtype=dtypes.int64,
456                                   name="indices")
457    shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
458    x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
459    pred = math_ops.less(1, 2)
460    fn1 = lambda: sparse_tensor.SparseTensor(
461        indices + 1, x.values + 1, dense_shape=shape)
462    fn2 = lambda: sparse_tensor.SparseTensor(
463        indices, x.values - 1, dense_shape=shape)
464    r = control_flow_ops.cond(pred, fn1, fn2)
465    self.assertAllEqual([3.0, 5.0], r.values)
466    self.assertAllEqual([[1], [4]], r.indices)
467    self.assertAllEqual(r.values.get_shape(), (2,))
468
469  def testCondRaggedTensor(self):
470    rt = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
471    pred = math_ops.less(1, 2)
472    fn1 = lambda: array_ops.concat([rt + 2, [[100]]], axis=0)
473    fn2 = lambda: rt[:2] - 2
474    result = control_flow_ops.cond(pred, fn1, fn2)
475    self.assertAllEqual([3, 4, 5, 6, 7, 8, 100], result.values)
476    self.assertAllEqual([0, 2, 3, 6, 7], result.row_splits)
477
478  @test_util.run_v1_only("b/120545219")
479  def testCondResource(self):
480
481    with self.cached_session():
482      rv = resource_variable_ops.ResourceVariable(True)
483      self.evaluate(variables.global_variables_initializer())
484      t = ops.convert_to_tensor(1.0)
485
486      def case():
487        assign = resource_variable_ops.assign_variable_op(rv.handle, False)
488        with ops.control_dependencies([assign]):
489          return array_ops.identity(t)
490
491      self.assertEqual(
492          1.0, self.evaluate(control_flow_ops.cond(rv, case, lambda: t)))
493
494  @test_util.run_v1_only("b/120545219")
495  def testCondWithTensorArrayGrad(self):
496    with self.cached_session() as sess:
497      with ops.device(test.gpu_device_name()):
498        pred = array_ops.placeholder(dtypes.bool, [])
499        x = constant_op.constant([1.0, 2.0, 3.0])
500        y = control_flow_ops.cond(
501            pred, lambda: map_fn.map_fn(lambda z: z * 2.0, x),
502            lambda: constant_op.constant([1.0, 1.0, 1.0]))
503        g = gradients_impl.gradients(y, x)[0]
504
505      self.assertAllEqual(sess.run(g, {pred: True}), [2.0, 2.0, 2.0])
506      self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0])
507
508  @test_util.disable_control_flow_v2("b/113293074")
509  @test_util.run_v1_only("b/120545219")
510  def testCondIndexedSlicesDifferentTypes(self):
511    with self.cached_session():
512      values = constant_op.constant(10)
513      i_32 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int32)
514      i_64 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int64)
515      x = ops.IndexedSlices(values, i_32)
516      pred = math_ops.less(1, 2)
517      fn1 = lambda: ops.IndexedSlices(math_ops.add(x.values, 1), i_32)
518      fn2 = lambda: ops.IndexedSlices(math_ops.subtract(x.values, 1), i_64)
519      r = control_flow_ops.cond(pred, fn1, fn2)
520
521      val = r.values
522      ind = r.indices
523    self.assertAllEqual(11, val)
524    self.assertAllEqual(0, ind)
525    self.assertTrue(ind.dtype == np.int64)
526
527  @test_util.run_v1_only("b/120545219")
528  def testCondColocation(self):
529    with self.session(use_gpu=True):
530      with ops.device("/cpu:0"):
531        v = variables.Variable(7.0)
532
533      x = constant_op.constant(10.0)
534      pred = math_ops.less(1.0, 2.0)
535      fn1 = lambda: math_ops.add(v, 1.0)
536      fn2 = lambda: math_ops.subtract(x, 1.0)
537      r = control_flow_ops.cond(pred, fn1, fn2)
538
539      for op in x.graph.get_operations():
540        if op.name == "cond/Add/Switch":
541          self.assertDeviceEqual(op.device, "/cpu:0")
542
543  def _testCond_1(self, use_gpu):
544    with self.cached_session(use_gpu=use_gpu):
545      x = constant_op.constant(10)
546      pred = math_ops.less(1, 2)
547      fn1 = lambda: math_ops.add(x, 1)
548      fn2 = lambda: math_ops.subtract(x, 1)
549      r = control_flow_ops.cond(pred, fn1, fn2)
550
551      result = self.evaluate(r)
552    self.assertAllEqual(11, result)
553
554  def testCond_1(self):
555
556    self._testCond_1(use_gpu=False)
557    # TODO(b/116526896): Enable GPU tests.
558    # self._testCond_1(use_gpu=True)
559
560  def testCond_2(self):
561
562    with self.cached_session():
563      x = constant_op.constant(10)
564      r = control_flow_ops.cond(
565          math_ops.less(1, 0), lambda: math_ops.add(x, 1),
566          lambda: math_ops.subtract(x, 1))
567      result = self.evaluate(r)
568    self.assertAllEqual(9, result)
569
570  def testCond_3(self):
571
572    with self.cached_session():
573      x = constant_op.constant(10)
574      pred = math_ops.less(1, 2)
575      fn1 = lambda: math_ops.add(x, 1)
576      fn2 = lambda: math_ops.subtract(x, 1)
577      fn3 = lambda: math_ops.add(control_flow_ops.cond(pred, fn1, fn2), 1)
578      r = control_flow_ops.cond(pred, fn3, fn2)
579
580      result = self.evaluate(r)
581    self.assertAllEqual(12, result)
582
583  @test_util.disable_xla("b/128638446")
584  @test_util.run_in_graph_and_eager_modes
585  def testCondPruning(self):
586    v1 = variables.Variable(7)
587    v2 = variables.Variable(7)
588    v3 = variables.Variable(7)
589
590    def f():
591      age = constant_op.constant(3)
592      max_age = constant_op.constant(2)
593      pred = math_ops.greater(age, max_age)
594      fn1 = lambda: [state_ops.assign(v1, 1).op, state_ops.assign(v2, 2).op]
595      fn2 = lambda: [state_ops.assign(v3, 3).op, constant_op.constant(10).op]
596      r = control_flow_ops.cond(pred, fn1, fn2)
597      self.assertEqual(len(r), 2)
598      return r[1]
599
600    f_defun = eager_function.defun(f)
601
602    if not context.executing_eagerly():
603      with self.cached_session():
604        self.evaluate(variables.global_variables_initializer())
605        result = self.evaluate(f())
606        self.assertEqual(True, result)
607        # Only second cond result was fetched, so v1 assign shouldn't run.
608        self.assertEqual(7, self.evaluate(v1))
609        self.assertEqual(2, self.evaluate(v2))
610        self.assertEqual(7, self.evaluate(v3))
611
612    result = f_defun()
613    self.assertEqual(True, self.evaluate(result))
614    # Both v1 and v2 branch assignments should be run in defun.
615    self.assertEqual(1, self.evaluate(v1))
616    self.assertEqual(2, self.evaluate(v2))
617    self.assertEqual(7, self.evaluate(v3))
618
619  def testCond_5(self):
620    with self.cached_session():
621      alive = constant_op.constant(True, name="alive")
622      count = constant_op.constant(0, name="count")
623
624      def body(i):
625        return control_flow_ops.cond(
626            alive, lambda: [math_ops.less(i, 3), math_ops.add(count, 1)],
627            lambda: [alive, count])
628
629      for i in range(10):
630        alive, count = body(i)
631      self.assertAllEqual(4, self.evaluate(count))
632
633  @test_util.run_v1_only("b/120545219")
634  def testCond_6(self):
635    with self.cached_session():
636      v1 = variables.Variable([7])
637
638      age = constant_op.constant(3)
639      pred = math_ops.greater(age, 4)
640      fn1 = lambda: age
641      fn2 = lambda: v1
642      r = control_flow_ops.cond(pred, fn1, fn2)
643
644      self.evaluate(variables.global_variables_initializer())
645      result = self.evaluate(r)
646      self.assertAllEqual(np.array([7]), result)
647
648  def testCond_7(self):
649    with self.cached_session() as sess:
650      x = constant_op.constant(10)
651      y = constant_op.constant(200)
652      pred = math_ops.less(1, 2)
653      fn1 = lambda: [math_ops.add(x, 1), math_ops.add(x, 2)]
654      fn2 = lambda: [y, y]
655      r = control_flow_ops.cond(pred, fn1, fn2)
656      self.assertAllEqual([11, 12], self.evaluate(r))
657
658  @test_util.run_gpu_only
659  @test_util.run_deprecated_v1
660  def testCond_Device(self):
661    x = constant_op.constant(-10.)
662
663    # True branch function defined outside of device scope
664    def true_fn():
665      return math_ops.exp(x)
666
667    with ops.device("CPU:0"):
668      r = control_flow_ops.cond(
669          constant_op.constant(True), true_fn, lambda: 0.)
670      self.assertIn("cpu", r.device.lower())
671
672    with session.Session() as sess:
673      options = config_pb2.RunOptions(output_partition_graphs=True)
674      run_metadata = config_pb2.RunMetadata()
675      sess.run(r, options=options, run_metadata=run_metadata)
676      # We expect that everything runs on CPU, even if GPU is available.
677      self.assertEqual(len(run_metadata.partition_graphs), 1)
678
679  def _count_matching_switch_nodes_on_device(self, run_metadata, device_str):
680    # Returns the number of Switch nodes with type float32 placed on
681    # `device_str`.
682    device_graphs = [
683        g for g in run_metadata.partition_graphs
684        if device_str in g.node[0].device
685    ]
686    self.assertLen(device_graphs, 1)
687    switch_nodes = [
688        n for n in device_graphs[0].node if n.op == "Switch" and
689        n.attr["T"].type == dtypes.float32.as_datatype_enum
690    ]
691    return len(switch_nodes)
692
693  @test_util.run_gpu_only
694  @test_util.run_deprecated_v1
695  def testCondSwitchColocatedWithInputWhenInputOnCPU(self):
696    x = array_ops.placeholder(dtypes.float32)
697
698    # `arg` is used in the cond then branch so a Switch node is created for it.
699    # We test that the Switch node gets placed on the same device as `arg`.
700    # We force `arg` to be on CPU here.
701    with ops.device("CPU:0"):
702      arg = x + 10.
703
704    def true_fn():
705      with ops.device("CPU:0"):
706        return arg + 1
707
708    r = control_flow_ops.cond(constant_op.constant(True), true_fn, lambda: 0.)
709
710    with session.Session() as sess:
711      run_metadata = config_pb2.RunMetadata()
712      options = config_pb2.RunOptions(output_partition_graphs=True)
713      sess.run(
714          r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
715      self.assertEqual(len(run_metadata.partition_graphs), 2)
716      # Check that the Switch for `arg` gets placed on CPU.
717      self.assertEqual(
718          self._count_matching_switch_nodes_on_device(run_metadata, "CPU"), 1)
719      self.assertEqual(
720          self._count_matching_switch_nodes_on_device(run_metadata, "GPU"), 0)
721
722  @test_util.run_gpu_only
723  @test_util.run_deprecated_v1
724  def testCondSwitchColocatedWithInputWhenInputOnGPU(self):
725    x = array_ops.placeholder(dtypes.float32)
726
727    # `arg` is used in the cond then branch so a Switch node is created for it.
728    # We test that the Switch node gets placed on the same device as `arg`.
729    # Note: `arg` gets placed on GPU by default by the placer.
730    arg = x + 10.
731
732    def true_fn():
733      with ops.device("CPU:0"):
734        return arg + 1
735
736    r = control_flow_ops.cond(constant_op.constant(True), true_fn, lambda: 0.)
737
738    with session.Session() as sess:
739      run_metadata = config_pb2.RunMetadata()
740      options = config_pb2.RunOptions(output_partition_graphs=True)
741      sess.run(
742          r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
743      self.assertEqual(len(run_metadata.partition_graphs), 2)
744      # Check that the Switch for `arg` gets placed on GPU.
745      self.assertEqual(
746          self._count_matching_switch_nodes_on_device(run_metadata, "CPU"), 0)
747      self.assertEqual(
748          self._count_matching_switch_nodes_on_device(run_metadata, "GPU"), 1)
749
750  def testCondListOutput(self):
751    with self.cached_session() as sess:
752      x = constant_op.constant(10)
753      y = constant_op.constant(200)
754      pred = math_ops.less(1, 2)
755      fn1 = lambda: [math_ops.add(x, y), math_ops.add(x, y)]
756      fn2 = lambda: [y, y]
757      r = control_flow_ops.cond(pred, fn1, fn2)
758      test_result = self.evaluate(r)
759      self.assertListEqual([210, 210], test_result)
760
761  def testTupleOutput(self):
762    with self.cached_session() as sess:
763      x = constant_op.constant(10)
764      y = constant_op.constant(200)
765      pred = math_ops.less(1, 2)
766      fn1 = lambda: (math_ops.add(x, y), math_ops.add(x, y))
767      fn2 = lambda: (y, y)
768      r = control_flow_ops.cond(pred, fn1, fn2)
769      test_result = self.evaluate(r)
770      self.assertTupleEqual((210, 210), test_result)
771
772  def testDictOutput(self):
773    with self.cached_session() as sess:
774      x = constant_op.constant(10)
775      y = constant_op.constant(200)
776      pred = math_ops.less(1, 2)
777      fn1 = lambda: {"a": math_ops.add(x, y), "b": math_ops.add(x, y)}
778      fn2 = lambda: {"a": y, "b": y}
779      r = control_flow_ops.cond(pred, fn1, fn2)
780      test_result = self.evaluate(r)
781      self.assertDictEqual({"a": 210, "b": 210}, test_result)
782
783  @test_util.run_deprecated_v1
784  def testEmbeddedListOutput(self):
785    with self.cached_session() as sess:
786      x = constant_op.constant(10)
787      y = constant_op.constant(200)
788      pred = math_ops.less(1, 2)
789      fn1 = lambda: [[math_ops.add(x, y), math_ops.add(x, y)]]
790      fn2 = lambda: [[y, y]]
791      # Pass strict=True flag as cond_v2 allows for tensors to be
792      # in nested output structures as singletons
793      r = control_flow_ops.cond(pred, fn1, fn2, strict=True)
794      test_result = self.evaluate(r)
795      self.assertListEqual([[210, 210]], test_result)
796
797  def testEmbeddedTupleOutput(self):
798    with self.cached_session() as sess:
799      x = constant_op.constant(10)
800      y = constant_op.constant(200)
801      pred = math_ops.less(1, 2)
802      fn1 = lambda: ((math_ops.add(x, y), math_ops.add(x, y)))
803      fn2 = lambda: ((y, y))
804      r = control_flow_ops.cond(pred, fn1, fn2)
805      test_result = self.evaluate(r)
806      self.assertTupleEqual(((210, 210)), test_result)
807
808  def testEmbeddedDictOutput(self):
809    with self.cached_session() as sess:
810      x = constant_op.constant(10)
811      y = constant_op.constant(200)
812      pred = math_ops.less(1, 2)
813      fn1 = lambda: {"a": {"c": math_ops.add(x, y)},
814                     "b": {"d": math_ops.add(x, y)}}
815      fn2 = lambda: {"a": {"c": y},
816                     "b": {"d": y}}
817      r = control_flow_ops.cond(pred, fn1, fn2)
818      test_result = self.evaluate(r)
819      self.assertDictEqual({"a": {"c": 210}, "b": {"d": 210}}, test_result)
820
821  @test_util.run_v1_only("b/120545219")
822  def testCheckNestedOutputStruct(self):
823    with self.cached_session() as sess:
824      x = constant_op.constant(10)
825      y = constant_op.constant(200)
826      pred = math_ops.less(1, 2)
827      fn1 = lambda: {"a": math_ops.add(x, y), "b": math_ops.add(x, y)}
828      fn2 = lambda: {"c": y, "d": y}
829      v1_msg = "The two structures don't have the same nested structure"
830      v2_msg = ("true_fn and false_fn arguments to tf.cond must have the same "
831                "number, type, and overall structure of return values.")
832      with self.assertRaisesRegexp(
833          TypeError if control_flow_util.ENABLE_CONTROL_FLOW_V2 else ValueError,
834          v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
835        control_flow_ops.cond(pred, fn1, fn2)
836
837  @test_util.run_deprecated_v1
838  def testCondRef(self):
839
840    with self.cached_session():
841      x = gen_state_ops.variable(
842          shape=[1],
843          dtype=dtypes.float32,
844          name="x",
845          container="",
846          shared_name="")
847      true_fn = lambda: x
848      false_fn = lambda: constant_op.constant([2.0])
849      r = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn)
850      self.assertAllEqual([2.0], self.evaluate(r))
851
852  @test_util.disable_control_flow_v2("b/79881896 (placeholder)")
853  @test_util.run_v1_only("b/120545219")
854  def testCondWithControl(self):
855    with self.cached_session():
856      control_holder = array_ops.placeholder(dtypes.float32, shape=())
857      a = constant_op.constant(3)
858
859      def true_branch():
860        with ops.control_dependencies([control_holder]):
861          _ = a + 1
862        return a + 2
863
864      r = control_flow_ops.cond(
865          constant_op.constant(True), true_branch,
866          lambda: constant_op.constant(1))
867      self.assertEqual(5, self.evaluate(r))
868
869  @test_util.run_v1_only("b/120545219")
870  def testUninitializedRefIdentity(self):
871    with self.cached_session() as sess:
872      v = gen_state_ops.variable(
873          shape=[1],
874          dtype=dtypes.float32,
875          name="v",
876          container="",
877          shared_name="")
878      inited = state_ops.is_variable_initialized(v)
879      v_f, v_t = control_flow_ops.ref_switch(v, inited)
880      # Both v_f and v_t are uninitialized references. However, an actual use
881      # of the reference in the 'true' branch in the 'tf.identity' op will
882      # not 'fire' when v is uninitialized, so this is a valid construction.
883      # This test tests that ref_identity allows uninitialized ref as input
884      # so that this construction is allowed.
885      v_f_op = gen_array_ops.ref_identity(v_f)
886      v_t_op = gen_array_ops.ref_identity(v_t)
887      with ops.control_dependencies([v_f_op]):
888        assign_v = state_ops.assign(v, [1.0])
889      with ops.control_dependencies([v_t_op]):
890        orig_v = array_ops.identity(v)
891      merged_op = control_flow_ops.merge([assign_v, orig_v])
892      self.assertAllEqual([1.0], self.evaluate(merged_op.output))
893
894  def testCondSwitchIdentity(self):
895    # Make sure the recv identity is not removed by optimization.
896    with session.Session(config=opt_cfg()) as sess:
897      pred = constant_op.constant(True)
898
899      def fn1():
900        return control_flow_ops.no_op()
901
902      def fn2():
903        return control_flow_ops.Assert(False, ["Wrong branch!!!"])
904
905      r = control_flow_ops.cond(pred, fn1, fn2)
906      self.evaluate(r)
907
908  def testCondRecvIdentity(self):
909    # Make sure the switch identity is not removed by optimization.
910    with session.Session(config=opt_cfg()) as sess:
911      with ops.device(test.gpu_device_name()):
912        pred = constant_op.constant(True)
913
914      def fn1():
915        return control_flow_ops.no_op()
916
917      def fn2():
918        with ops.device("/cpu:0"):
919          return control_flow_ops.Assert(False, ["Wrong branch!!!"])
920
921      r = control_flow_ops.cond(pred, fn1, fn2)
922      self.evaluate(r)
923
924  @test_util.run_v1_only("b/120545219")
925  def testCondGrad_1(self):
926    with self.cached_session():
927      x = constant_op.constant(10.0, name="x")
928      pred = math_ops.less(1, 2)
929      fn1 = lambda: array_ops.identity(x)
930      fn2 = lambda: array_ops.identity(x)
931      r = control_flow_ops.cond(pred, fn1, fn2)
932
933      grad = gradients_impl.gradients(r, [x])[0]
934      self.assertAllEqual(1.0, self.evaluate(grad))
935
936  @test_util.run_deprecated_v1
937  def testCondGrad_2(self):
938    with self.cached_session():
939      c = array_ops.placeholder(dtypes.int32, shape=[])
940      x = constant_op.constant(10.0)
941      pred = math_ops.less(c, 2)
942      fn1 = lambda: math_ops.multiply(x, 42.0)
943      fn2 = lambda: math_ops.multiply(x, 3.0)
944      r = control_flow_ops.cond(pred, fn1, fn2)
945
946      grad = gradients_impl.gradients(r, [x])[0]
947      self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1}))
948      self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
949
950  @test_util.disable_control_flow_v2(
951      "b/110550782 (gradient w.r.t external variable)")
952  @test_util.run_deprecated_v1
953  def testCondGrad_3(self):
954    with self.cached_session():
955      c = array_ops.placeholder(dtypes.int32, shape=[])
956      ox = constant_op.constant(10.0)
957      pred = math_ops.less(c, 2)
958
959      def fn1(x):
960        m = x * x
961        return gradients_impl.gradients(m, [ox])[0]
962
963      fn2 = lambda: math_ops.multiply(ox, 3.0)
964      y = math_ops.multiply(7.0, ox)
965      r = control_flow_ops.cond(pred, lambda: fn1(y), fn2)
966
967      self.assertAllEqual(980.0, r.eval(feed_dict={c: 1}))
968      self.assertAllEqual(30.0, r.eval(feed_dict={c: 3}))
969
970  @test_util.run_deprecated_v1
971  def testCondGradMultiDevice(self):
972    config = config_pb2.ConfigProto(device_count={"CPU": 2},
973                                    allow_soft_placement=True)
974    with self.cached_session(use_gpu=True, config=config) as sess:
975      pred = array_ops.placeholder(dtypes.bool, [])
976      x = array_ops.placeholder(dtypes.float32)
977      y = array_ops.placeholder(dtypes.float32)
978
979      with ops.device("/cpu:0"):
980        z = control_flow_ops.cond(pred, lambda: x * y * 2.0, lambda: 2.0)
981
982      with ops.device("/cpu:1"):
983        grad = gradients_impl.gradients(z, x)[0]
984
985      with ops.device("/cpu:0"):
986        grad_grad = gradients_impl.gradients(grad, x)[0]
987
988      self.assertEqual(sess.run(grad, {pred: True, x: 1.0, y: 2.0}), 4.0)
989      self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
990
991      # v1 control flow gets None second derivative for some reason.
992      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
993        self.assertIsNone(grad_grad)
994        return
995
996      self.assertEqual(sess.run(grad_grad, {pred: True, x: 1.0, y: 2.0}), 0.0)
997      self.assertEqual(sess.run(grad_grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
998
999  @test_util.run_v1_only("b/120545219")
1000  def testNestedCond_Simple(self):
1001    with self.cached_session():
1002      x = constant_op.constant(0., name="X")
1003      y = control_flow_ops.cond(
1004          constant_op.constant(True), lambda: x,
1005          lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x))
1006      result = gradients_impl.gradients(y, x)[0]
1007      self.assertEqual(1.0, self.evaluate(result))
1008
1009      z = control_flow_ops.cond(
1010          constant_op.constant(False), lambda: x,
1011          lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x))
1012      result = gradients_impl.gradients(z, x)[0]
1013      self.assertEqual(1.0, self.evaluate(result))
1014
1015  @test_util.disable_control_flow_v2("b/113327884")
1016  @test_util.run_v1_only("b/120545219")
1017  def testCondGrad_Gather(self):
1018    with self.cached_session() as sess:
1019      v1 = variables.Variable([1.0, 42.0])
1020      c = array_ops.placeholder(dtypes.int32, shape=[])
1021      pred = math_ops.less(c, 2)
1022      fn1 = lambda: array_ops.identity(v1)
1023      fn2 = lambda: array_ops.gather(v1, [1, 1])
1024      r = control_flow_ops.cond(pred, fn1, fn2)
1025      grad = gradients_impl.gradients(r, [v1])[0]
1026      self.evaluate(variables.global_variables_initializer())
1027      # Should just be [1, 1], but possibly a sparse representation
1028      gv, gi = sess.run([grad.values, grad.indices], feed_dict={c: 1})
1029      dense_gv = [
1030          sum(y for (x, y) in zip(gi, gv) if x == i) for i in range(2)
1031      ]
1032      self.assertAllEqual(dense_gv, [1.0, 1.0])
1033      # Should be [0, 2], as the else forwards v1[1] twice
1034      gv, gi = sess.run([grad.values, grad.indices], feed_dict={c: 3})
1035      dense_gv = [
1036          sum(y for (x, y) in zip(gi, gv) if x == i) for i in range(2)
1037      ]
1038      self.assertAllEqual(dense_gv, [0.0, 2.0])
1039
1040  @test_util.run_deprecated_v1
1041  def testCondGrad_ResourceVarSparseRead(self):
1042    # NOTE(skyewm): this test is interesting because the
1043    # ResourceVariable.sparse_read gradient function returns IndexedSlices.
1044    var = resource_variable_ops.ResourceVariable(
1045        np.ones((4, 2), dtype=np.float32))
1046    x = constant_op.constant(1.0)
1047    r = control_flow_ops.cond(
1048        constant_op.constant(True),
1049        lambda: x * math_ops.reduce_sum(var.sparse_read([1, 2])),
1050        lambda: constant_op.constant(np.zeros((2, 3)),
1051                                     dtype=dtypes.float32))
1052    grad = gradients_impl.gradients(r, var)[0]
1053
1054    self.evaluate(variables.global_variables_initializer())
1055    grad_val = self.evaluate(grad)
1056    self.assertIsInstance(grad_val, ops.IndexedSlicesValue)
1057    self.assertAllEqual(gradient_checker_v2._to_numpy(grad_val), [[0., 0.],
1058                                                                  [1., 1.],
1059                                                                  [1., 1.],
1060                                                                  [0., 0.]])
1061
1062  @test_util.disable_xla("b/128643464")
1063  def testCondGrad_MultiGather(self):
1064    # NOTE(skyewm): this test is interesting because the array_ops.gather and
1065    # ResourceVariable.sparse_read gradient functions returns IndexedSlices.
1066    var = resource_variable_ops.ResourceVariable(
1067        np.ones((4, 2), dtype=np.float32))
1068    x1 = constant_op.constant(np.ones((3, 3), dtype=np.float32))
1069    x2 = constant_op.constant(2.0)
1070
1071    def true_fn():
1072      y1 = var.sparse_read([1, 2])
1073      y2 = array_ops.gather(x1, [2]) * x2
1074      y3 = x2 * [1., 1., 1.]
1075      return y1, y2, y3
1076
1077    def false_fn():
1078      y1 = np.zeros((2, 2), dtype=np.float32)
1079      y2 = array_ops.gather(x1, [2]) * x2
1080      y3 = array_ops.gather(x1, [2])
1081      return y1, y2, y3
1082
1083    @def_function.function
1084    def foo():
1085      r = control_flow_ops.cond(constant_op.constant(True), true_fn, false_fn)
1086      return gradients_impl.gradients(r, [var, x1, x2])
1087
1088    grad = foo()
1089    self.evaluate(variables.global_variables_initializer())
1090    var_grad, x1_grad, x2_grad = self.evaluate(grad)
1091    self.assertIsInstance(var_grad, ops.IndexedSlicesValue)
1092    self.assertAllEqual(gradient_checker_v2._to_numpy(var_grad), [[0., 0.],
1093                                                                  [1., 1.],
1094                                                                  [1., 1.],
1095                                                                  [0., 0]])
1096    self.assertIsInstance(x1_grad, ops.IndexedSlicesValue)
1097    self.assertAllEqual(gradient_checker_v2._to_numpy(x1_grad), [[0., 0., 0.],
1098                                                                 [0., 0., 0.],
1099                                                                 [2., 2., 2.]])
1100    self.assertIsInstance(x1_grad, ops.IndexedSlicesValue)
1101    self.assertEqual(gradient_checker_v2._to_numpy(x2_grad), 6.)
1102
1103  @test_util.run_v1_only("b/120545219")
1104  def testCondPredicateTensor(self):
1105    """Regression test for lowering predicate from non-first output of an op."""
1106
1107    @eager_function.defun
1108    def foo():
1109      return constant_op.constant("foo"), constant_op.constant(True)
1110
1111    r = control_flow_ops.cond(foo()[1], lambda: 1.0, lambda: 2.0)
1112    self.assertEqual(self.evaluate(r), 1.0)
1113
1114  @test_util.run_v1_only("Tests Session.run() pruning logic.")
1115  def testCondFeedConstantPredicate(self):
1116    with self.cached_session() as sess:
1117      value = constant_op.constant(37.0)
1118      predicate = constant_op.constant(True)
1119      cond_output = control_flow_ops.cond(
1120          predicate, lambda: constant_op.constant(0.0), lambda: value)
1121      result = array_ops.identity(cond_output)
1122      self.assertEqual(37.0, sess.run(result, feed_dict={predicate: False}))
1123      self.assertEqual(0.0, sess.run(result, feed_dict={predicate: True}))
1124      self.assertEqual(0.0, sess.run(result))
1125
1126  @test_util.run_v1_only("Tests Session.run() pruning logic.")
1127  def testCondFeedPlaceholderWithDefaultPredicate(self):
1128    with self.cached_session() as sess:
1129      value = constant_op.constant(37.0)
1130      predicate = array_ops.placeholder_with_default(
1131          constant_op.constant(True), [])
1132      cond_output = control_flow_ops.cond(
1133          predicate, lambda: constant_op.constant(0.0), lambda: value)
1134      result = array_ops.identity(cond_output)
1135      self.assertAllEqual(37.0, sess.run(result, feed_dict={predicate: False}))
1136      self.assertAllEqual(0.0, sess.run(result, feed_dict={predicate: True}))
1137      self.assertAllEqual(0.0, sess.run(result))
1138
1139  @test_util.disable_xla("b/128644469 PrintV2")
1140  @test_util.run_in_graph_and_eager_modes
1141  def testCondAutoControlDeps(self):
1142
1143    def branch_fn():
1144      logging_ops.print_v2("A")
1145      logging_ops.print_v2("B")
1146      with ops.control_dependencies([logging_ops.print_v2("C")]):
1147        return constant_op.constant(10)
1148
1149    def build_cond():
1150      return control_flow_ops.cond(
1151          constant_op.constant(True), branch_fn, lambda: 0)
1152
1153    def build_nested_cond():
1154      return control_flow_ops.cond(
1155          constant_op.constant(True), build_cond, lambda: 0)
1156
1157    # In v1 graph mode, pruning should make only "C" print.
1158    if not context.executing_eagerly():
1159      with self.cached_session():
1160        with self.captureWritesToStream(sys.stderr) as printed:
1161          self.assertEqual(self.evaluate(build_cond()), 10)
1162        self.assertEqual(printed.contents(), "C\n")
1163
1164        with self.captureWritesToStream(sys.stderr) as printed:
1165          self.assertEqual(self.evaluate(build_nested_cond()), 10)
1166        self.assertEqual(printed.contents(), "C\n")
1167
1168    # In defuns, all prints should execute in program order.
1169    # This doesn't work with legacy control flow.
1170    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
1171
1172      @eager_function.defun
1173      def cond():
1174        return build_cond()
1175
1176      with self.captureWritesToStream(sys.stderr) as printed:
1177        self.assertEqual(self.evaluate(cond()), 10)
1178      self.assertTrue(printed.contents().endswith("A\nB\nC\n"),
1179                      printed.contents())
1180
1181      @eager_function.defun
1182      def nested_cond():
1183        return build_nested_cond()
1184
1185      with self.captureWritesToStream(sys.stderr) as printed:
1186        self.assertEqual(self.evaluate(nested_cond()), 10)
1187      self.assertTrue(printed.contents().endswith("A\nB\nC\n"),
1188                      printed.contents())
1189
1190    # wrap_function should prune.
1191    def pruned_cond():
1192      return build_cond()
1193    pruned_cond = wrap_function.wrap_function(pruned_cond, [])
1194
1195    with self.captureWritesToStream(sys.stderr) as printed:
1196      self.assertEqual(self.evaluate(pruned_cond()), 10)
1197    self.assertEqual(printed.contents(), "C\n")
1198
1199    def pruned_nested_cond():
1200      return build_nested_cond()
1201    pruned_nested_cond = wrap_function.wrap_function(pruned_nested_cond, [])
1202
1203    with self.captureWritesToStream(sys.stderr) as printed:
1204      self.assertEqual(self.evaluate(pruned_nested_cond()), 10)
1205    self.assertEqual(printed.contents(), "C\n")
1206
1207  @test_util.disable_xla("b/128643646 PrintV2")
1208  @test_util.run_in_graph_and_eager_modes
1209  def testWhileAutoControlDeps(self):
1210    # Legacy while_loop fails this test because it produces deprecation notices
1211    # in stderr.
1212    if not control_flow_util.ENABLE_CONTROL_FLOW_V2: return
1213
1214    def cond(i, unused_x):
1215      logging_ops.print_v2("A")
1216      return i < 2
1217
1218    def body(i, x):
1219      logging_ops.print_v2("B")
1220      with ops.control_dependencies([logging_ops.print_v2("C")]):
1221        x = array_ops.identity(x)
1222      with ops.control_dependencies([logging_ops.print_v2("D")]):
1223        return i + 1, x
1224
1225    def build_while():
1226      return control_flow_ops.while_loop(
1227          cond, body, [constant_op.constant(0), constant_op.constant(0)])
1228
1229    def build_nested_while():
1230      return control_flow_ops.cond(
1231          constant_op.constant(True), build_while, lambda: [0, 0])
1232
1233    # In v1 graph mode, pruning should make only "D" print.
1234    if not context.executing_eagerly():
1235      with self.cached_session():
1236        with self.captureWritesToStream(sys.stderr) as printed:
1237          self.assertEqual(self.evaluate(build_while()[0]), 2)
1238        self.assertTrue(printed.contents().endswith("D\nD\n"),
1239                        printed.contents())
1240
1241        with self.captureWritesToStream(sys.stderr) as printed:
1242          self.assertEqual(self.evaluate(build_nested_while()[0]), 2)
1243        self.assertTrue(printed.contents().endswith("D\nD\n"),
1244                        printed.contents())
1245
1246    # In defuns, all prints should execute in program order.
1247    @eager_function.defun
1248    def while_loop():
1249      return build_while()[0]
1250
1251    with self.captureWritesToStream(sys.stderr) as printed:
1252      self.assertEqual(self.evaluate(while_loop()), 2)
1253    self.assertTrue(printed.contents().endswith("A\nB\nC\nD\nA\nB\nC\nD\nA\n"),
1254                    printed.contents())
1255
1256    @eager_function.defun
1257    def nested_while_loop():
1258      return build_nested_while()[0]
1259
1260    # TODO(b/117840611): calling nested_while_loop fails in eager
1261    if not context.executing_eagerly():
1262      with self.captureWritesToStream(sys.stderr) as printed:
1263        self.assertEqual(self.evaluate(nested_while_loop()), 2)
1264      self.assertTrue(
1265          printed.contents().endswith("A\nB\nC\nD\nA\nB\nC\nD\nA\n"),
1266          printed.contents())
1267
1268    # wrap_function should prune.
1269    def pruned_while():
1270      return build_while()[0]
1271    pruned_while = wrap_function.wrap_function(pruned_while, [])
1272
1273    with self.captureWritesToStream(sys.stderr) as printed:
1274      self.assertEqual(self.evaluate(pruned_while()), 2)
1275    self.assertTrue(printed.contents().endswith("D\nD\n"), printed.contents())
1276
1277    def pruned_nested_while():
1278      return build_nested_while()[0]
1279    pruned_nested_while = wrap_function.wrap_function(pruned_nested_while, [])
1280
1281    # TODO(b/117840611): calling nested_while_loop fails in eager
1282    if not context.executing_eagerly():
1283      with self.captureWritesToStream(sys.stderr) as printed:
1284        self.assertEqual(self.evaluate(pruned_nested_while()), 2)
1285      self.assertTrue(printed.contents().endswith("D\nD\n"), printed.contents())
1286
1287  # Microbenchmark: 256,000 iterations/s.
1288  def testWhile_1(self):
1289    with self.cached_session():
1290      n = constant_op.constant(0)
1291      c = lambda x: math_ops.less(x, 10000)
1292      b = lambda x: math_ops.add(x, 1)
1293      r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
1294      self.assertEqual(10000, self.evaluate(r))
1295
1296  @test_util.run_v1_only("b/120545219")
1297  def testWhileExternalControlDependencies(self):
1298    with self.cached_session():
1299      v = variables.Variable(0.0)
1300      v.initializer.run()
1301      increment = v.assign_add(1.0).read_value()
1302
1303      def body_fn(i):
1304        with ops.control_dependencies([increment]):
1305          return i + 1
1306
1307      result = control_flow_ops.while_loop(cond=lambda i: i < 2,
1308                                           body=body_fn, loop_vars=[1])
1309      self.assertAllEqual(result, 2)
1310      self.assertAllEqual(v.read_value(), 1.0)
1311
1312  @test_util.run_v1_only("b/120545219")
1313  def testWhileExternalControlDependenciesNoInput(self):
1314    with self.cached_session():
1315      v = variables.Variable(0.0)
1316      v.initializer.run()
1317      # TODO(apassos): figure out why the reading is necessary here.
1318      increment = v.assign_add(1.0).read_value()
1319
1320      def body_fn(unused_i):
1321        with ops.control_dependencies([increment]):
1322          return constant_op.constant(5, name="five")
1323
1324      result = control_flow_ops.while_loop(cond=lambda i: i < 5,
1325                                           body=body_fn, loop_vars=[0])
1326      self.evaluate(result)
1327      self.assertAllEqual(self.evaluate(v), 1.0)
1328
1329  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
1330  @test_util.run_v1_only("b/120545219")
1331  def testWhileWithRefs_1(self):
1332    with self.cached_session() as sess:
1333      x = variables.VariableV1(0)._ref()  # pylint: disable=protected-access
1334      i = constant_op.constant(0)
1335      c = lambda i, x: math_ops.less(i, 100)
1336
1337      self.assertEqual(x.dtype, dtypes.int32_ref)
1338
1339      def b(i, x):
1340        self.assertEqual(x.dtype, dtypes.int32_ref)
1341        return (i + 1, gen_array_ops.ref_identity(x))
1342
1343      r = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=5)
1344
1345      self.evaluate(variables.global_variables_initializer())
1346
1347      self.assertEqual(r[0].dtype, dtypes.int32)
1348      self.assertEqual(r[1].dtype, dtypes.int32_ref)
1349
1350      value_i, value_x = self.evaluate(r)
1351
1352    self.assertEqual(100, value_i)
1353    self.assertEqual(0, value_x)
1354
1355  def testWhile_2(self):
1356    with self.cached_session():
1357      s = constant_op.constant(0)
1358      r = isum(s)
1359      self.assertAllEqual(45, self.evaluate(r))
1360
1361  def testWhileWithMaximumIterations(self):
1362    with self.cached_session():
1363      s = constant_op.constant([1, 2, 3, 4, 5])
1364      r = isum(s, maximum_iterations=3)
1365      self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], self.evaluate(r))
1366
1367  @test_util.run_v1_only("b/120545219")
1368  def testWhileWithMaximumIterationsAndSingleArgument(self):
1369    with self.cached_session():
1370      r = control_flow_ops.while_loop(
1371          lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
1372      self.assertEqual(1, self.evaluate(r))
1373
1374  @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
1375  @test_util.run_v1_only("b/120545219")
1376  def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
1377    v = constant_op.constant(1.0)
1378
1379    def training_loop_with_gradient(i):
1380      out = control_flow_ops.while_loop(
1381          lambda i_, _: i_ < 3,
1382          lambda i_, j: [i_ + 1, j * v], [0, 1.0],
1383          maximum_iterations=i)
1384      g = gradients_impl.gradients(out, v)
1385      with ops.control_dependencies(g):
1386        return i + 1
1387
1388    xla_context = control_flow_ops.XLAControlFlowContext()
1389    xla_context.Enter()
1390    # Create training loop, ensure we can call gradient() of
1391    # while_loop inside the training loop.
1392    loop = control_flow_ops.while_loop(lambda i: i < 3,
1393                                       training_loop_with_gradient, [0])
1394    xla_context.Exit()
1395
1396    loop_execute = array_ops.identity(loop)  # Because loop is not fetchable.
1397
1398    # Should execute without issue.
1399    self.assertEqual(3, self.evaluate(loop_execute))
1400
1401  @test_util.run_v1_only("b/120545219")
1402  def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
1403    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
1404      self.skipTest("WhileV2 does lazy evaluation of maximum_iterations")
1405    v = constant_op.constant(1.0)
1406
1407    def inner_body(i, x):
1408      out = control_flow_ops.while_loop(
1409          lambda i, _: i < 3,
1410          lambda i, j: [i + 1, j * v], [0, x],
1411          maximum_iterations=i)
1412      return out
1413
1414    def create_while_loop(maximum_iterations=None):
1415      return control_flow_ops.while_loop(
1416          lambda i, _: i < 3,
1417          inner_body, [0, 1.0],
1418          maximum_iterations=maximum_iterations)
1419
1420    loop_no_xla = create_while_loop(maximum_iterations=5)
1421    # maximum_iterations is fine outside of an XLA scope
1422    gs = gradients_impl.gradients(loop_no_xla, v)
1423    self.evaluate(gs)  # This should execute without error.
1424
1425    xla_context = control_flow_ops.XLAControlFlowContext()
1426    xla_context.Enter()
1427    loop_no_maxiter = create_while_loop()
1428    loop_with_maxiter = create_while_loop(maximum_iterations=2)
1429    xla_context.Exit()
1430
1431    with self.assertRaisesRegexp(
1432        ValueError,
1433        r"Cannot create a gradient accumulator for tensor '.+' inside "
1434        r"XLA while_loop because maximum_iterations was not passed to "
1435        r"the tf.while_loop call \('.+'\)."):
1436      _ = gradients_impl.gradients(loop_no_maxiter, v)
1437
1438    with self.assertRaisesRegexp(
1439        ValueError,
1440        r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
1441        r"while_loop. maximum_iterations tensor '.+' for while_loop context "
1442        r"'.+' must be statically known \(e.g. a constant value or known "
1443        r"shape dimension\), or be defined at or outside the while loop "
1444        r"context '.*' \(currently defined in '.*'\)"):
1445      _ = gradients_impl.gradients(loop_with_maxiter, v)
1446
1447  @test_util.run_v1_only("b/120545219")
1448  def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
1449    v = constant_op.constant(1.0)
1450
1451    def create_while_loop():
1452      max_iter_holder = []
1453
1454      def create_mi():
1455        max_iter_holder.append(array_ops.placeholder(dtypes.int32, shape=()))
1456        return 1.0
1457
1458      _ = control_flow_ops.cond(
1459          constant_op.constant(True), create_mi, create_mi)
1460
1461      return control_flow_ops.while_loop(
1462          lambda i, _: i < 3,
1463          lambda i, x: (i + 1, v * x), (0, 1.0),
1464          maximum_iterations=max_iter_holder[0])
1465
1466    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
1467      xla_context = control_flow_ops.XLAControlFlowContext()
1468      xla_context.Enter()
1469      with self.assertRaisesRegexp(
1470          ValueError, r"Tensor.*Placeholder:0.* must be from the same graph.*"):
1471        loop = create_while_loop()
1472      xla_context.Exit()
1473    else:
1474      xla_context = control_flow_ops.XLAControlFlowContext()
1475      xla_context.Enter()
1476      loop = create_while_loop()
1477      xla_context.Exit()
1478      with self.assertRaisesRegexp(
1479          ValueError,
1480          r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
1481          r"while_loop. maximum_iterations tensor '.*Placeholder:0' for "
1482          r"while_loop context '.+' must be statically known \(e.g. a constant "
1483          r"value or known shape dimension\), or be defined at or outside the "
1484          r"while loop context '' \(currently defined in 'cond/.+'\)"):
1485        _ = gradients_impl.gradients(loop, v)
1486
1487  @test_util.run_v1_only("b/120545219")
1488  def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self):
1489    if test_util.is_gpu_available():
1490      self.skipTest("b/128646372, b/128645947 fails in opensource build")
1491
1492    v = constant_op.constant(1.0)
1493
1494    p = array_ops.placeholder(dtype=dtypes.int32)
1495
1496    def mid_body_builder(iterations):
1497
1498      def mid_body(i, x):
1499        r = control_flow_ops.while_loop(
1500            lambda *_: True,
1501            lambda i, x: (i + 1, v * x), (0, x),
1502            maximum_iterations=iterations,
1503            name="inner")
1504        return (i + 1, gradients_impl.gradients(x + r[1], v)[0])
1505
1506      return mid_body
1507
1508    def outer_body(i, x):
1509      iterations = array_ops.size(p, name="iterations")
1510      return (i + 1, x + control_flow_ops.while_loop(
1511          lambda *_: True,
1512          mid_body_builder(iterations), (0, x),
1513          maximum_iterations=iterations,
1514          name="mid")[1])
1515
1516    def create_while_loop():
1517      with ops.device("/cpu:0"):
1518        r = control_flow_ops.while_loop(
1519            lambda *_: True,
1520            outer_body, (0, 1.0),
1521            maximum_iterations=5,
1522            name="outer")
1523        return array_ops.identity(r[1])
1524
1525    xla_context = control_flow_ops.XLAControlFlowContext()
1526    xla_context.Enter()
1527    final_with_xla_context = create_while_loop()
1528    xla_context.Exit()
1529
1530    final_without_xla_context = create_while_loop()
1531
1532    with self.session(use_gpu=False) as sess:
1533      opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
1534      run_metadata_without_xla_context = config_pb2.RunMetadata()
1535      run_metadata = config_pb2.RunMetadata()
1536
1537      final_value_without_xla_context = sess.run(
1538          final_without_xla_context,
1539          feed_dict={p: [0, 0, 0]},
1540          options=opts,
1541          run_metadata=run_metadata_without_xla_context)
1542
1543      final_value_with_xla_context = sess.run(
1544          final_with_xla_context,
1545          feed_dict={p: [0, 0, 0]},
1546          options=opts,
1547          run_metadata=run_metadata)
1548
1549      if control_flow_util.ENABLE_CONTROL_FLOW_V2:
1550        # With while_v2 on xla, run_metadata only contains the unlowered While
1551        # op so node_stats does not have statistics for the pushes. So as a
1552        # loose check we check the pushes in the lowered version.
1553        node_stats = run_metadata_without_xla_context.step_stats.dev_stats[
1554            0].node_stats
1555        stack_push_op = "TensorListPushBack"
1556      else:
1557        node_stats = run_metadata.step_stats.dev_stats[0].node_stats
1558        stack_push_op = "StackPushV2"
1559      stack_push_count = len(
1560          [x for x in node_stats if x.node_name.endswith(stack_push_op)])
1561      # Pushes to the stack = product of maximum_iterations values;
1562      # the last two "3"s comes from size(p), when p == [0, 0, 0].
1563      self.assertEqual(stack_push_count, 5 * 3 * 3, str(node_stats))
1564
1565      self.assertAllClose(final_value_with_xla_context,
1566                          final_value_without_xla_context)
1567
1568  # Have more than 10 parallel iterations and hence exercise k-bound
1569  # most of the time.
1570  @test_util.run_deprecated_v1
1571  def testWhile_3(self):
1572    with self.cached_session():
1573
1574      def compute(i, m, c, o):
1575        m, c = [math_ops.add(m, 1), math_ops.add(c, 1)]
1576        o = math_ops.add(o, m)
1577        o = math_ops.add(o, c)
1578        i = math_ops.add(i, 1)
1579        return [i, m, c, o]
1580
1581      i = ops.convert_to_tensor(0)
1582      m = ops.convert_to_tensor(0)
1583      c = ops.convert_to_tensor(0)
1584      o = ops.convert_to_tensor(0)
1585      d = ops.convert_to_tensor(100)
1586      r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, d),
1587                                      compute, [i, m, c, o])
1588      result = r[3]
1589    self.assertAllEqual(10100, result)
1590
1591  @test_util.run_deprecated_v1
1592  def testWhile_4(self):
1593    with self.cached_session():
1594
1595      def compute(i, m, c, o):
1596        m, c = [array_ops.gather(x, i), array_ops.gather(x, i)]
1597        o = math_ops.add(o, m)
1598        o = math_ops.add(o, c)
1599        i = math_ops.add(i, 1)
1600        return [i, m, c, o]
1601
1602      i = ops.convert_to_tensor(0)
1603      m = ops.convert_to_tensor(0)
1604      c = ops.convert_to_tensor(0)
1605      o = ops.convert_to_tensor(0)
1606      x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
1607      s = array_ops.size(x)
1608      r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, s),
1609                                      compute, [i, m, c, o])
1610      result = r[3]
1611    self.assertAllEqual(42, result)
1612
1613  @test_util.run_v1_only("b/120545219")
1614  def testWhile_5(self):
1615    with self.cached_session():
1616
1617      def compute(i, c, o):
1618        c = array_ops.strided_slice(x, array_ops.expand_dims(i, 0),
1619                                    [1] + array_ops.expand_dims(i, 0))
1620        o = array_ops.concat([o, c], 0)
1621        i = math_ops.add(i, 1)
1622        return [i, c, o]
1623
1624      i = ops.convert_to_tensor(0)
1625      c = ops.convert_to_tensor([0])
1626      o = ops.convert_to_tensor([0])
1627      x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
1628      s = array_ops.size(x)
1629      r = control_flow_ops.while_loop(lambda i, c, o: math_ops.less(i, s),
1630                                      compute, [i, c, o], [
1631                                          i.get_shape(),
1632                                          tensor_shape.unknown_shape(),
1633                                          tensor_shape.unknown_shape()
1634                                      ])
1635      result = r[2]
1636    self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result)
1637
1638  @test_util.run_gpu_only
1639  @test_util.run_deprecated_v1
1640  def testWhile_Device(self):
1641
1642    # Body function defined outside of device scope
1643    def body(x):
1644      return math_ops.exp(x)
1645
1646    with ops.device("CPU:0"):
1647      r = control_flow_ops.while_loop(
1648          lambda x: x < 10, body, [constant_op.constant(-10.)])
1649      self.assertIn("cpu", r.device.lower())
1650
1651    with session.Session() as sess:
1652      options = config_pb2.RunOptions(output_partition_graphs=True)
1653      run_metadata = config_pb2.RunMetadata()
1654      sess.run(r, options=options, run_metadata=run_metadata)
1655      # We expect that everything runs on CPU, even if GPU is available.
1656      self.assertEqual(len(run_metadata.partition_graphs), 1)
1657
1658  @test_util.disable_control_flow_v2("b/116338794 (buffer_reuse)")
1659  @test_util.run_v1_only("b/120545219")
1660  def testBufferForwarding(self):
1661    run_options = config_pb2.RunOptions(
1662        trace_level=config_pb2.RunOptions.FULL_TRACE)
1663    run_metadata = config_pb2.RunMetadata()
1664
1665    with self.cached_session() as sess:
1666      with ops.device("/cpu:0"):
1667        c = constant_op.constant(2)
1668        i0 = constant_op.constant(0)
1669        r = control_flow_ops.while_loop(lambda i: i < 1000,
1670                                        lambda i: math_ops.square(c) + i, [i0])
1671      r_val = sess.run(r, options=run_options, run_metadata=run_metadata)
1672      self.assertEqual(1000, r_val)
1673      self.assertTrue(run_metadata.HasField("step_stats"))
1674      unique_allocs = set()
1675      for node_stat in run_metadata.step_stats.dev_stats[0].node_stats:
1676        for output in node_stat.output:
1677          unique_allocs.add(
1678              output.tensor_description.allocation_description.ptr)
1679      # Prior to cl/147536680, the number of unique allocations was about 1005.
1680      self.assertLess(len(unique_allocs), 756)
1681
1682  def _testWhile_Gpu_1(self, use_gpu):
1683    with self.cached_session(use_gpu=use_gpu):
1684      n = constant_op.constant(1.0)
1685      c = lambda x: math_ops.less(x, 10.0)
1686      b = lambda x: math_ops.add(x, 1.0)
1687      r = control_flow_ops.while_loop(c, b, [n])
1688      self.assertAllClose(10.0, self.evaluate(r))
1689
1690  def testWhile_Gpu_1(self):
1691    self._testWhile_Gpu_1(use_gpu=False)
1692    self._testWhile_Gpu_1(use_gpu=True)
1693
1694  def _testWhile_Gpu_2(self, use_gpu):
1695    with self.cached_session(use_gpu=use_gpu):
1696      n = constant_op.constant(1.0)
1697      c = lambda x: math_ops.less(x, 10.0)
1698
1699      def b(x):
1700        with ops.device("/cpu:0"):
1701          return math_ops.add(x, 1.0)
1702
1703      r = control_flow_ops.while_loop(c, b, [n])
1704      self.assertAllClose(10.0, self.evaluate(r))
1705
1706  def testWhile_Gpu_2(self):
1707    self._testWhile_Gpu_2(use_gpu=False)
1708    self._testWhile_Gpu_2(use_gpu=True)
1709
1710  def testWhileShape(self):
1711    with self.cached_session():
1712      i = constant_op.constant(0)
1713      m = array_ops.ones([2, 2])
1714      c = lambda i, j: math_ops.less(i, 2)
1715
1716      def _b(i, j):
1717        new_i = math_ops.add(i, 1)
1718        new_j = array_ops.tile(j, [2, 2])
1719        return [new_i, new_j]
1720
1721      r = control_flow_ops.while_loop(
1722          c, _b, [i, m],
1723          [i.get_shape(), tensor_shape.unknown_shape()])
1724      r = r[1] * array_ops.ones([8, 8])
1725      self.assertAllEqual(np.ones((8, 8)), self.evaluate(r))
1726
1727  @test_util.run_deprecated_v1
1728  def testWhileWithNonTensorInput_Scalar(self):
1729    with self.cached_session():
1730      n = 0
1731      c = lambda x: x < 10000
1732      b = lambda x: x + 1
1733      r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
1734      self.assertEqual(10000, self.evaluate(r))
1735
1736  def testWhileWithNonTensorInput_Vector(self):
1737    with self.cached_session():
1738      n = np.array([0])  # Note, [0] would not work here; that is a list
1739      c = lambda x: x[0] < 10000
1740      b = lambda x: array_ops.stack([x[0] + 1])
1741      r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
1742      self.assertEqual([10000], self.evaluate(r))
1743
1744  @test_util.run_v1_only("b/120545219")
1745  def testWhileShapeInference(self):
1746    with self.cached_session():
1747      i = constant_op.constant(0)
1748      m = array_ops.ones([2, 2])
1749      c = lambda i, j: math_ops.less(i, 2)
1750
1751      def b(i, j):
1752        new_i = math_ops.add(i, 1)
1753        new_j = array_ops.concat([j, j], 0)
1754        return [new_i, new_j]
1755
1756      r = control_flow_ops.while_loop(
1757          c, b, [i, m],
1758          [i.get_shape(), tensor_shape.TensorShape([None, 2])])
1759      self.assertIsNone(r[1].shape.dims[0].value)
1760      self.assertEqual(r[1].shape.dims[1], tensor_shape.Dimension(2))
1761
1762      with self.assertRaisesRegexp(
1763          ValueError,
1764          r"Input tensor 'ones:0' enters the loop with shape \(2, 2\), but has "
1765          r"shape \(4, 2\) after one iteration. To allow the shape to vary "
1766          r"across iterations, use the `shape_invariants` argument of "
1767          r"tf.while_loop to specify a less-specific shape."):
1768        r = control_flow_ops.while_loop(c, b, [i, m])
1769
1770  @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
1771  @test_util.run_v1_only("b/120545219")
1772  def testWhileShapeInferenceSparseTensor(self):
1773    values = constant_op.constant([2.0, 4.0], name="values")
1774    indices = constant_op.constant([[0], [3]],
1775                                   dtype=dtypes.int64,
1776                                   name="indices")
1777    shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
1778    i = constant_op.constant(0)
1779    x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
1780
1781    def c(i, _):
1782      return i < 10
1783
1784    def b1(i, x):  # modifies values.  (shape of components is not changed.)
1785      return [
1786          i + 1,
1787          sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape)
1788      ]
1789
1790    def b2(i, x):  # adds new values.  (shape of components is changed.)
1791      return [
1792          i + 1,
1793          sparse_ops.sparse_add(
1794              x,
1795              sparse_tensor.SparseTensor(
1796                  indices=math_ops.cast(
1797                      array_ops.fill([1, 1], i), dtypes.int64),
1798                  values=array_ops.fill([1], 1.0),
1799                  dense_shape=x.dense_shape))
1800      ]
1801
1802    def b3(i, x):  # modifies rank.  (shape of all components is changed.)
1803      return [
1804          i + 1,
1805          sparse_tensor.SparseTensor(
1806              array_ops.concat([x.indices, [[i], [i]]], axis=1), x.values * 2.0,
1807              array_ops.concat([x.dense_shape, [10]], axis=0))
1808      ]
1809
1810    # Default shape invariant; b1 only modifies values.
1811    _, r = control_flow_ops.while_loop(c, b1, [i, x])
1812    self.assertEqual(r.indices.get_shape().as_list(), [None, 1])
1813    self.assertEqual(r.values.get_shape().as_list(), [None])
1814    self.assertEqual(r.dense_shape.get_shape().as_list(), [1])
1815
1816    # Default shape invariant; b2 adds new values
1817    _, r = control_flow_ops.while_loop(c, b2, [i, x])
1818    self.assertEqual(r.indices.get_shape().as_list(), [None, 1])
1819    self.assertEqual(r.values.get_shape().as_list(), [None])
1820    self.assertEqual(r.dense_shape.get_shape().as_list(), [1])
1821
1822    # Default shape invariant; b3 modifies rank (which is not allowed).
1823    with self.assertRaises(ValueError):
1824      _, r = control_flow_ops.while_loop(c, b3, [i, x])
1825
1826    # Explicit shape invariant, allowing any rank; b1 only modifies values.
1827    _, r = control_flow_ops.while_loop(
1828        c, b1, [i, x],
1829        [i.get_shape(), tensor_shape.TensorShape([None])])
1830    self.assertEqual(r.indices.get_shape().as_list(), [None, None])
1831    self.assertEqual(r.values.get_shape().as_list(), [None])
1832    self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
1833
1834    # Explicit shape invariant, allowing any rank; b3 modifies rank.
1835    _, r = control_flow_ops.while_loop(
1836        c, b3, [i, x],
1837        [i.get_shape(), tensor_shape.TensorShape([None])])
1838    self.assertEqual(r.indices.get_shape().as_list(), [None, None])
1839    self.assertEqual(r.values.get_shape().as_list(), [None])
1840    self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
1841
1842    # Shape invariant with ndims=None.  Technically, this isn't supported
1843    # according to the docs, but we support it for backwards compatibility.
1844    _, r = control_flow_ops.while_loop(
1845        c, b1, [i, x],
1846        [i.get_shape(), tensor_shape.TensorShape(None)])
1847    self.assertEqual(r.indices.get_shape().as_list(), [None, None])
1848    self.assertEqual(r.values.get_shape().as_list(), [None])
1849    self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
1850    _, r = control_flow_ops.while_loop(
1851        c, b3, [i, x],
1852        [i.get_shape(), tensor_shape.TensorShape(None)])
1853    self.assertEqual(r.indices.get_shape().as_list(), [None, None])
1854    self.assertEqual(r.values.get_shape().as_list(), [None])
1855    self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
1856
1857    # Explicit shape invariant, with a specific (incompatible) rank.
1858    with self.assertRaisesRegexp(ValueError, "is not compatible with"):
1859      _, r = control_flow_ops.while_loop(
1860          c, b1, [i, x],
1861          [i.get_shape(), tensor_shape.TensorShape([5])])
1862
1863  @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
1864  @test_util.run_v1_only("b/120545219")
1865  def testWhileShapeInferenceIndexedSlices(self):
1866    with self.cached_session():
1867      values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
1868      indices = constant_op.constant([0, 3], name="indices")
1869      shape = constant_op.constant([10, 2], name="dense_shape")
1870      i = constant_op.constant(0)
1871      x = ops.IndexedSlices(values, indices, dense_shape=shape)
1872
1873      def c(i, _):
1874        return i < 10
1875
1876      def b(i, x):
1877        return [
1878            i + 1,
1879            ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape)
1880        ]
1881
1882      _, r = control_flow_ops.while_loop(c, b, [i, x])
1883      self.assertEqual(r.dense_shape.get_shape()[0], 2)
1884      self.assertEqual(r.values.get_shape(), tensor_shape.TensorShape([2, 2]))
1885
1886      _, r = control_flow_ops.while_loop(
1887          c, b, [i, x],
1888          [i.get_shape(), tensor_shape.TensorShape([None, 2])])
1889      self.assertEqual(r.dense_shape.get_shape()[0], 2)
1890      self.assertEqual(r.values.get_shape().as_list(), [None, 2])
1891
1892      with self.assertRaisesRegexp(ValueError, "is not compatible with"):
1893        _, r = control_flow_ops.while_loop(
1894            c, b, [i, x],
1895            [i.get_shape(), tensor_shape.TensorShape([None, 5])])
1896
1897  @test_util.disable_control_flow_v2("b/116328420 (RaggedTensor)")
1898  def testWhileShapeInferenceRaggedTensor(self):
1899    if context.executing_eagerly():
1900      self.skipTest("b/116328420")
1901    i = constant_op.constant(0)
1902    x = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
1903    c = lambda i, _: i < 10
1904
1905    def b1(i, x):  # Adds new values to rows (but doesn't create new rows)
1906      return [
1907          i + 1,
1908          array_ops.concat([x, x], axis=1)
1909      ]
1910
1911    def b2(i, x):  # Adds new rows.
1912      return [
1913          i + 1,
1914          array_ops.concat([x, x], axis=0)
1915      ]
1916
1917    # Default shape invariant; b1 adds new values to rows.
1918    _, r = control_flow_ops.while_loop(c, b1, [i, x])
1919    self.assertEqual(r.row_splits.shape.as_list(), [4])
1920
1921    self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None]))
1922
1923    # Default shape invariant; b2 adds new rows (not allowed).
1924    if not context.executing_eagerly():
1925      with self.assertRaises(ValueError):
1926        _, r = control_flow_ops.while_loop(c, b2, [i, x])
1927
1928    # Explicit shape invariant; b1 adds new values to rows.
1929    _, r = control_flow_ops.while_loop(
1930        c, b1, [i, x],
1931        [i.get_shape(), tensor_shape.TensorShape([None, None])])
1932    self.assertTrue(r.row_splits.shape.as_list() in ([4], [None]))
1933    self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None]))
1934
1935    # Explicit shape invariant; b2 adds new rows.
1936    _, r = control_flow_ops.while_loop(
1937        c, b2, [i, x],
1938        [i.get_shape(), tensor_shape.TensorShape([None, None])])
1939    self.assertTrue(r.row_splits.shape.as_list() in ([3 * 2**10 + 1], [None]))
1940    self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None]))
1941
1942  @test_util.disable_control_flow_v2("b/116328420 (RaggedTensor)")
1943  def testWhileShapeInferenceRaggedTensorRaggedRank2(self):
1944    if context.executing_eagerly():
1945      self.skipTest("b/116328420")
1946    i = constant_op.constant(0)
1947    x = ragged_factory_ops.constant([[[1, 2], [3], [4, 5, 6]],
1948                                     [[], [8, 9, 10]]])
1949    c = lambda i, _: i < 10
1950    def b(i, x):
1951      return [
1952          i + 1,
1953          array_ops.concat([x, x[..., i:i+1]], axis=-1)
1954      ]
1955    _, r = control_flow_ops.while_loop(c, b, [i, x])
1956    self.assertEqual(r.row_splits.shape.as_list(), [3])
1957    self.assertTrue(r.values.row_splits.shape.as_list() in ([6], [None]))
1958    self.assertTrue(r.values.values.shape.as_list() in ([49], [None]))
1959
1960  def _testNestedWhile_1(self, use_gpu):
1961    with self.cached_session(use_gpu=use_gpu):
1962      n = constant_op.constant(0)
1963
1964      def cpu_sum(s):
1965        c = lambda i, s: math_ops.less(i, 10)
1966
1967        def b(i, s):
1968          i1 = math_ops.add(i, 1)
1969          with ops.device("/cpu:0"):
1970            s1 = math_ops.add(i, s)
1971          return i1, s1
1972
1973        _, r_s = control_flow_ops.while_loop(c, b, [n, s])
1974        return r_s
1975
1976      c = lambda x: math_ops.less(x, 200)
1977      b = lambda x: math_ops.add(x, cpu_sum(n))
1978      r = control_flow_ops.while_loop(c, b, [n])
1979      self.assertEqual(225, self.evaluate(r))
1980
1981  def testNestedWhile_1(self):
1982    self._testNestedWhile_1(use_gpu=False)
1983    self._testNestedWhile_1(use_gpu=True)
1984
1985  def _testNestedWhile_2(self, use_gpu):
1986    # Test the cases that A -> Enter and Exit -> A are partitioned.
1987    with self.cached_session(use_gpu=use_gpu):
1988      s0 = constant_op.constant(2.0)
1989
1990      def inner_loop(s):
1991        c = lambda s: math_ops.less(s, 20.0)
1992
1993        def b(s):
1994          s1 = math_ops.add(s, s)
1995          return s1
1996
1997        r_s = control_flow_ops.while_loop(c, b, [s], parallel_iterations=1)
1998        return r_s
1999
2000      outer_c = lambda x: math_ops.less(x, 3000.0)
2001
2002      def outer_b(x):
2003        x = logging_ops.Print(x, [x])  # Edge "Print -> Enter" is partitioned
2004        x = inner_loop(x)
2005        with ops.device("/cpu:0"):
2006          x = math_ops.square(x)  # Edge "Exit -> Square" is partitioned
2007        return x
2008
2009      r = control_flow_ops.while_loop(
2010          outer_c, outer_b, [s0], parallel_iterations=1)
2011      self.assertEqual(1048576.0, self.evaluate(r))
2012
2013  def testNestedWhile_2(self):
2014    self._testNestedWhile_2(use_gpu=False)
2015    self._testNestedWhile_2(use_gpu=True)
2016
2017  @test_util.run_v1_only("b/120545219")
2018  def testWhileWithControl_1(self):
2019    with self.cached_session():
2020      n = constant_op.constant(0)
2021      r = constant_op.constant(0)
2022      condition = lambda n_, r_: math_ops.less(n_, 10)
2023
2024      def body(n_, r_):
2025        n_ = math_ops.add(n_, 1)
2026        with r_.graph.control_dependencies([r_]):
2027          r_ = constant_op.constant(12)
2028        return [n_, r_]
2029
2030      res = control_flow_ops.while_loop(
2031          condition, body, [n, r], parallel_iterations=1)
2032      self.assertAllEqual(12, res[1])
2033
2034  @test_util.run_deprecated_v1
2035  def testWhileWithControl_2(self):
2036    with self.cached_session():
2037      r = constant_op.constant(0)
2038      condition = lambda r_: math_ops.less(r_, 10)
2039
2040      def body(r_):
2041        with r_.graph.control_dependencies([r_]):
2042          r_ = constant_op.constant(12)
2043        return [r_]
2044
2045      res = control_flow_ops.while_loop(
2046          condition, body, [r], parallel_iterations=1)
2047      self.assertAllEqual(12, self.evaluate(res))
2048
2049  @test_util.run_v1_only("b/120545219")
2050  def testWhileWithControl_3(self):
2051    with self.cached_session() as sess:
2052      b = array_ops.placeholder(dtypes.bool)
2053      c = constant_op.constant(1)
2054      x0 = constant_op.constant(0)
2055      with ops.control_dependencies([b]):
2056        r = control_flow_ops.while_loop(lambda x: x < 10, lambda x: x + c, [x0])
2057      self.assertEqual(10, sess.run(r, {b: True}))
2058
2059  @test_util.run_v1_only("b/120545219")
2060  def testWhileWithControl_4(self):
2061    with self.cached_session() as sess:
2062      b = array_ops.placeholder(dtypes.bool)
2063      c = constant_op.constant(1)
2064      x0 = constant_op.constant(0)
2065      with ops.control_dependencies([b]):
2066        r = control_flow_ops.while_loop(
2067            lambda x: x < 10, lambda x: x + array_ops.identity(c), [x0])
2068      self.assertEqual(10, sess.run(r, {b: True}))
2069
2070  @test_util.run_v1_only("b/120545219")
2071  def testWhileWithControl_5(self):
2072    with self.cached_session() as sess:
2073      b = array_ops.placeholder(dtypes.bool)
2074      c = constant_op.constant(1)
2075      x0 = constant_op.constant(0)
2076
2077      def body(x):
2078        with ops.control_dependencies([b]):
2079          return x + c
2080
2081      r = control_flow_ops.while_loop(lambda x: x < 10, body, [x0])
2082      self.assertEqual(10, sess.run(r, {b: True}))
2083
2084  def testWhileCondWithControl(self):
2085    # Ensure that no control edges by an outer control dependency context are
2086    # added to nodes inside cond/while contexts.
2087    with self.cached_session() as sess:
2088      const_true = lambda: constant_op.constant(True)
2089      const_false = lambda: constant_op.constant(False)
2090      cond = lambda i: control_flow_ops.cond(i > 0, const_true, const_false)
2091      body = lambda i: control_flow_ops.cond(i > 0, lambda: i - 1, lambda: i)
2092
2093      with ops.control_dependencies([control_flow_ops.no_op()]):
2094        loop = control_flow_ops.while_loop(cond, body,
2095                                           (constant_op.constant(5),))
2096      self.assertEqual(0, self.evaluate(loop))
2097
2098  @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
2099  @test_util.run_v1_only("b/120545219")
2100  def testWhileCondWithControl_1(self):
2101    with self.cached_session():
2102      v = variable_scope.get_variable(
2103          "v", [], initializer=init_ops.constant_initializer(2))
2104      i0 = constant_op.constant(0)
2105      with ops.control_dependencies([i0]):
2106
2107        def loop_condition(i):
2108          return i < 4
2109
2110        def loop_body(i):
2111          some_cond = control_flow_ops.cond(
2112              constant_op.constant(True),
2113              lambda: state_ops.assign(v, math_ops.square(v)), lambda: v)
2114          with ops.control_dependencies([some_cond]):
2115            return i + 1
2116
2117      r = control_flow_ops.while_loop(loop_condition, loop_body, (i0,))
2118      self.evaluate(variables.global_variables_initializer())
2119      self.assertEqual(4, self.evaluate(r))
2120      self.assertAllClose(65536.0, self.evaluate(v))
2121
2122  @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
2123  @test_util.run_v1_only("b/120545219")
2124  def testWhileCondExitControl(self):
2125
2126    with self.cached_session():
2127      v = variables.Variable(1)
2128
2129      def false_branch():
2130        cond = lambda i: i < 100
2131
2132        def body(i):
2133          x = state_ops.assign(v, i)
2134          return x + 1
2135
2136        loop = control_flow_ops.while_loop(cond, body, [0])
2137        # Make sure to handle correctly control edge from Exit to a node.
2138        with ops.control_dependencies([loop]):
2139          return constant_op.constant(6.0)
2140
2141      r = control_flow_ops.cond(
2142          constant_op.constant(False), lambda: constant_op.constant(1.0),
2143          false_branch)
2144      self.evaluate(variables.global_variables_initializer())
2145      self.assertEqual(6.0, self.evaluate(r))
2146      self.assertEqual(99, self.evaluate(v))
2147
2148  def testCondWhile_1(self):
2149
2150    with self.cached_session():
2151      n = ops.convert_to_tensor(0, name="n")
2152      c = lambda x: math_ops.less(x, 10)
2153      b = lambda x: math_ops.add(x, 1)
2154      r = control_flow_ops.cond(
2155          math_ops.less(0, 1), lambda: control_flow_ops.while_loop(c, b, [n]),
2156          lambda: n)
2157      self.assertAllEqual(10, self.evaluate(r))
2158
2159  def testCondWhile_2(self):
2160
2161    with self.cached_session():
2162      n = ops.convert_to_tensor(0)
2163      c = lambda x: math_ops.less(x, 10)
2164      b = lambda x: math_ops.add(x, 1)
2165      r = control_flow_ops.cond(
2166          math_ops.less(1, 0), lambda: math_ops.add(n, 1),
2167          lambda: control_flow_ops.while_loop(c, b, [n]))
2168      self.assertAllEqual(10, self.evaluate(r))
2169
2170  def _testCondWhile_3(self, use_gpu):
2171    with self.cached_session(use_gpu=use_gpu) as sess:
2172      p = array_ops.placeholder(dtypes.bool)
2173      n = constant_op.constant(0.0)
2174
2175      def c(x):
2176        return math_ops.less(x, 10.0)
2177
2178      def b(x):
2179        with ops.device("/cpu:0"):
2180          x1 = math_ops.add(x, 1.0)
2181        return x1
2182
2183      r = control_flow_ops.cond(p,
2184                                lambda: control_flow_ops.while_loop(c, b, [n]),
2185                                lambda: math_ops.multiply(n, 2.0))
2186      r1 = gradients_impl.gradients(r, [n])
2187      self.assertEqual(10., sess.run(r, {p: True}))
2188      self.assertEqual([1.0], sess.run(r1, {p: True}))
2189      self.assertEqual(0.0, sess.run(r, {p: False}))
2190      self.assertEqual([2.0], sess.run(r1, {p: False}))
2191
2192  @test_util.run_deprecated_v1
2193  def testCondWhile_3(self):
2194    self._testCondWhile_3(use_gpu=False)
2195    self._testCondWhile_3(use_gpu=True)
2196
2197  def testWhileCond_1(self):
2198
2199    with self.cached_session():
2200      i = ops.convert_to_tensor(0, name="i")
2201      n = ops.convert_to_tensor(10, name="n")
2202      one = ops.convert_to_tensor(1, name="one")
2203      c = lambda x: math_ops.less(x, n)
2204      # pylint: disable=undefined-variable
2205      # for OSS build
2206      b = lambda x: control_flow_ops.cond(
2207          constant_op.constant(True),
2208          lambda: math_ops.add(x, one), lambda: math_ops.subtract(x, one))
2209      # pylint: enable=undefined-variable
2210      r = control_flow_ops.while_loop(c, b, [i])
2211      self.assertAllEqual(10, self.evaluate(r))
2212
2213  def testWhileCond_2(self):
2214
2215    with self.cached_session():
2216      n = ops.convert_to_tensor(0, name="n")
2217      c = lambda x: math_ops.less(x, 10)
2218      b = lambda x: control_flow_ops.cond(constant_op.constant(True), lambda: math_ops.add(x, 1), lambda: n)
2219      r = control_flow_ops.while_loop(c, b, [n])
2220      self.assertAllEqual(10, self.evaluate(r))
2221
2222  def testWhileCond_3(self):
2223
2224    with self.cached_session():
2225      n = ops.convert_to_tensor(0)
2226      c = lambda x: math_ops.less(x, 10)
2227      # pylint: disable=undefined-variable
2228      # for OSS build
2229      b = lambda x: control_flow_ops.cond(math_ops.less(0, 1),
2230                                          lambda: math_ops.add(x, 1),
2231                                          lambda: math_ops.subtract(x, 1))
2232      # pylint: enable=undefined-variable
2233      r = control_flow_ops.while_loop(c, b, [n])
2234      self.assertAllEqual(10, self.evaluate(r))
2235
2236  @test_util.run_deprecated_v1
2237  def testWhileCondGradMultiDevice(self):
2238    config = config_pb2.ConfigProto(device_count={"CPU": 2},
2239                                    allow_soft_placement=True)
2240    with self.cached_session(use_gpu=True, config=config) as sess:
2241      pred = array_ops.placeholder(dtypes.bool, [])
2242      x_init = constant_op.constant(1.0)
2243
2244      with ops.device("/cpu:0"):
2245        z = control_flow_ops.while_loop(
2246            lambda i, _: i < 3,
2247            lambda i, x: (i + 1, control_flow_ops.cond(
2248                pred, lambda: x * 2.0, lambda: 10.0)),
2249            [0, x_init])
2250
2251      with ops.device("/cpu:1"):
2252        grad = gradients_impl.gradients(z, x_init)[0]
2253
2254      with ops.device("/cpu:0"):
2255        grad_grad = gradients_impl.gradients(grad, x_init)[0]
2256
2257      self.assertEqual(sess.run(grad, {pred: True}), 8.0)
2258      self.assertEqual(sess.run(grad, {pred: False}), 0.0)
2259
2260      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
2261        return
2262
2263      self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0)
2264      self.assertEqual(sess.run(grad_grad, {pred: False}), 0.0)
2265
2266  # NOTE: It is ok to have parallel_iterations > 1
2267  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2268  @test_util.run_deprecated_v1
2269  def testWhileUpdateVariable_1(self):
2270    with self.cached_session():
2271      select = variables.Variable([3.0, 4.0, 5.0])
2272      n = constant_op.constant(0)
2273
2274      def loop_iterator(j):
2275        return math_ops.less(j, 3)
2276
2277      def loop_body(j):
2278        ns = state_ops.scatter_update(select, j, 10.0)
2279        nj = math_ops.add(j, 1)
2280        op = control_flow_ops.group(ns)
2281        nj = control_flow_ops.with_dependencies([op], nj)
2282        return [nj]
2283
2284      r = control_flow_ops.while_loop(
2285          loop_iterator, loop_body, [n], parallel_iterations=1)
2286      self.evaluate(variables.global_variables_initializer())
2287      self.assertEqual(3, self.evaluate(r))
2288      result = self.evaluate(select)
2289      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
2290
2291  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2292  @test_util.run_v1_only("b/120545219")
2293  def testWhileUpdateVariable_2(self):
2294    with self.cached_session():
2295      select1 = variables.Variable([3.0, 4.0, 5.0])
2296      select2 = variables.Variable([3.0, 4.0, 5.0])
2297      n = constant_op.constant(0)
2298
2299      def loop_iterator(j):
2300        return math_ops.less(j, 3)
2301
2302      def loop_body(j):
2303        ns1 = state_ops.scatter_update(select1, j, 10.0)
2304        ns2 = state_ops.scatter_update(select2, j, 10.0)
2305        nj = math_ops.add(j, 1)
2306        op = control_flow_ops.group(ns1, ns2)
2307        nj = control_flow_ops.with_dependencies([op], nj)
2308        return [nj]
2309
2310      r = control_flow_ops.while_loop(
2311          loop_iterator, loop_body, [n], parallel_iterations=1)
2312      self.evaluate(variables.global_variables_initializer())
2313      self.assertEqual(3, self.evaluate(r))
2314      result1 = self.evaluate(select1)
2315      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result1)
2316      result2 = self.evaluate(select2)
2317      self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
2318
2319  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2320  @test_util.run_v1_only("b/120545219")
2321  def testWhileUpdateVariable_3(self):
2322    with self.cached_session():
2323      select = variables.Variable([3.0, 4.0, 5.0])
2324      n = constant_op.constant(0)
2325
2326      def loop_iterator(j, _):
2327        return math_ops.less(j, 3)
2328
2329      def loop_body(j, _):
2330        ns = state_ops.scatter_update(select, j, 10.0)
2331        nj = math_ops.add(j, 1)
2332        return [nj, ns]
2333
2334      r = control_flow_ops.while_loop(
2335          loop_iterator,
2336          loop_body, [n, array_ops.identity(select)],
2337          parallel_iterations=1)
2338      self.evaluate(variables.global_variables_initializer())
2339      result = r[1]
2340    self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
2341
2342  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2343  @test_util.run_v1_only("b/120545219")
2344  def testWhileUpdateVariable_4(self):
2345    with self.cached_session():
2346      var_a = variables.Variable(0, name="a")
2347      var_b = variables.Variable(0, name="b")
2348      self.evaluate(variables.global_variables_initializer())
2349
2350      c = constant_op.constant(0, name="c")
2351      asn1 = state_ops.assign_add(var_a, 1, name="a_add")
2352
2353      # Loop condition
2354      def pred(i):
2355        return math_ops.less(i, 10)
2356
2357      # Loop body
2358      def loop_body(i):
2359        asn2 = state_ops.assign_add(var_b, asn1, name="b_add")
2360        with ops.control_dependencies([asn2]):
2361          ni = math_ops.add(i, 1, name="i_add")
2362        return ni
2363
2364      lpa = control_flow_ops.while_loop(
2365          pred, loop_body, [c], parallel_iterations=1)
2366
2367      self.assertEqual(0, self.evaluate(var_b))
2368      self.evaluate(lpa)  # Run the loop
2369      self.assertEqual(10, self.evaluate(var_b))
2370
2371  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2372  @test_util.run_v1_only("b/120545219")
2373  def testWhileUpdateVariable_5(self):
2374    with self.cached_session():
2375      # Create some variables.
2376      var_a = variables.Variable(0, name="a")
2377      var_b = variables.Variable(0, name="b")
2378      self.evaluate(variables.global_variables_initializer())
2379
2380      # Change condition to check var_b
2381      def pred(_):
2382        return math_ops.less(var_b, 10)
2383
2384      # Change body to increment var_b
2385      def loop_body(i):
2386        asn1 = state_ops.assign_add(
2387            var_a, constant_op.constant(1), name="a_add")
2388        asn2 = state_ops.assign_add(
2389            var_b, constant_op.constant(1), name="b_add")
2390        with ops.control_dependencies([asn1, asn2]):
2391          inc_b = array_ops.identity(var_b)
2392        return inc_b
2393
2394      lpa = control_flow_ops.while_loop(
2395          pred, loop_body, [var_b], parallel_iterations=1, name="loop")
2396
2397      self.assertEqual(0, self.evaluate(var_b))
2398      self.evaluate(lpa)  # Run the loop
2399      self.assertEqual(10, self.evaluate(var_a))
2400      self.assertEqual(10, self.evaluate(var_b))
2401
2402  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
2403  @test_util.run_v1_only("b/120545219")
2404  def testWhileUpdateVariable_6(self):
2405    with self.cached_session():
2406      # Create some variables.
2407      var_a = variables.Variable(0, name="a")
2408      var_b = variables.Variable(0, name="b")
2409      c = constant_op.constant(0)
2410      self.evaluate(variables.global_variables_initializer())
2411
2412      # Loop condition
2413      def pred(i):
2414        return math_ops.less(i, 10)
2415
2416      # Loop body
2417      def loop_body(i):
2418        asn1 = state_ops.assign_add(var_a, 1, name="a_add")
2419        with ops.control_dependencies([asn1]):
2420          asn2 = state_ops.assign_add(var_b, var_a, name="b_add")
2421        with ops.control_dependencies([asn2]):
2422          ni = math_ops.add(i, 1, name="i_add")
2423          return ni
2424
2425      lpa = control_flow_ops.while_loop(
2426          pred, loop_body, [c], parallel_iterations=1, name="loop")
2427
2428      self.assertEqual(0, self.evaluate(var_b))
2429      self.evaluate(lpa)  # Run the loop
2430      self.assertEqual(55, self.evaluate(var_b))
2431      self.assertEqual(10, self.evaluate(var_a))
2432
2433  @test_util.run_v1_only("b/120545219")
2434  def testWhileQueue_1(self):
2435    with self.cached_session():
2436      q = data_flow_ops.FIFOQueue(-1, dtypes.int32)
2437      i = constant_op.constant(0)
2438
2439      def c(i):
2440        return math_ops.less(i, 10)
2441
2442      def b(i):
2443        ni = math_ops.add(i, 1)
2444        ni = control_flow_ops.with_dependencies([q.enqueue((i,))], ni)
2445        return ni
2446
2447      r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1)
2448      self.assertEqual([10], self.evaluate(r))
2449      for i in xrange(10):
2450        self.assertEqual([i], self.evaluate(q.dequeue()))
2451
2452  @test_util.run_v1_only("b/120545219")
2453  def testWhileTimeOut(self):
2454    run_options = config_pb2.RunOptions(timeout_in_ms=1)
2455    with self.cached_session() as sess:
2456      n = constant_op.constant(0)
2457      c = lambda x: True
2458      b = lambda x: math_ops.add(x, 1)
2459      r = control_flow_ops.while_loop(c, b, [n])
2460      with self.assertRaises(errors_impl.DeadlineExceededError):
2461        sess.run(r, options=run_options)
2462
2463  @test_util.disable_control_flow_v2("b/117119329 (stack)")
2464  @test_util.run_v1_only("b/120545219")
2465  def testWhileStack_1(self):
2466    with self.cached_session():
2467      s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo")
2468      i = constant_op.constant(0)
2469
2470      def c(i):
2471        return math_ops.less(i, 10)
2472
2473      def b(i):
2474        ni = math_ops.add(i, 1)
2475        ni = control_flow_ops.with_dependencies(
2476            [gen_data_flow_ops.stack_push_v2(s, i)], ni)
2477        return ni
2478
2479      r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1)
2480
2481      x = constant_op.constant(0)
2482
2483      def c1(i, _):
2484        return math_ops.greater(i, 0)
2485
2486      def b1(i, x):
2487        ni = math_ops.subtract(i, 1)
2488        nx = x + gen_data_flow_ops.stack_pop_v2(s, dtypes.int32)
2489        return [ni, nx]
2490
2491      _, rx = control_flow_ops.while_loop(
2492          c1,
2493          b1, [r, x],
2494          [r.get_shape(), tensor_shape.unknown_shape()],
2495          parallel_iterations=1)
2496      self.assertEqual(45, self.evaluate(rx))
2497
2498  def _testWhileGrad_ColocateGradients(self, colocate):
2499    gpu_dev_name = test.gpu_device_name() if test.is_gpu_available(
2500    ) else "/device:CPU:0"
2501
2502    graph = ops.Graph()
2503    with graph.as_default():
2504      v = constant_op.constant(2.0, name="v")
2505      c = lambda v: math_ops.less(v, 100.0)
2506
2507      def b(x):
2508        with ops.device(gpu_dev_name):
2509          return math_ops.square(x)
2510
2511      loop = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
2512      r = gradients_impl.gradients(
2513          loop, v, colocate_gradients_with_ops=colocate)[0]
2514
2515    r_ops = graph.get_operations()
2516    r_devices = [(op.name, op.device) for op in r_ops]
2517
2518    self.assertTrue(any("Square" in op.name for op in r_ops))
2519
2520    for (name, dev) in r_devices:
2521      if not colocate and name.endswith("Square"):
2522        # Only forward graph contain gpu in Square device
2523        self.assertTrue(gpu_dev_name in dev)
2524      elif colocate and "Square" in name:
2525        # Forward and backward graphs contain gpu in Square/Square_grad devices
2526        self.assertTrue(gpu_dev_name in dev)
2527      else:
2528        self.assertFalse(gpu_dev_name in dev)
2529
2530    with self.session(graph=graph) as sess:
2531      self.assertAllClose(1024.0, self.evaluate(r))
2532
2533  @test_util.disable_control_flow_v2("b/116351701 (colocation)")
2534  @test_util.run_v1_only("b/120545219")
2535  def testWhileGrad_ColocateGradients(self):
2536    self._testWhileGrad_ColocateGradients(colocate=False)
2537    self._testWhileGrad_ColocateGradients(colocate=True)
2538
2539  @test_util.run_v1_only("b/120545219")
2540  def testWhileGrad_Square(self):
2541    with self.cached_session():
2542      v = constant_op.constant(2.0, name="v")
2543      c = lambda v: math_ops.less(v, 100.0)
2544      b = math_ops.square
2545      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
2546      r = control_flow_ops.cond(math_ops.less(1, 2), lambda: r, lambda: v)
2547
2548      r = gradients_impl.gradients(r, v)[0]
2549      self.assertAllClose(1024.0, self.evaluate(r))
2550
2551  @test_util.run_v1_only("b/120545219")
2552  def testWhileGrad_Shape(self):
2553    with self.cached_session():
2554      x = array_ops.placeholder(dtypes.float32, shape=[None])
2555      v = constant_op.constant([2.0], name="v")
2556      n = constant_op.constant(0, name="n")
2557      c = lambda i, v: math_ops.less(i, 5)
2558      b = lambda i, v: [i + 1, math_ops.multiply(x, v)]
2559      r = control_flow_ops.while_loop(
2560          c,
2561          b, [n, v],
2562          [n.get_shape(), tensor_shape.unknown_shape()],
2563          parallel_iterations=1)
2564
2565      r = gradients_impl.gradients(r[1], x)[0]
2566      self.assertEqual([None], r.get_shape().as_list())
2567      self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]}))
2568
2569  @test_util.run_deprecated_v1
2570  def testWhileGrad_BaseShape(self):
2571    with self.cached_session() as sess:
2572      x = array_ops.placeholder(dtypes.float32, [None])
2573      v0 = constant_op.constant([2.0, 2.0], name="v")
2574      c = lambda v: constant_op.constant(False)
2575      b = lambda v: math_ops.multiply(v, x)
2576      r = control_flow_ops.while_loop(c, b, [v0])
2577      y = math_ops.square(x)
2578
2579      r = gradients_impl.gradients([r, y], x)[0]
2580      self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]}))
2581
2582  @test_util.run_v1_only("b/120545219")
2583  def testWhileGrad_MultipleUses(self):
2584    with self.cached_session():
2585      v = constant_op.constant(2.0, name="v")
2586      c = lambda v: math_ops.less(v, 100.0)
2587      b = math_ops.square
2588      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
2589      r = math_ops.multiply(r, r)
2590
2591      r = gradients_impl.gradients(r, v)[0]
2592      self.assertEqual(524288.0, self.evaluate(r))
2593
2594  @test_util.run_v1_only("b/120545219")
2595  def testWhileGrad_LoopAdd(self):
2596    with self.cached_session():
2597      v = constant_op.constant(2.0, name="v")
2598      c = lambda v: math_ops.less(v, 100.0)
2599      b = math_ops.square
2600      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
2601      r = math_ops.add(r, r)
2602
2603      r = gradients_impl.gradients(r, v)[0]
2604      self.assertAllClose(2048.0, self.evaluate(r))
2605
2606  def _testWhileGrad_Mul(self, use_gpu, p_iters):
2607    with self.cached_session(use_gpu=use_gpu) as sess:
2608      a = constant_op.constant(3.0, name="a")
2609      v = constant_op.constant(2.0, name="v")
2610      c = lambda v: math_ops.less(v, 100.0)
2611      b = lambda v: math_ops.multiply(v, a)
2612      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=p_iters)
2613
2614      grad_a, grad_v = gradients_impl.gradients(r, [a, v])
2615      grad_a_val, grad_v_val = self.evaluate([grad_a, grad_v])
2616      self.assertAllClose(216.0, grad_a_val)
2617      self.assertAllClose(81.0, grad_v_val)
2618
2619  @test_util.run_deprecated_v1
2620  def testWhileGrad_Mul(self):
2621    self._testWhileGrad_Mul(use_gpu=False, p_iters=1)
2622    self._testWhileGrad_Mul(use_gpu=False, p_iters=10)
2623    self._testWhileGrad_Mul(use_gpu=True, p_iters=1)
2624    self._testWhileGrad_Mul(use_gpu=True, p_iters=10)
2625
2626  def _testNestedWhileCondWhileGrad(self, use_gpu):
2627
2628    with self.cached_session(use_gpu=use_gpu):
2629      v = constant_op.constant(1.0)
2630
2631      def inner_loop(s):
2632        z = constant_op.constant(0)
2633        c = lambda i, x: math_ops.less(i, 4)
2634        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
2635        return control_flow_ops.while_loop(c, b, [z, s])
2636
2637      c = lambda x: math_ops.less(x, 128.0)
2638
2639      def b(x):
2640        return control_flow_ops.cond(
2641            constant_op.constant(True),
2642            lambda: math_ops.square(inner_loop(x)[1]),
2643            lambda: math_ops.multiply(x, 2.0))
2644
2645      r = control_flow_ops.while_loop(c, b, [v])
2646      r = gradients_impl.gradients(r, v)[0]
2647      self.assertAllClose(512.0, self.evaluate(r))
2648
2649  @test_util.run_deprecated_v1
2650  def testNestedWhileCondWhileGrad(self):
2651    self._testNestedWhileCondWhileGrad(use_gpu=False)
2652
2653  @test_util.run_deprecated_v1
2654  def testNestedWhileCondWhileGradGpu(self):
2655    self._testNestedWhileCondWhileGrad(use_gpu=True)
2656
2657  @test_util.run_v1_only("b/120545219")
2658  def testWhileGrad_Variable(self):
2659    with self.cached_session():
2660      a = variables.Variable(3.0)
2661      v = constant_op.constant(2.0, name="v")
2662      c = lambda v: math_ops.less(v, 100.0)
2663      b = lambda v: math_ops.multiply(v, a)
2664      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
2665
2666      r = gradients_impl.gradients(r, a)
2667      self.evaluate(variables.global_variables_initializer())
2668      self.assertAllClose(216.0, r[0])
2669
2670  @test_util.run_deprecated_v1
2671  def testWhileGrad_ResourceVariable(self):
2672    with self.cached_session():
2673      a = resource_variable_ops.ResourceVariable(3.0)
2674      v = constant_op.constant(2.0, name="v")
2675      c = lambda v: math_ops.less(v, 100.0)
2676      b = lambda v: math_ops.multiply(v, a)
2677      r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
2678
2679      g = gradients_impl.gradients(r, a)
2680      self.evaluate(variables.global_variables_initializer())
2681      self.assertAllClose(216.0, g[0])
2682
2683  def testWhileGrad_EagerResourceVariable(self):
2684    with context.eager_mode():
2685      a = resource_variable_ops.ResourceVariable(
2686          np.ones([2, 2], dtype=np.float32))
2687      v = constant_op.constant(1.0)
2688
2689      @eager_function.defun
2690      def fn():
2691        r = control_flow_ops.while_loop(
2692            lambda i, _: i < 2,
2693            lambda i, x: (i + 1, x * math_ops.reduce_sum(a) * v),
2694            [0, 1.0])[1]
2695        return gradients_impl.gradients(r, [v])[0]
2696
2697      self.assertEqual(self.evaluate(fn()), 32.)
2698
2699  @test_util.disable_xla("b/128643381")
2700  def testWhileGrad_ResourceVarInFunctionCall(self):
2701
2702    @def_function.function
2703    def foo(x, var):
2704      return x + math_ops.reduce_sum(var.sparse_read([1, 3]))
2705
2706    @def_function.function
2707    def bar(var):
2708      r = control_flow_ops.while_loop(
2709          lambda i, _: i < 2,
2710          lambda i, x: (i + 1, foo(x, var)),
2711          [0, 0.0])[1]
2712      return gradients_impl.gradients(r, var)[0]
2713
2714    var = resource_variable_ops.ResourceVariable([1., 2., 3., 4.])
2715    self.evaluate(variables.global_variables_initializer())
2716    grad = self.evaluate(bar(var))
2717    self.assertIsInstance(grad, ops.IndexedSlicesValue)
2718    self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.])
2719
2720  @test_util.disable_xla("b/128643461")
2721  def testWhileGrad_ResourceVarInNestedFunctionCall(self):
2722
2723    @def_function.function
2724    def foo(x, var):
2725      return x + math_ops.reduce_sum(var.sparse_read([1, 3]))
2726
2727    @def_function.function
2728    def foo2(x, var):
2729      return foo(x, var)
2730
2731    @def_function.function
2732    def bar(var):
2733      r = control_flow_ops.while_loop(
2734          lambda i, _: i < 2,
2735          lambda i, x: (i + 1, foo2(x, var)),
2736          [0, 0.0])[1]
2737      return gradients_impl.gradients(r, var)[0]
2738
2739    var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.])
2740    self.evaluate(variables.global_variables_initializer())
2741    grad = self.evaluate(bar(var))
2742    self.assertIsInstance(grad, ops.IndexedSlicesValue)
2743    self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.])
2744
2745  def testWhileGrad_ResourceVarInLoopInFunctionCall(self):
2746    if test.is_gpu_available():
2747      self.skipTest("b/128635252")
2748
2749    @def_function.function
2750    def foo(x, var):
2751      return control_flow_ops.while_loop(
2752          lambda j, _: j < 3,
2753          lambda j, y: (j + 1,
2754                        y + math_ops.reduce_sum(var.sparse_read([1, 2]))),
2755          [0, x])[1]
2756
2757    @def_function.function
2758    def bar(var):
2759      r = control_flow_ops.while_loop(
2760          lambda i, _: i < 2,
2761          lambda i, x: (i + 1, foo(x, var)),
2762          [0, 0.0])[1]
2763      return gradients_impl.gradients(r, var)[0]
2764
2765    var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.])
2766    self.evaluate(variables.global_variables_initializer())
2767    grad = self.evaluate(bar(var))
2768    self.assertIsInstance(grad, ops.IndexedSlicesValue)
2769    self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 6., 6., 0.])
2770
2771  @test_util.disable_xla("b/128639858")
2772  def testWhileCondGrad_ResourceVarInFunctionCall(self):
2773
2774    @def_function.function
2775    def foo(x, var):
2776      return x + var.sparse_read([1])[0]
2777
2778    def body(i, x):
2779      return (i + 1, control_flow_ops.cond(
2780          math_ops.equal(i % 2, 0),
2781          lambda: foo(x, var1),
2782          lambda: foo(x, var2)))
2783
2784    @def_function.function
2785    def bar(var1, var2):
2786      r = control_flow_ops.while_loop(
2787          lambda i, _: i < 4, body, [0, 0.0])
2788      return gradients_impl.gradients(r, [var1, var2])
2789
2790    var1 = resource_variable_ops.ResourceVariable([1., 2., 3.])
2791    var2 = resource_variable_ops.ResourceVariable([4., 5.])
2792    self.evaluate(variables.global_variables_initializer())
2793    grads = self.evaluate(bar(var1, var2))
2794    self.assertAllEqual(gradient_checker_v2._to_numpy(grads[0]), [0., 2., 0.])
2795    self.assertAllEqual(gradient_checker_v2._to_numpy(grads[1]), [0., 2.])
2796
2797  @test_util.run_deprecated_v1
2798  def testWhileGrad_ResourceVarSparseRead(self):
2799    # NOTE(skyewm): this test is interesting because the
2800    # ResourceVariable.sparse_read gradient function returns an IndexedSlices.
2801    var = resource_variable_ops.ResourceVariable(np.ones(5),
2802                                                 dtype=dtypes.float32)
2803    r = control_flow_ops.while_loop(
2804        lambda i, _: i < 3,
2805        lambda i, x: (i + 1, x * math_ops.reduce_sum(var.sparse_read([1, 3]))),
2806        [0, constant_op.constant(1.0)])[1]
2807    grad = gradients_impl.gradients(r, var)[0]
2808
2809    self.evaluate(variables.global_variables_initializer())
2810    grad_val = self.evaluate(grad)
2811    self.assertIsInstance(grad_val, ops.IndexedSlicesValue)
2812    arr = gradient_checker_v2._to_numpy(grad_val)
2813    self.assertAllEqual(arr, [0., 12., 0., 12., 0.])
2814
2815  @test_util.run_deprecated_v1
2816  def testWhileGrad_MultiResourceVarSparseRead(self):
2817    # NOTE(skyewm): this test is interesting because the
2818    # ResourceVariable.sparse_read gradient function returns an IndexedSlices.
2819    var1 = resource_variable_ops.ResourceVariable(np.ones(5),
2820                                                  dtype=dtypes.float32)
2821    var2 = resource_variable_ops.ResourceVariable(np.ones(3),
2822                                                  dtype=dtypes.float32)
2823    x1_init = constant_op.constant([0., 0.])
2824    x2_init = constant_op.constant(1.)
2825    x3_init = constant_op.constant(1.)
2826
2827    def body(i, unused_x1, x2, x3):
2828      y1 = var1.sparse_read([1, 3])
2829      y2 = x2 * 2
2830      y3 = x3 * math_ops.reduce_sum(var2.sparse_read([0]))
2831      return i + 1, y1, y2, y3
2832
2833    r = control_flow_ops.while_loop(
2834        lambda i, x1, x2, x3: i < 3, body,
2835        [0, x1_init, x2_init, x3_init])[1:]
2836    var1_grad, var2_grad = gradients_impl.gradients(r, [var1, var2])
2837
2838    self.evaluate(variables.global_variables_initializer())
2839    var1_grad_val = self.evaluate(var1_grad)
2840    var2_grad_val = self.evaluate(var2_grad)
2841    self.assertIsInstance(var1_grad_val, ops.IndexedSlicesValue)
2842    self.assertIsInstance(var2_grad_val, ops.IndexedSlicesValue)
2843    self.assertAllEqual(gradient_checker_v2._to_numpy(var1_grad_val),
2844                        [0., 1., 0., 1., 0.])
2845    self.assertAllEqual(gradient_checker_v2._to_numpy(var2_grad_val),
2846                        [3., 0., 0.])
2847
2848  @test_util.run_deprecated_v1
2849  def testWhileGrad_Gather(self):
2850    # NOTE(skyewm): this test is interesting because the gather gradient
2851    # function returns an IndexedSlices.
2852    x = constant_op.constant([1., 1., 1., 1., 1.])
2853    y = control_flow_ops.while_loop(
2854        lambda i, _: i < 3,
2855        lambda i, x: (i + 1, x + array_ops.gather(x, [0])),
2856        [0, x[:1]])[1]
2857    z = y * 3.0
2858    grad = gradients_impl.gradients(z, x)[0]
2859    self.assertEqual(self.evaluate(y), 8.)
2860    self.assertAllEqual(self.evaluate(grad), [24., 0., 0., 0., 0.])
2861
2862  @test_util.run_deprecated_v1
2863  def testWhileGrad_GatherNoFanOut(self):
2864    # NOTE(skyewm): this test is interesting because the gather gradient
2865    # function returns an IndexedSlices.
2866    x = constant_op.constant([1., 1., 1., 1., 1.])
2867    y = control_flow_ops.while_loop(
2868        lambda i, _: i < 3,
2869        lambda i, x: (i + 1, array_ops.gather(x, [0])),
2870        [0, x[:1]])[1]
2871    z = y * 3.0
2872    grad = gradients_impl.gradients(z, x)[0]
2873    self.assertEqual(self.evaluate(y), 1.)
2874    self.assertAllEqual(self.evaluate(grad), [3., 0., 0., 0., 0.])
2875
2876  @test_util.run_v1_only("b/120545219")
2877  def testWhileGradInCond(self):
2878
2879    with self.cached_session():
2880      n = ops.convert_to_tensor(1.0, name="n")
2881      x = array_ops.placeholder(dtypes.float32, shape=None)
2882      c = lambda n: math_ops.less(n, 10.0)
2883      b = lambda n: math_ops.add(n, x)
2884
2885      def fn1():
2886        r = control_flow_ops.while_loop(c, b, [n],
2887                                        [tensor_shape.unknown_shape()])
2888        return gradients_impl.gradients(r, x)[0]
2889
2890      r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x)
2891      self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
2892
2893  @test_util.disable_control_flow_v2("b/116340060")
2894  @test_util.run_v1_only("b/120545219")
2895  def testGradInWhileWrtInitialLoopVal(self):
2896    with self.cached_session():
2897      x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
2898      y = x + 1
2899
2900      def body(i, v):
2901        z = v * 2
2902        return i + 1, gradients_impl.gradients(z, x)[0]
2903
2904      with self.assertRaisesRegexp(
2905          ValueError,
2906          "Cannot compute gradient inside while loop with respect to op 'x'. "
2907          "We do not support taking the gradient wrt or through the initial "
2908          "value of a loop variable. Gradients can be computed through "
2909          "loop invariants or wrt the input parameters to the loop body."):
2910        control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
2911
2912  @test_util.run_v1_only("b/120545219")
2913  def testWhileGradInWhile(self):
2914    with self.cached_session():
2915      n = ops.convert_to_tensor(1.0, name="n")
2916      x = array_ops.placeholder(dtypes.float32, shape=None)
2917      c = lambda n: math_ops.less(n, 10.0)
2918      b = lambda n: math_ops.add(n, x)
2919
2920      def b1(n):
2921        r = control_flow_ops.while_loop(c, b, [n],
2922                                        [tensor_shape.unknown_shape()])
2923        return gradients_impl.gradients(r, x)
2924
2925      r = control_flow_ops.while_loop(lambda n: n < 6.0, b1, [n],
2926                                      [tensor_shape.unknown_shape()])
2927      self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
2928
2929  @test_util.run_v1_only("b/120545219")
2930  def testCondGradInNestedWhiles(self):
2931
2932    def outer_body(i, x):
2933      _, x = control_flow_ops.while_loop(
2934          lambda j, x: j < 3, inner_body, [0, 0.0])
2935      return i + 1, x
2936
2937    def inner_body(j, x):
2938      y = control_flow_ops.cond(math_ops.less(x, 1), lambda: 2 * x, lambda: x)
2939      return j + 1, gradients_impl.gradients(y, x)[0]
2940
2941    i, x = control_flow_ops.while_loop(lambda i, x: i < 3, outer_body, [0, 0.0])
2942
2943    with self.cached_session() as sess:
2944      i_val, x_val = self.evaluate([i, x])
2945      self.assertEqual(i_val, 3)
2946      self.assertAllClose(x_val, 1.0)
2947
2948  @test_util.run_gpu_only
2949  def testGpuResourceAccess(self):
2950    with ops.device(test.gpu_device_name()):
2951      var = resource_variable_ops.ResourceVariable(constant_op.constant(3.0))
2952
2953    @def_function.function
2954    def foo():
2955      return control_flow_ops.while_loop(
2956          lambda i, _: i < 3,
2957          lambda i, x: (i + 1, control_flow_ops.cond(
2958              constant_op.constant(True),
2959              lambda: x + var,
2960              lambda: x)),
2961          [0, 0.0])[1]
2962
2963    self.evaluate(variables.global_variables_initializer())
2964    self.assertEqual(self.evaluate(foo()), 9.0)
2965
2966  @test_util.disable_xla("b/128643398")
2967  def testNestedResourceAccess(self):
2968    var = resource_variable_ops.ResourceVariable(constant_op.constant(3.0))
2969
2970    @eager_function.defun
2971    def test_fn():
2972      x = constant_op.constant(0.0)
2973      r = control_flow_ops.while_loop(
2974          # Outer loop condition
2975          lambda i, y: i < 2,
2976          # Outer loop body
2977          lambda i, y: (i + 1, y + control_flow_ops.cond(
2978              constant_op.constant(True),
2979              # True branch
2980              lambda: control_flow_ops.while_loop(
2981                  # Inner loop condition
2982                  lambda j, z: j < 3,
2983                  # Inner loop body
2984                  lambda j, z: (j + 1, z + math_ops.square(var)),
2985                  # Inner initial loop value
2986                  [0, y])[1],
2987              # False branch
2988              lambda: (0.0))),
2989          # Outer initial loop value
2990          [0, x])[1]
2991
2992      grad = gradients_impl.gradients(r, x)[0]
2993      return r, grad
2994
2995    self.evaluate(variables.global_variables_initializer())
2996    r, grad = self.evaluate(test_fn())
2997    # 2 * 3 * 3^2
2998    self.assertEqual(r, 81.0)
2999    # v1 control flow gets the wrong answer!!!
3000    # Gradient computation:
3001    #   f(x) = x + 3^2
3002    #   inner_loop(x) = f(f(f(x))) = x + 3*3^2 = x + 27
3003    #   g(x) = x + inner_loop(x) = 2x + 27
3004    #   outer_loop(x) = g(g(x)) = 4x + 81
3005    #   outer_loop'(x) = 4
3006    # Note that v1 control flow gets 4.0 as well if the cond is removed.
3007    if control_flow_util.ENABLE_CONTROL_FLOW_V2:
3008      self.assertEqual(grad, 4.0)
3009
3010  def testWhile_NestedInput(self):
3011    with self.cached_session() as sess:
3012      named = collections.namedtuple("named", ("a", "b"))
3013      loop_vars = [
3014          named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
3015          (constant_op.constant(2.0), constant_op.constant(3.0)),
3016          constant_op.constant(4.0)
3017      ]
3018      c = lambda lv0, _1, _2: lv0.a < 100.0
3019
3020      def b(lv0, lv1, lv2):
3021        lv0 = named(a=lv0.a + 1, b=lv0.b)
3022        lv1 = (lv1[0] + 1, lv1[1])
3023        lv2 += 2
3024        return [lv0, lv1, lv2]
3025
3026      r = control_flow_ops.while_loop(c, b, loop_vars)
3027
3028      self.assertTrue(isinstance(r, list))
3029      self.assertTrue(isinstance(r[0], named))
3030      self.assertTrue(isinstance(r[1], tuple))
3031      self.assertTrue(isinstance(r[2], ops.Tensor))
3032
3033      r_flattened = nest.flatten(r)
3034      self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0],
3035                       self.evaluate(r_flattened))
3036
3037  @test_util.run_v1_only("b/120545219")
3038  def testWhile_NestedBadArityFails(self):
3039    with self.cached_session():
3040      named = collections.namedtuple("named", ("a", "b"))
3041      loop_vars = [
3042          named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
3043          (constant_op.constant(2.0), constant_op.constant(3.0)),
3044          constant_op.constant(4.0)
3045      ]
3046      c = lambda lv0, _1, _2: lv0.a < 100.0
3047
3048      def b(lv0, lv1, _):
3049        return [lv0, lv1]
3050
3051      with self.assertRaisesRegexp(ValueError, "the same number of elements"):
3052        control_flow_ops.while_loop(c, b, loop_vars)
3053
3054  @test_util.run_v1_only("b/120545219")
3055  def testWhileGrad_ys_xs(self):
3056    with self.cached_session():
3057      x = constant_op.constant(3.0, name="x")
3058      y = constant_op.constant(2.0, name="y")
3059
3060      c = lambda x, y: math_ops.less(x, 100.0)
3061
3062      def b(x, y):
3063        y1 = math_ops.add(x, y)
3064        x1 = math_ops.multiply(x, y1)
3065        return x1, y1
3066
3067      rx, ry = control_flow_ops.while_loop(c, b, [x, y], parallel_iterations=1)
3068
3069      r = gradients_impl.gradients([rx, ry], x)
3070      self.assertAllClose(304.0, r[0])
3071      r = gradients_impl.gradients([rx, ry], y)
3072      self.assertAllClose(124.0, r[0])
3073      r = gradients_impl.gradients([rx], x)
3074      self.assertAllClose(295.0, r[0])
3075      r = gradients_impl.gradients([rx], y)
3076      self.assertAllClose(120.0, r[0])
3077
3078  @test_util.run_deprecated_v1
3079  def testWhileGrad_Dependency(self):
3080    with self.cached_session():
3081      i = constant_op.constant(0, name="i")
3082      x = constant_op.constant(2.0, name="x")
3083
3084      c = lambda i, x: math_ops.less(i, 10)
3085
3086      def b(i, x):
3087        x = math_ops.multiply(x, 2.0)
3088        i = math_ops.add(i, 1)
3089        return i, x
3090
3091      ri, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
3092
3093      r = gradients_impl.gradients([ri, rx], x)
3094      self.assertAllClose(1024.0, r[0])
3095      r = gradients_impl.gradients([rx], x)
3096      self.assertAllClose(1024.0, r[0])
3097
3098  @test_util.disable_control_flow_v2("b/116355153 (back_prop flag)")
3099  @test_util.run_v1_only("b/120545219")
3100  def testWhileGrad_NoGradient(self):
3101    with self.cached_session():
3102      v = constant_op.constant(2.0, name="v")
3103      c = lambda v: math_ops.less(v, 100.0)
3104      b = math_ops.square
3105      r = control_flow_ops.while_loop(c, b, [v], back_prop=False)
3106      r = math_ops.add(r, v)
3107      r = gradients_impl.gradients(r, v)
3108      self.assertAllClose(1.0, r[0])
3109
3110  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
3111  @test_util.run_v1_only("b/120545219")
3112  def testWhileGrad_NoDependency(self):
3113    with self.cached_session() as sess:
3114      variable = variables.Variable(array_ops.ones([2, 3]))
3115      duration = array_ops.zeros([], dtype=dtypes.int32)
3116
3117      def cond(duration, tensor, _):
3118        del tensor
3119        return duration < 10
3120
3121      def body(duration, tensor, _):
3122        return (duration + 1, tensor, tensor)
3123
3124      loop_vars = [duration, variable, variable]
3125      tensors = control_flow_ops.while_loop(
3126          cond=cond, body=body, loop_vars=loop_vars)
3127      cost = math_ops.reduce_sum(tensors[2])
3128      grad = gradients_impl.gradients(cost, [variable])
3129      self.evaluate(variables.global_variables_initializer())
3130      self.assertAllClose(np.ones([2, 3]), sess.run(grad[0]))
3131
3132  @test_util.run_deprecated_v1
3133  def testWhileGrad_Const(self):
3134    with self.cached_session() as sess:
3135      c0 = constant_op.constant(0.0, name="c0")
3136      c1 = constant_op.constant(1.0, name="c1")
3137      duration = constant_op.constant(0, name="t")
3138
3139      def cond(duration, _):
3140        return duration < 1
3141
3142      def body(duration, _):
3143        return duration + 1, c1
3144
3145      loop_vars = [duration, c0]
3146      tensors = control_flow_ops.while_loop(
3147          cond=cond, body=body, loop_vars=loop_vars)
3148      cost = math_ops.reduce_sum(tensors[1])
3149      grad = gradients_impl.gradients(cost, [c0])
3150      self.assertAllClose(0.0, sess.run(grad[0]))
3151
3152  @test_util.run_v1_only("b/120545219")
3153  def testWhileGrad_SerialTwoLoops(self):
3154    with self.cached_session():
3155      i = constant_op.constant(0, name="i")
3156      x = constant_op.constant(2.0, name="x")
3157
3158      c = lambda i, x: math_ops.less(i, 5)
3159
3160      def b(i, x):
3161        x = math_ops.multiply(x, 2.0)
3162        i = math_ops.add(i, 1)
3163        return i, x
3164
3165      _, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
3166      _, rx = control_flow_ops.while_loop(c, b, [i, rx], parallel_iterations=1)
3167
3168      r = gradients_impl.gradients([rx], x)
3169      self.assertAllClose(1024.0, r[0])
3170
3171  @test_util.run_v1_only("b/120545219")
3172  def testWhileGrad_ParallelTwoLoops(self):
3173    with self.cached_session():
3174      i = constant_op.constant(0, name="i")
3175      x = constant_op.constant(2.0, name="x")
3176
3177      c = lambda i, x: math_ops.less(i, 5)
3178
3179      def b(i, x):
3180        x = math_ops.multiply(x, 2.0)
3181        i = math_ops.add(i, 1)
3182        return i, x
3183
3184      _, r1 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
3185      _, r2 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
3186      rx = math_ops.add(r1, r2)
3187
3188      r = gradients_impl.gradients([rx], x)
3189      self.assertAllClose(64.0, r[0])
3190
3191  @test_util.run_v1_only("b/120545219")
3192  def testWhileGrad_OneOutputWithControlDependencyOnSecond(self):
3193    with self.cached_session():
3194      i = constant_op.constant(0, name="i")
3195      x = constant_op.constant(1.0, name="x")
3196      y = constant_op.constant(1.0, name="y")
3197      c = lambda i, *_: math_ops.less(i, 1, name="cond_less")
3198
3199      def b(i, xi, yi):
3200        # return (i + 1, xi, xi + yi)
3201        return (math_ops.add(i, 1, name="inc"), array_ops.identity(
3202            xi, name="xi"), math_ops.add(xi, yi, name="xi_plus_yi"))
3203
3204      _, x_f, y_f = control_flow_ops.while_loop(c, b, [i, x, y])
3205      with ops.control_dependencies([x_f]):
3206        y_f_d = array_ops.identity(y_f, name="y_f_d")
3207
3208      self.assertAllClose(2.0, self.evaluate(y_f_d))  # y_f_d = 1.0 + 1.0
3209      g = gradients_impl.gradients([y_f_d], [x])[0]
3210      self.assertTrue(g is not None)
3211      self.assertAllClose(1.0,
3212                          self.evaluate(g))  # y_f_d = x + 1.0, dy_f_d/dx = 1.0
3213
3214  def _testNestedWhileGrad_Simple(self, use_gpu):
3215    with self.cached_session(use_gpu=use_gpu):
3216      v = constant_op.constant(1.0)
3217
3218      def inner_loop(s):
3219        c = lambda x: math_ops.less(x, 4.0)
3220        b = lambda x: math_ops.multiply(x, 2.0)
3221        return control_flow_ops.while_loop(c, b, [s])
3222
3223      c = lambda x: math_ops.less(x, 2.0)
3224      b = lambda x: math_ops.multiply(inner_loop(x), 2.0)
3225      r = control_flow_ops.while_loop(c, b, [v])
3226
3227      r = gradients_impl.gradients(r, v)[0]
3228      self.assertAllClose(8.0, self.evaluate(r))
3229
3230  @test_util.run_deprecated_v1
3231  def testNestedWhileGrad_Simple(self):
3232    self._testNestedWhileGrad_Simple(use_gpu=False)
3233    self._testNestedWhileGrad_Simple(use_gpu=True)
3234
3235  @test_util.run_v1_only("b/120545219")
3236  def testNestedWhileGrad_SerialInner(self):
3237    with self.cached_session():
3238      v = constant_op.constant(1.0)
3239
3240      def inner_loop1(s):
3241        z = constant_op.constant(0)
3242        c = lambda i, x: math_ops.less(i, 4)
3243        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
3244        return control_flow_ops.while_loop(c, b, [z, s])
3245
3246      def inner_loop2(s):
3247        z = constant_op.constant(0)
3248        c = lambda i, x: math_ops.less(i, 4)
3249        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
3250        return control_flow_ops.while_loop(c, b, [z, s])
3251
3252      c = lambda x: math_ops.less(x, 128.0)
3253      b = lambda x: inner_loop2(inner_loop1(x)[1])[1]
3254      r = control_flow_ops.while_loop(c, b, [v])
3255
3256      r = gradients_impl.gradients(r, v)[0]
3257      self.assertAllClose(256.0, self.evaluate(r))
3258
3259  @test_util.run_deprecated_v1
3260  def testNestedWhileGrad_ParallelInner(self):
3261    with self.cached_session():
3262      v = constant_op.constant(1.0)
3263
3264      def inner_loop1(s):
3265        z = constant_op.constant(0)
3266        c = lambda i, x: math_ops.less(i, 4)
3267        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
3268        return control_flow_ops.while_loop(c, b, [z, s])
3269
3270      def inner_loop2(s):
3271        z = constant_op.constant(0)
3272        c = lambda i, x: math_ops.less(i, 4)
3273        b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
3274        return control_flow_ops.while_loop(c, b, [z, s])
3275
3276      c = lambda x: math_ops.less(x, 128.0)
3277      b = lambda x: math_ops.multiply(inner_loop1(x)[1], inner_loop2(x)[1])
3278      r = control_flow_ops.while_loop(c, b, [v])
3279
3280      r = gradients_impl.gradients(r, v)[0]
3281      self.assertAllClose(512.0, self.evaluate(r))
3282
3283  @test_util.run_v1_only("b/120545219")
3284  def testNestedWhileGrad_ParallelIterations(self):
3285    # Make sure the stack pushes and pops of an inner loop are executed in
3286    # the sequential order of the iterations of its outer loop.
3287    with self.cached_session() as sess:
3288
3289      def inner_loop(t):
3290        fn = lambda n: n + math_ops.square(var)
3291        return map_fn.map_fn(fn=fn, elems=t, parallel_iterations=10)
3292
3293      def outer_loop(inp):
3294        return map_fn.map_fn(
3295            fn=inner_loop, elems=inp, parallel_iterations=10)
3296
3297      var = variables.Variable(constant_op.constant(3.0))
3298      inp = constant_op.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
3299      res = outer_loop(inp)
3300      optimizer = adam.AdamOptimizer(learning_rate=0.001)
3301      train_op = optimizer.minimize(math_ops.reduce_mean(math_ops.square(res)))
3302      self.evaluate(variables.global_variables_initializer())
3303      self.evaluate(train_op)
3304      self.assertAllClose(2.999, var.read_value())
3305
3306  def _testWhileCondGrad_Simple(self, use_gpu):
3307    with self.cached_session(use_gpu=use_gpu):
3308      v = ops.convert_to_tensor(2.0, name="v")
3309      n = ops.convert_to_tensor(100.0, name="n")
3310      one = ops.convert_to_tensor(1.0, name="one")
3311      c = lambda x: math_ops.less(x, n)
3312      # pylint: disable=undefined-variable
3313      # for OSS build
3314      b = lambda x: control_flow_ops.cond(constant_op.constant(True),
3315                                          lambda: math_ops.square(x),
3316                                          lambda: math_ops.subtract(x, one))
3317      # pylint: enable=undefined-variable
3318      r = control_flow_ops.while_loop(c, b, [v])
3319      r = gradients_impl.gradients(r, v)[0]
3320      self.assertAllClose(1024.0, self.evaluate(r))
3321
3322  @test_util.run_deprecated_v1
3323  def testWhileCondGrad_Simple(self):
3324    self._testWhileCondGrad_Simple(use_gpu=False)
3325    self._testWhileCondGrad_Simple(use_gpu=True)
3326
3327  @test_util.run_deprecated_v1
3328  def testWhileCondGrad_UnknownShape(self):
3329    with self.cached_session() as sess:
3330      v = array_ops.placeholder(dtypes.float32)
3331      n = ops.convert_to_tensor(100.0, name="n")
3332      one = ops.convert_to_tensor(1.0, name="one")
3333      c = lambda x: math_ops.less(x, n)
3334      # pylint: disable=undefined-variable
3335      # for OSS build
3336      b = lambda x: control_flow_ops.cond(constant_op.constant(True),
3337                                          lambda: math_ops.square(x),
3338                                          lambda: math_ops.subtract(x, one))
3339      # pylint: enable=undefined-variable
3340      r = control_flow_ops.while_loop(c, b, [v])
3341      r = gradients_impl.gradients(r, v)[0]
3342      r = sess.run(r, feed_dict={v: 2.0})
3343      self.assertAllClose(1024.0, r)
3344
3345  @test_util.run_deprecated_v1
3346  def testWhileGrad_Concat(self):
3347    with self.cached_session() as sess:
3348      x = variable_scope.get_variable("x", initializer=[[1., 2.]])
3349      i0 = constant_op.constant(0)
3350      h0 = array_ops.zeros([0, 2])
3351
3352      def condition(i, _):
3353        return i < 2
3354
3355      def body(i, h):
3356        return i + 1, array_ops.concat([h, x], 0)
3357
3358      _, h = control_flow_ops.while_loop(
3359          condition, body, [i0, h0],
3360          [i0.get_shape(), tensor_shape.TensorShape([None, 2])])
3361      s = math_ops.reduce_sum(h)
3362
3363      optimizer = gradient_descent.GradientDescentOptimizer(0.01)
3364      op = optimizer.minimize(s)
3365
3366      self.evaluate(variables.global_variables_initializer())
3367      self.evaluate(op)
3368      self.assertAllClose([[0.98000002, 1.98000002]], self.evaluate(x))
3369
3370  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
3371  @test_util.run_v1_only("b/120545219")
3372  def testWhileWithRefsWithGradients_1(self):
3373    with self.cached_session() as sess:
3374      x = variables.VariableV1(0.)._ref()  # pylint: disable=protected-access
3375      i = constant_op.constant(0)
3376      c = lambda i, x: math_ops.less(i, 10)
3377
3378      self.assertEqual(x.dtype, dtypes.float32_ref)
3379
3380      def body(i, x):
3381        self.assertEqual(x.dtype, dtypes.float32_ref)
3382        return [i + 1, gen_array_ops.ref_identity(x)]
3383
3384      r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5)
3385
3386      grad_ys = [variables.VariableV1(73)._ref()]  # pylint: disable=protected-access
3387      grad = gradients_impl.gradients([r[1]], [x], grad_ys=grad_ys)
3388
3389      self.evaluate(variables.global_variables_initializer())
3390
3391      self.assertEqual(r[0].dtype, dtypes.int32)
3392      self.assertEqual(r[1].dtype, dtypes.float32_ref)
3393
3394      value_i, value_x, value_x_grad = sess.run(r + grad)
3395
3396    self.assertEqual(10, value_i)
3397    self.assertEqual(0, value_x)
3398    self.assertEqual(73, value_x_grad)
3399
3400  @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
3401  @test_util.run_v1_only("b/120545219")
3402  def testWhileGrad_IndexedSlices(self):
3403    with self.cached_session():
3404      values = constant_op.constant([2.0, 4.0], name="values")
3405      indices = constant_op.constant([0, 3], name="indices")
3406      shape = constant_op.constant([10], name="dense_shape")
3407      i = constant_op.constant(0)
3408      x = ops.IndexedSlices(values, indices, dense_shape=shape)
3409
3410      def c(i, _):
3411        return i < 10
3412
3413      def b(i, x):
3414        return [
3415            i + 1,
3416            ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape)
3417        ]
3418
3419      _, r = control_flow_ops.while_loop(c, b, [i, x])
3420      r = gradients_impl.gradients(r.values, values)[0]
3421      self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r))
3422
3423  @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
3424  @test_util.run_v1_only("b/120545219")
3425  def testWhileGrad_SparseTensor(self):
3426    with self.cached_session():
3427      values = constant_op.constant([2.0, 4.0], name="values")
3428      indices = constant_op.constant(
3429          [[0], [3]], dtype=dtypes.int64, name="indices")
3430      shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
3431      i = constant_op.constant(0)
3432      x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
3433
3434      def c(i, _):
3435        return i < 10
3436
3437      def b(i, x):
3438        return [
3439            i + 1,
3440            sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape)
3441        ]
3442
3443      _, r = control_flow_ops.while_loop(c, b, [i, x])
3444      r = gradients_impl.gradients(r.values, values)[0]
3445      self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r))
3446
3447  @test_util.run_v1_only("b/120545219")
3448  def testCallGradInLoop(self):
3449    with self.cached_session() as sess:
3450      i0 = constant_op.constant(0)
3451      params = constant_op.constant(5.0)
3452      params_1 = math_ops.square(params)
3453
3454      def c(i, _):
3455        return i < 10
3456
3457      def b(i, x):
3458        data = constant_op.constant([1.0, 2.0, 3.0])
3459        data = math_ops.multiply(data, params_1)
3460        x1 = x + gradients_impl.gradients(data, params)[0]
3461        return i + 1, x1
3462
3463      output_grad = control_flow_ops.while_loop(
3464          c, b, [i0, constant_op.constant(0.0)])
3465      self.assertAllClose(600.0, self.evaluate(output_grad)[1])
3466
3467  @test_util.run_deprecated_v1
3468  def testWhileAndTensorArray(self):
3469    with self.cached_session() as sess:
3470      param = constant_op.constant(2.0)
3471      n0 = constant_op.constant(0)
3472      y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
3473
3474      def c(i, _):
3475        return i < 10
3476
3477      def b(i, y):
3478        return [
3479            i + 1,
3480            map_fn.map_fn(lambda x: math_ops.multiply(x, param), y)
3481        ]
3482
3483      r = control_flow_ops.while_loop(c, b, [n0, y0], parallel_iterations=1)
3484      r = gradients_impl.gradients(r, param)[0]
3485      self.assertAllClose(107520.0, self.evaluate(r))
3486
3487  @test_util.run_deprecated_v1
3488  def testNestedWhileAndTensorArray(self):
3489    n = constant_op.constant(3.0)
3490
3491    def Body(row, ta):
3492
3493      def InnerBody(row, col, ta):
3494        # Note: row and col are 1-based.
3495        ta = ta.write(
3496            math_ops.cast(n * (row - 1.) + col - 1., dtypes.int32), row * col)
3497        return row, col + 1., ta
3498
3499      ta = control_flow_ops.while_loop(
3500          lambda _, col, _1: col <= n,
3501          InnerBody, [row, constant_op.constant(1.), ta],
3502          return_same_structure=False)[2]
3503      return row + 1., ta
3504
3505    ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=9)
3506    ta = control_flow_ops.while_loop(
3507        lambda row, _: row <= n,
3508        Body, [constant_op.constant(1.), ta],
3509        return_same_structure=False)[1]
3510
3511    output = array_ops.reshape(ta.stack(), [3, 3])
3512    self.assertAllEqual(
3513        self.evaluate(output), [[1., 2., 3.], [2., 4., 6.], [3., 6., 9.]])
3514    # TODO(b/117675481): This does not work with current TA. Enable with new TA.
3515    # grad = gradients_impl.gradients(output, [n])
3516    # self.assertEqual(self.evaluate(grad), 3.5)
3517
3518  @test_util.run_deprecated_v1
3519  def testWhileGrad_StopGrad(self):
3520    with self.cached_session():
3521      x = constant_op.constant(3.0, name="x")
3522      y = constant_op.constant(2.0, name="y")
3523
3524      c = lambda x, y: math_ops.less(x, 100.0)
3525
3526      def b(x, y):
3527        y1 = math_ops.square(y)
3528        x1 = math_ops.add(math_ops.square(x), y1)
3529        return x1, y1
3530
3531      rx, ry = control_flow_ops.while_loop(c, b, [x, y])
3532
3533      r = gradients_impl.gradients(rx, y)[0]
3534      self.assertEqual(136.0, self.evaluate(r))
3535      r = gradients_impl.gradients(ry, y)[0]
3536      self.assertEqual(32.0, self.evaluate(r))
3537
3538      r = gradients_impl.gradients(array_ops.stop_gradient(rx), y)[0]
3539      self.assertEqual(r, None)
3540      r = gradients_impl.gradients(array_ops.stop_gradient(ry), y)[0]
3541      self.assertEqual(r, None)
3542
3543      r = gradients_impl.gradients(
3544          array_ops.stop_gradient(math_ops.square(rx)), y)[0]
3545      self.assertEqual(r, None)
3546      r = gradients_impl.gradients(
3547          array_ops.stop_gradient(math_ops.add(rx, ry)), x)[0]
3548      self.assertEqual(r, None)
3549      r = gradients_impl.gradients(
3550          array_ops.stop_gradient(math_ops.add(rx, ry)), y)[0]
3551      self.assertEqual(r, None)
3552
3553      r = gradients_impl.gradients(math_ops.add(rx, ry), y)[0]
3554      self.assertEqual(168.0, self.evaluate(r))
3555      r = gradients_impl.gradients(
3556          math_ops.add(rx, array_ops.stop_gradient(ry)), y)[0]
3557      self.assertEqual(136.0, self.evaluate(r))
3558      r = gradients_impl.gradients(
3559          math_ops.add(array_ops.stop_gradient(rx), ry), y)[0]
3560      self.assertEqual(32.0, self.evaluate(r))
3561
3562  @test_util.run_deprecated_v1
3563  @test_util.disable_control_flow_v2("b/118712257")
3564  def testWhileGrad_StopGradInside(self):
3565    with self.cached_session():
3566      x = constant_op.constant(3.0, name="x")
3567      y = constant_op.constant(2.0, name="y")
3568
3569      c = lambda x, y: math_ops.less(x, 100.0)
3570
3571      def b(x, y):
3572        y1 = array_ops.stop_gradient(math_ops.square(y))
3573        x1 = math_ops.add(math_ops.square(x), y1)
3574        return x1, y1
3575
3576      rx, _ = control_flow_ops.while_loop(c, b, [x, y])
3577
3578      r = gradients_impl.gradients(rx, y)[0]
3579      self.assertAllClose(0.0, self.evaluate(r))
3580      r = gradients_impl.gradients(rx, x)[0]
3581      self.assertAllClose(156.0, self.evaluate(r))
3582
3583  @test_util.run_deprecated_v1
3584  @test_util.disable_control_flow_v2("b/118712257")
3585  def testWhileGrad_StopGradInsideNoShape(self):
3586    with self.cached_session() as sess:
3587      x = array_ops.placeholder(dtypes.float32)
3588      y = array_ops.placeholder(dtypes.float32)
3589
3590      c = lambda x, y: math_ops.less(math_ops.reduce_sum(x), 100.0)
3591
3592      def b(x, y):
3593        y1 = array_ops.stop_gradient(math_ops.square(y, name="stopped"))
3594        x1 = math_ops.add(math_ops.square(x), y1)
3595        return x1, y1
3596
3597      rx, _ = control_flow_ops.while_loop(c, b, [x, y])
3598
3599      r = gradients_impl.gradients(rx, y)[0]
3600      feed_dict = {x: [3.0, 4.0], y: [2.0, 3.0]}
3601      self.assertAllClose([0.0, 0.0], sess.run(r, feed_dict=feed_dict))
3602      r = gradients_impl.gradients(rx, x)[0]
3603      self.assertAllClose([156.0, 400.0], sess.run(r, feed_dict=feed_dict))
3604      name = "gradients/while/stopped_grad"
3605      all_ops = x.graph.get_operations()
3606      self.assertFalse(any(name in op.name for op in all_ops))
3607
3608  @test_util.run_deprecated_v1
3609  def testWhileGradGradFail(self):
3610    theta = variables.Variable(initial_value=1.)
3611
3612    def fn(prev, x):
3613      return prev + x * theta
3614
3615    result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32))
3616    grad_theta = gradients_impl.gradients(result, theta)
3617    if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
3618      with self.assertRaisesRegexp(TypeError, "Second-order gradient"):
3619        gradients_impl.gradients(grad_theta, theta)
3620    grad_theta_stopped = array_ops.stop_gradient(grad_theta)
3621    gradients_impl.gradients(grad_theta_stopped, theta)
3622
3623  @test_util.run_deprecated_v1
3624  def testStopGradOnWhileGrad(self):
3625    with self.cached_session():
3626      x = constant_op.constant(2.0, name="x")
3627      y = constant_op.constant(2.0, name="y")
3628
3629      c = lambda x: math_ops.less(x, 100.0)
3630      b = lambda x: math_ops.multiply(x, y)
3631      rx = control_flow_ops.while_loop(c, b, [x])
3632
3633      rg = gradients_impl.gradients(rx, y)[0]
3634      rg = array_ops.stop_gradient(rg)
3635      r = math_ops.add(math_ops.square(y), rx)
3636      r = math_ops.add(r, rg)
3637      r = gradients_impl.gradients(r, y)[0]
3638      self.assertEqual(388.0, self.evaluate(r))
3639
3640  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
3641  @test_util.run_deprecated_v1
3642  def testWhileGradientWithNontrainablePath1(self):
3643    q = variables.Variable([7., 8.])
3644
3645    def cond(_, y):
3646      del y
3647      return False
3648
3649    def body(x, _):
3650      return x, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q)
3651
3652    _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
3653    dy_dq, = gradients_impl.gradients(y, q)
3654    self.assertIsNotNone(dy_dq)
3655    with self.cached_session() as sess:
3656      self.evaluate(q.initializer)
3657      self.assertAllClose([0., 0.], self.evaluate(dy_dq))
3658
3659  @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
3660  @test_util.run_v1_only("b/120545219")
3661  def testWhileGradientWithNontrainablePath2(self):
3662    q = variables.Variable([7., 8.])
3663
3664    def cond(_, y):
3665      return math_ops.equal(y, 0.)
3666
3667    def body(x, _):
3668      zero = constant_op.constant(0, dtype=dtypes.int64)
3669      return zero, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q)
3670
3671    _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
3672    dy_dq, = gradients_impl.gradients(y, q)
3673    self.assertIsNotNone(dy_dq)
3674    with self.cached_session() as sess:
3675      self.evaluate(q.initializer)
3676      self.assertAllClose([1., 1.], self.evaluate(dy_dq))
3677
3678  @test_util.run_v1_only("b/120545219")
3679  def testIssue16504(self):
3680    c = constant_op.constant(np.arange(100), dtype=dtypes.float32)
3681    w = variables.Variable(
3682        initial_value=np.ones(100), dtype=dtypes.float32) / 100
3683    k = variables.Variable(0, dtype=dtypes.int32)
3684    chg_w = constant_op.constant(np.inf, dtype=dtypes.float32)
3685
3686    def cond(k, _, chg_w):
3687      return math_ops.logical_and(k < 10, chg_w > 1e-3)
3688
3689    def body(k, w, chg_w):
3690      grad, = gradients_impl.gradients(-math_ops.reduce_sum(w * c), w)
3691      w_n = w * math_ops.exp(-0.1 * grad)
3692      w_n /= math_ops.reduce_sum(w_n)
3693      chg_w = (
3694          math_ops.reduce_sum(math_ops.abs(w_n - w)) / math_ops.reduce_sum(
3695              math_ops.abs(w)))
3696      return k + 1, w_n, chg_w
3697
3698    _, w, _ = control_flow_ops.while_loop(cond, body, [k, w, chg_w])
3699    grad, = gradients_impl.gradients(w, c)
3700    self.assertIsNotNone(grad)
3701
3702  @test_util.run_v1_only("b/120545219")
3703  def testStopGradMultiFlows(self):
3704    with self.cached_session():
3705
3706      def body(i, y, r):
3707        x = variable_scope.get_variable(
3708            "x",
3709            shape=(),
3710            dtype=dtypes.float32,
3711            initializer=init_ops.ones_initializer())
3712        y *= x
3713        return [i + 1, y, r + math_ops.reduce_sum(y)]
3714
3715      i0 = constant_op.constant(0)
3716      y0 = array_ops.ones(5)
3717      r0 = constant_op.constant(0.0)
3718      cond = lambda i, y, r: i < 1
3719      _, _, r = control_flow_ops.while_loop(
3720          cond, body, [i0, y0, r0], back_prop=True)
3721
3722      vars_ = variables.global_variables()
3723      grads = linalg_ops.norm(gradients_impl.gradients(r, vars_)[0])
3724      z = math_ops.add(r, array_ops.stop_gradient(math_ops.reduce_sum(grads)))
3725      result = gradients_impl.gradients(z, vars_)[0]
3726      self.evaluate(variables.global_variables_initializer())
3727      self.assertEqual(5.0, self.evaluate(result))
3728
3729  @test_util.run_v1_only("b/120545219")
3730  def testOneValueCond(self):
3731
3732    with self.cached_session():
3733      c = array_ops.placeholder(dtypes.int32, shape=[])
3734      one = ops.convert_to_tensor(1, name="one")
3735      two = ops.convert_to_tensor(2, name="two")
3736      p = math_ops.greater_equal(c, 1)
3737      i = control_flow_ops.cond(p, lambda: one, lambda: two)
3738      self.assertTrue(isinstance(i, ops.Tensor))
3739
3740      # True case: c = 2 is >= 1
3741      self.assertEqual([1], i.eval(feed_dict={c: 2}))
3742
3743      # False case: c = 0 is not >= 1
3744      self.assertEqual([2], i.eval(feed_dict={c: 0}))
3745
3746  @test_util.run_deprecated_v1
3747  def testExampleCond(self):
3748
3749    with self.cached_session():
3750      x = ops.convert_to_tensor([-2.0, 2.0], name="x")
3751      d = array_ops.placeholder(dtypes.int32, shape=[])
3752
3753      def l2():
3754        return math_ops.sqrt(math_ops.reduce_sum(math_ops.square(x)))
3755
3756      def l1():
3757        return math_ops.reduce_sum(math_ops.abs(x))
3758
3759      i = control_flow_ops.cond(math_ops.equal(d, 2), l2, l1)
3760      self.assertAllClose(4.0, i.eval(feed_dict={d: 1}))
3761      self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
3762
3763  @test_util.run_v1_only("b/120545219")
3764  def testCase(self):
3765    with self.cached_session():
3766      x = constant_op.constant(1)
3767      y = constant_op.constant(2)
3768      z = constant_op.constant(3)
3769      f1 = lambda: constant_op.constant(17)
3770      f2 = lambda: constant_op.constant(23)
3771      f3 = lambda: constant_op.constant(-1)
3772
3773      r1 = control_flow_ops.case(
3774          {
3775              x < y: f1,
3776              x > z: f2
3777          }, default=f3, exclusive=True)
3778      self.assertAllEqual(r1, 17)
3779
3780      r2 = control_flow_ops.case([(y > z, f1), (y > x, f2)], default=f3)
3781      self.assertAllEqual(r2, 23)
3782
3783      # Duplicate events can happen, first one is selected
3784      r3 = control_flow_ops.case([(x < y, f1), (x < y, f2)], default=f3)
3785      self.assertAllEqual(r3, 17)
3786
3787      # Duplicate events cause an error if exclusive = True
3788      r4 = control_flow_ops.case(
3789          [(x < y, f1), (x < y, f2)], default=f3, exclusive=True)
3790      with self.assertRaisesOpError("Input error:"):
3791        self.evaluate(r4)
3792
3793      # Check that the default is called if none of the others are
3794      r5 = control_flow_ops.case({x > y: f1}, default=f3)
3795      self.assertAllEqual(r5, -1)
3796
3797      ran_once = [False, False, False]
3798
3799      def break_run_twice(ix):
3800
3801        def _break():
3802          ran_once[ix] = True
3803          return constant_op.constant(ix)
3804
3805        return _break
3806
3807      # Should not fail - each conditional gets called exactly once
3808      # except default.  Default gets called twice: once to create an
3809      # empty output and once for the actual cond switch.
3810      r6 = control_flow_ops.case(
3811          [(x < y, break_run_twice(0)), (x > y, break_run_twice(1))],
3812          default=lambda: constant_op.constant(2))
3813
3814      self.assertAllEqual(r6, 0)
3815
3816  @test_util.run_v1_only("b/120545219")
3817  def testCaseSideEffects(self):
3818    with self.cached_session() as sess:
3819      v0 = variables.Variable(-1)
3820      v1 = variables.Variable(-1)
3821      v2 = variables.Variable(-1)
3822
3823      a = lambda: control_flow_ops.with_dependencies([state_ops.assign(v0, 0)], 0)
3824      b = lambda: control_flow_ops.with_dependencies([state_ops.assign(v1, 1)], 1)
3825      c = lambda: control_flow_ops.with_dependencies([state_ops.assign(v2, 2)], 2)
3826
3827      x = constant_op.constant(1)
3828      y = constant_op.constant(2)
3829
3830      r0 = control_flow_ops.case(
3831          ((x < y, a), (x > y, b)), default=c, exclusive=True)
3832      r1 = control_flow_ops.case(
3833          ((x > y, a), (x < y, b)), default=c, exclusive=True)
3834      r2 = control_flow_ops.case(
3835          ((x > y, a), (x > y, b)), default=c, exclusive=True)
3836
3837      self.evaluate(variables.global_variables_initializer())
3838      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
3839      self.assertEqual(2, self.evaluate(r2))
3840      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, -1, 2])
3841
3842      self.evaluate(variables.global_variables_initializer())
3843      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
3844      self.assertEqual(1, self.evaluate(r1))
3845      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, 1, -1])
3846
3847      self.evaluate(variables.global_variables_initializer())
3848      self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
3849      self.assertEqual(0, self.evaluate(r0))
3850      self.assertAllEqual(self.evaluate([v0, v1, v2]), [0, -1, -1])
3851
3852  @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
3853  @test_util.run_v1_only("b/120545219")
3854  def testOneOpCond(self):
3855    with self.cached_session():
3856      v = variables.Variable(0)
3857      c = ops.convert_to_tensor(0)
3858      one = ops.convert_to_tensor(1)
3859      two = ops.convert_to_tensor(2)
3860      p = math_ops.greater_equal(c, 1)
3861
3862      def a():
3863        return state_ops.assign(v, one)
3864
3865      def b():
3866        return state_ops.assign(v, two)
3867
3868      i = control_flow_ops.cond(p, a, b)
3869      self.assertTrue(isinstance(i, ops.Tensor))
3870      self.evaluate(variables.global_variables_initializer())
3871
3872      self.assertEqual(0, self.evaluate(v))
3873
3874      # True case: c = 2 is >= 1, v is set to 1.
3875      self.assertEqual(1, i.eval(feed_dict={c.name: 2}))
3876      self.assertEqual(1, self.evaluate(v))
3877
3878      # False case: c = 0 is not >= 1, v is set to 2.
3879      self.assertEqual(2, i.eval(feed_dict={c.name: 0}))
3880      self.assertEqual(2, self.evaluate(v))
3881
3882  @test_util.run_v1_only("b/120545219")
3883  def testWithOpsDependencies(self):
3884    with self.cached_session() as sess:
3885      v = variables.VariableV1(0.0)
3886      c = constant_op.constant(10)
3887
3888      # Fetching v directly will result in an uninitialized error
3889      with self.assertRaisesOpError("Attempting to use uninitialized value"):
3890        self.evaluate([c, v])
3891
3892      # Use a control dependency to ensure init_variable is run
3893      # while asking for c
3894      real_v = control_flow_ops.with_dependencies(
3895          name="real_tensor",
3896          output_tensor=v._ref(),  # pylint: disable=protected-access
3897          dependencies=[v.initializer])
3898      c_val, real_v_val = self.evaluate([c, real_v])
3899
3900    # Ensure the result of 'real_c' is the same as 'c'
3901    self.assertAllEqual(10, c_val)
3902
3903    # Ensure that 'v' is initialized
3904    self.assertAllClose(0.0, real_v_val)
3905
3906  @test_util.run_v1_only("b/120545219")
3907  def testWithTensorDependencies(self):
3908    with self.cached_session():
3909      v = variables.VariableV1(0.0)
3910      c1 = constant_op.constant(10)
3911      c2 = constant_op.constant(20)
3912
3913      # c1_with_init_v depends on the init op for v
3914      c1_with_init_v = control_flow_ops.with_dependencies(
3915          name="c1_with_init_v", output_tensor=c1, dependencies=[v.initializer])
3916      # c2_with_c1 depends on the value of c1_with_init_v
3917      c2_with_c1_dep = control_flow_ops.with_dependencies(
3918          name="c2_with_c1_dep",
3919          output_tensor=c2,
3920          dependencies=[c1_with_init_v])
3921
3922      # Fetching v directly will result in an uninitialized error
3923      with self.assertRaisesOpError("Attempting to use uninitialized value"):
3924        self.evaluate(v)
3925
3926      # Get the value of 'c2_with_c1_dep', which should cause 'v'
3927      # to be initialized.
3928      self.assertAllEqual(20, self.evaluate(c2_with_c1_dep))
3929
3930      # Ensure that 'v' is initialized
3931      self.assertAllClose(0.0, self.evaluate(v))
3932
3933  @test_util.run_v1_only("b/120545219")
3934  def testWithIndexedSlicesDependencies(self):
3935    with self.cached_session():
3936      v = variables.VariableV1(
3937          np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32))
3938      v_at_1 = ops.IndexedSlices(v, constant_op.constant([1]))
3939      gather_v_at_1 = array_ops.gather(v_at_1.values, v_at_1.indices)
3940      v_at_1_after_init = control_flow_ops.with_dependencies([v.initializer],
3941                                                             v_at_1)
3942      gather_v_at_1_after_init = array_ops.gather(v_at_1_after_init.values,
3943                                                  v_at_1_after_init.indices)
3944
3945      # Fetching gather_v_at_1 will result in an uninitialized error
3946      with self.assertRaisesOpError("Attempting to use uninitialized value"):
3947        self.evaluate(gather_v_at_1)
3948
3949      # Getting gather_v_at_1_after_init will work, and initialize v.
3950      self.assertAllEqual([[10.0, 11.0]],
3951                          self.evaluate(gather_v_at_1_after_init))
3952
3953      # Double check that 'v' is initialized
3954      self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]],
3955                          self.evaluate(v))
3956
3957  def testDependenciesDevice(self):
3958    with ops.Graph().as_default():
3959      # device set on tensor => same device on dep.
3960      with ops.device("/job:ps"):
3961        vd = variables.VariableV1([0.0])
3962      with_vd_dep = control_flow_ops.with_dependencies([vd.initializer], vd)
3963      self.assertTrue("/job:ps" in with_vd_dep.device)
3964
3965      # No device set on tensor => no device on dep.
3966      vnod = variables.VariableV1([0.0])
3967      with_vnod_dep = control_flow_ops.with_dependencies([vnod.initializer],
3968                                                         vnod)
3969      self.assertDeviceEqual(None, with_vnod_dep.device)
3970
3971      # device set on tensor, default device on graph => default device on dep.
3972      vdef = variables.VariableV1([0.0], name="vdef")
3973      with ops.device("/job:worker/device:GPU:1"):
3974        with_vdef_dep = control_flow_ops.with_dependencies([vdef.initializer],
3975                                                           vdef)
3976        # The device is empty, but the colocation constraint is set.
3977        self.assertDeviceEqual("", with_vdef_dep.device)
3978        self.assertEqual([b"loc:@vdef"], with_vdef_dep.op.colocation_groups())
3979
3980  @test_util.run_v1_only("b/120545219")
3981  def testGroup(self):
3982    with self.cached_session() as sess:
3983      v1 = variables.VariableV1([0.0])
3984      v2 = variables.VariableV1([1.0])
3985
3986      # Group init1 and init2 and run.
3987      init = control_flow_ops.group(v1.initializer, v2.initializer)
3988      # Fetching v1 directly will result in an uninitialized error
3989      with self.assertRaisesOpError("Attempting to use uninitialized value"):
3990        self.evaluate(v1)
3991
3992      # Runs "init" before fetching v1 and v2.
3993      init.run()
3994      v1_val, v2_val = self.evaluate([v1, v2])
3995
3996    # Ensure that v1 and v2 are initialized
3997    self.assertAllClose([0.0], v1_val)
3998    self.assertAllClose([1.0], v2_val)
3999
4000  @test_util.run_v1_only("b/120545219")
4001  def testGroupEmpty(self):
4002    op = control_flow_ops.group()
4003    self.assertEqual(op.type, "NoOp")
4004    self.assertEqual(op.control_inputs, [])
4005
4006  @test_util.run_deprecated_v1
4007  def testMergeShapes(self):
4008    # All inputs unknown.
4009    p1 = array_ops.placeholder(dtypes.float32)
4010    p2 = array_ops.placeholder(dtypes.float32)
4011    p3 = array_ops.placeholder(dtypes.float32)
4012    m, index = control_flow_ops.merge([p1, p2, p3])
4013    self.assertIs(None, m.get_shape().ndims)
4014    self.assertEqual([], index.get_shape())
4015
4016    # All inputs known with different ranks.
4017    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4018    p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2, 3])
4019    m, index = control_flow_ops.merge([p1, p2])
4020    self.assertIs(None, m.get_shape().ndims)
4021    self.assertEqual([], index.get_shape())
4022
4023    # All inputs known with some dimensions different.
4024    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4025    p2 = array_ops.placeholder(dtypes.float32, shape=[2, 1])
4026    m, index = control_flow_ops.merge([p1, p2])
4027    self.assertEqual([None, None], m.get_shape().as_list())
4028    self.assertEqual([], index.get_shape())
4029
4030    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4031    p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
4032    m, index = control_flow_ops.merge([p1, p2])
4033    self.assertEqual([None, 2], m.get_shape().as_list())
4034    self.assertEqual([], index.get_shape())
4035
4036    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4037    p2 = array_ops.placeholder(dtypes.float32, shape=[2, 2])
4038    m, index = control_flow_ops.merge([p1, p2])
4039    self.assertEqual([None, 2], m.get_shape().as_list())
4040    self.assertEqual([], index.get_shape())
4041
4042    # All inputs known with same dimensions.
4043    p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4044    p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
4045    m, index = control_flow_ops.merge([p1, p2])
4046    self.assertEqual([1, 2], m.get_shape().as_list())
4047    self.assertEqual([], index.get_shape())
4048
4049    p1 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
4050    p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
4051    m, index = control_flow_ops.merge([p1, p2])
4052    self.assertEqual([None, 2], m.get_shape().as_list())
4053    self.assertEqual([], index.get_shape())
4054
4055    p1 = array_ops.placeholder(dtypes.float32, shape=[None, None])
4056    p2 = array_ops.placeholder(dtypes.float32, shape=[None, None])
4057    m, index = control_flow_ops.merge([p1, p2])
4058    self.assertEqual([None, None], m.get_shape().as_list())
4059    self.assertEqual([], index.get_shape())
4060
4061  @test_util.run_v1_only("b/120545219")
4062  def testRefSelect(self):
4063    index = array_ops.placeholder(dtypes.int32)
4064
4065    # All inputs unknown.
4066    p1 = array_ops.placeholder(dtypes.float32)
4067    p2 = array_ops.placeholder(dtypes.float32)
4068    p3 = array_ops.placeholder(dtypes.float32)
4069    v1 = variables.VariableV1(p1, validate_shape=False)
4070    v2 = variables.VariableV1(p2, validate_shape=False)
4071    v3 = variables.VariableV1(p3, validate_shape=False)
4072    self.assertIs(None, v1.get_shape().ndims)
4073    s = control_flow_ops.ref_select(index, [v1, v2, v3])
4074    self.assertIs(None, s.get_shape().ndims)
4075
4076    # All inputs known but different.
4077    v1 = variables.VariableV1([[1, 2]])
4078    v2 = variables.VariableV1([[2], [1]])
4079    s = control_flow_ops.ref_select(index, [v1, v2])
4080    self.assertIs(None, s.get_shape().ndims)
4081
4082    # All inputs known and same.
4083    v1 = variables.VariableV1([[1, 2]])
4084    v2 = variables.VariableV1([[1, 2]])
4085    s = control_flow_ops.ref_select(index, [v1, v2])
4086    self.assertEqual([1, 2], s.get_shape())
4087
4088    # Possibly the same but not guaranteed.
4089    v1 = variables.VariableV1([[1., 2.]])
4090    p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
4091    v2 = variables.VariableV1(p2, validate_shape=False)
4092    s = control_flow_ops.ref_select(index, [v1, v2])
4093    self.assertEqual(None, s.get_shape())
4094
4095  @test_util.run_deprecated_v1
4096  def testRunLoopTensor(self):
4097    with self.cached_session() as sess:
4098      tensor_list = []
4099
4100      def condition(t):
4101        return t < constant_op.constant(5)
4102
4103      def body(_):
4104        tensor_list.append(constant_op.constant(5))
4105        return constant_op.constant(10)
4106
4107      result = control_flow_ops.while_loop(condition, body,
4108                                           [constant_op.constant(4)])
4109      self.assertEqual(10, self.evaluate(result))
4110
4111      # Ensure that we cannot run a tensor that escapes the loop body
4112      # accidentally.
4113      with self.assertRaises(ValueError):
4114        sess.run(tensor_list[0])
4115
4116  @test_util.run_v1_only("b/120545219")
4117  def testWhilePyFuncBasic(self):
4118
4119    def func(x):
4120      return np.square(x)
4121
4122    with self.cached_session():
4123      r = control_flow_ops.while_loop(
4124          lambda i, v: i < 4,
4125          lambda i, v: [i + 1, script_ops.py_func(func, [v], [dtypes.float32])[0]],
4126          [constant_op.constant(0), constant_op.constant(2.0, dtypes.float32)],
4127          [tensor_shape.unknown_shape(), tensor_shape.unknown_shape()])
4128      self.assertEqual(self.evaluate(r[1]), 65536.0)
4129
4130  @test_util.run_v1_only("b/120545219")
4131  def testWhileFuncBasic(self):
4132
4133    @function.Defun(dtypes.float32)
4134    def func(x):
4135      return math_ops.square(math_ops.square(x))
4136
4137    with self.cached_session():
4138      x = constant_op.constant(2.0, dtypes.float32)
4139      r = control_flow_ops.while_loop(
4140          lambda i, v: i < 2, lambda i, v: [i + 1, func(v)],
4141          [constant_op.constant(0), x],
4142          [tensor_shape.unknown_shape(),
4143           tensor_shape.unknown_shape()])
4144      grad = gradients_impl.gradients(r, x)[0]
4145      self.assertEqual(self.evaluate(r[1]), 65536.0)
4146      self.assertEqual(self.evaluate(grad), 524288.0)
4147      # while_v2 does not have stacks.
4148      if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
4149        self.assertEqual(
4150            len([op for op in x.graph.get_operations() if op.type == "StackV2"
4151                ]), 1)
4152
4153
4154  @test_util.run_v1_only("b/120545219")
4155  def testQIntSwitchMerge(self):
4156    with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
4157      constant_qint = constant_op.constant(np.array([42]), dtypes.qint8)
4158      cond = constant_op.constant(True, dtypes.bool)
4159      v_f, v_t = control_flow_ops.switch(constant_qint, cond)
4160      result = control_flow_ops.merge([v_f, v_t])
4161      self.evaluate(result)
4162
4163  @test_util.run_v1_only("b/120545219")
4164  def testQIntRefSwitchMerge(self):
4165    with self.cached_session(use_gpu=test.is_gpu_available()) as sess:
4166      var_qint = gen_state_ops.variable(
4167          shape=[1], dtype=dtypes.qint8, name="v", container="", shared_name="")
4168      assign_op = state_ops.assign(
4169          var_qint, constant_op.constant(np.array([42]), dtypes.qint8))
4170      self.evaluate(assign_op)
4171
4172      cond = constant_op.constant(True, dtypes.bool)
4173      v_f, v_t = control_flow_ops.ref_switch(var_qint, cond)
4174      result = control_flow_ops.ref_merge([v_f, v_t])
4175      self.evaluate(result)
4176
4177  @test_util.run_v1_only("b/120545219")
4178  def testUInt64SwitchMerge(self):
4179    with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
4180      constant_uint64 = constant_op.constant(np.array([42]), dtypes.uint64)
4181      cond = constant_op.constant(True, dtypes.bool)
4182      v_f, v_t = control_flow_ops.switch(constant_uint64, cond)
4183      result = control_flow_ops.merge([v_f, v_t])
4184      self.evaluate(result)
4185
4186  @test_util.run_deprecated_v1
4187  def testQIntArgAndRet(self):
4188
4189    @function.Defun(dtypes.qint8)
4190    def func(x):
4191      return x
4192
4193    with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
4194      qint = constant_op.constant(np.array([42]), dtypes.qint8)
4195      result = func(qint)
4196      self.evaluate(result)
4197
4198  def testSparseIdentity(self):
4199    st1 = sparse_tensor.SparseTensor([[0, 5]], ['x'], [10, 10])
4200    st2 = control_flow_ops._Identity(st1)
4201    self.assertAllEqual(st1.indices, st2.indices)
4202    self.assertAllEqual(st1.values, st2.values)
4203    self.assertAllEqual(st1.dense_shape, st2.dense_shape)
4204
4205  def testSparseEnterExit(self):
4206    st1 = sparse_tensor.SparseTensor([[0, 5]], ['x'], [10, 10])
4207    st2 = control_flow_ops._Enter(st1, "foo_1")
4208    st3 = control_flow_ops.exit(st2)
4209    self.assertAllEqual(st1.indices, st3.indices)
4210    self.assertAllEqual(st1.values, st3.values)
4211    self.assertAllEqual(st1.dense_shape, st3.dense_shape)
4212
4213
4214class ControlFlowContextCheckTest(test.TestCase):
4215
4216  def _getWhileTensor(self):
4217    """Creates and returns a tensor from a while context."""
4218    tensor = []
4219
4220    def body(i):
4221      if not tensor:
4222        tensor.append(constant_op.constant(1))
4223      return i + tensor[0]
4224
4225    control_flow_ops.while_loop(lambda i: i < 10, body, [0])
4226    return tensor[0]
4227
4228  def _getCondTensor(self):
4229    cond_tensor = []
4230
4231    def true_fn():
4232      if not cond_tensor:
4233        cond_tensor.append(constant_op.constant(1))
4234      return cond_tensor[0]
4235
4236    control_flow_ops.cond(
4237        math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0))
4238    return cond_tensor[0]
4239
4240  @test_util.run_v1_only("b/120545219")
4241  def testInvalidContext(self):
4242    # Accessing a while loop tensor outside of control flow is illegal.
4243    while_tensor = self._getWhileTensor()
4244    with self.assertRaisesRegexp(
4245        ValueError,
4246        "Cannot use 'while/Const_1' as input to 'Add' because 'while/Const_1' "
4247        "is in a while loop. See info log for more details."):
4248      math_ops.add(1, while_tensor)
4249
4250  @test_util.run_v1_only("b/120545219")
4251  def testInvalidContextInCond(self):
4252    # Accessing a while loop tensor in cond is illegal.
4253    while_tensor = self._getWhileTensor()
4254    with self.assertRaisesRegexp(
4255        ValueError, "Cannot use 'while/Const_1' as input to 'cond/Add' because "
4256        "'while/Const_1' is in a while loop. See info log for more details."):
4257      # TODO(skyewm): this passes if we return while_tensor directly instead
4258      # of using it as input to another op.
4259      control_flow_ops.cond(
4260          math_ops.less(1, 2), lambda: math_ops.add(1, while_tensor),
4261          lambda: constant_op.constant(0))
4262
4263  @test_util.run_v1_only("b/120545219")
4264  def testInvalidContextInWhile(self):
4265    # Accessing a while loop tensor in a different while loop is illegal.
4266    while_tensor = self._getWhileTensor()
4267    with self.assertRaisesRegexp(
4268        ValueError,
4269        "Cannot use 'while/Const_1' as input to 'while_1/Add' because they are "
4270        "in different while loops. See info log for more details."):
4271      control_flow_ops.while_loop(lambda i: i < 10,
4272                                  lambda x: math_ops.add(1, while_tensor), [0])
4273
4274    with self.assertRaisesRegexp(
4275        ValueError,
4276        "Cannot use 'while/Const_1' as input to 'while_2/NextIteration' "
4277        "because they are in different while loops. See info log for more "
4278        "details."):
4279      control_flow_ops.while_loop(lambda i: i < 10, lambda i: while_tensor, [0])
4280
4281  def testValidCondContext(self):
4282    # Accessing a tensor from a cond context is OK (although dangerous).
4283    cond_tensor = self._getCondTensor()
4284    math_ops.add(1, cond_tensor)
4285
4286  def testValidCondContextBranches(self):
4287    # Accessing a tensor from a cond context from the other branch's cond
4288    # context is OK (although dangerous).
4289    cond_tensor = []
4290
4291    def branch_fn():
4292      if not cond_tensor:
4293        cond_tensor.append(constant_op.constant(1))
4294      return cond_tensor[0]
4295
4296    control_flow_ops.cond(math_ops.less(1, 2), branch_fn, branch_fn)
4297
4298  @test_util.run_v1_only("b/120545219")
4299  def testValidWhileContext(self):
4300    # Accessing a tensor in a nested while is OK.
4301    def body(_):
4302      c = constant_op.constant(1)
4303      return control_flow_ops.while_loop(lambda i: i < 3, lambda i: i + c, [0])
4304
4305    control_flow_ops.while_loop(lambda i: i < 5, body, [0])
4306
4307  @test_util.run_v1_only("b/120545219")
4308  def testValidNestedContexts(self):
4309    # Accessing a tensor from a cond context in a while context, all inside an
4310    # outer while context, is OK.
4311    def body(_):
4312      cond_tensor = self._getCondTensor()
4313      # Create another cond containing the while loop for good measure
4314      return control_flow_ops.cond(
4315          math_ops.less(1, 2),
4316          lambda: control_flow_ops.while_loop(lambda i: i < 3,
4317                                              lambda i: i + cond_tensor, [0]),
4318          lambda: constant_op.constant(0))
4319
4320    control_flow_ops.while_loop(lambda i: i < 5, body, [0])
4321
4322  @test_util.run_v1_only("b/120545219")
4323  def testInvalidNestedContexts(self):
4324    # Accessing a tensor from a while context in a different while context, all
4325    # inside a cond context, is illegal.
4326    def true_fn():
4327      while_tensor = self._getWhileTensor()
4328      return control_flow_ops.while_loop(lambda i: i < 3,
4329                                         lambda i: i + while_tensor, [0])
4330
4331    with self.assertRaisesRegexp(
4332        ValueError,
4333        "Cannot use 'cond/while/Const_1' as input to 'cond/while_1/add' because"
4334        " they are in different while loops. See info log for more details."):
4335      control_flow_ops.cond(
4336          math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0))
4337
4338
4339class TupleTest(test.TestCase):
4340
4341  @test_util.run_v1_only("b/120545219")
4342  def testTensors(self):
4343    for v1_first in [True, False]:
4344      with self.cached_session():
4345        v1 = variables.VariableV1([1.0])
4346        add1 = math_ops.add(
4347            control_flow_ops.with_dependencies([v1.initializer], v1._ref()),  # pylint: disable=protected-access
4348            2.0)
4349        v2 = variables.VariableV1([10.0])
4350        add2 = math_ops.add(
4351            control_flow_ops.with_dependencies([v2.initializer], v2._ref()),  # pylint: disable=protected-access
4352            20.0)
4353        t1, _, t2 = control_flow_ops.tuple([add1, None, add2])
4354
4355        # v1 is not initialized.
4356        with self.assertRaisesOpError("Attempting to use uninitialized value"):
4357          self.evaluate(v1)
4358
4359        # v2 is not initialized.
4360        with self.assertRaisesOpError("Attempting to use uninitialized value"):
4361          self.evaluate(v2)
4362
4363        if v1_first:
4364          # Getting t1 initializes v2.
4365          self.assertAllClose([3.0], self.evaluate(t1))
4366          self.assertAllClose([10.0], self.evaluate(v2))
4367        else:
4368          # Getting t2 initializes v1.
4369          self.assertAllClose([30.0], self.evaluate(t2))
4370          self.assertAllClose([1.0], self.evaluate(v1))
4371
4372  @test_util.run_v1_only("b/120545219")
4373  def testIndexedSlices(self):
4374    for v1_first in [True, False]:
4375      with self.cached_session():
4376        v1 = variables.VariableV1(
4377            np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(
4378                np.float32))
4379        v1_at_1 = ops.IndexedSlices(
4380            control_flow_ops.with_dependencies([v1.initializer], v1._ref()),  # pylint: disable=protected-access
4381            constant_op.constant([1]))
4382
4383        v2 = variables.VariableV1(
4384            np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype(
4385                np.float32))
4386        v2_at_1 = ops.IndexedSlices(
4387            control_flow_ops.with_dependencies([v2.initializer], v2._ref()),  # pylint: disable=protected-access
4388            constant_op.constant([1]))
4389
4390        st1, st2 = control_flow_ops.tuple([v1_at_1, v2_at_1])
4391        g1 = array_ops.gather(st1.values, st1.indices)
4392        g2 = array_ops.gather(st2.values, st2.indices)
4393
4394        # v1 is not initialized.
4395        with self.assertRaisesOpError("Attempting to use uninitialized value"):
4396          self.evaluate(v1)
4397
4398        # v2 is not initialized.
4399        with self.assertRaisesOpError("Attempting to use uninitialized value"):
4400          self.evaluate(v2)
4401
4402        if v1_first:
4403          # Getting g1 initializes v2.
4404          self.assertAllClose([[10.0, 11.0]], self.evaluate(g1))
4405          self.assertAllClose([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]],
4406                              self.evaluate(v2))
4407        else:
4408          # Getting g2 initializes v1.
4409          self.assertAllClose([[10.1, 11.1]], self.evaluate(g2))
4410          self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]],
4411                              self.evaluate(v1))
4412
4413  def testAcceptTensorsAsControlInputs(self):
4414    with self.cached_session():
4415      var = variables.VariableV1(0)
4416      assign = state_ops.assign(var, 1)
4417      t, = control_flow_ops.tuple(
4418          [constant_op.constant(0)], control_inputs=[assign])
4419
4420      # Should trigger the assign.
4421      self.evaluate(t)
4422
4423      self.assertEquals(1, self.evaluate(var))
4424
4425
4426class AssertTest(test.TestCase):
4427
4428  @test_util.run_deprecated_v1
4429  def testGuardedAssertDoesNotCopyWhenTrue(self):
4430    if test_util.is_gpu_available():
4431      self.skipTest("b/128646478 fails in opensource")
4432
4433    with self.session(use_gpu=True) as sess:
4434      with ops.device(test.gpu_device_name()):
4435        value = constant_op.constant(1.0)
4436      with ops.device("/cpu:0"):
4437        true = constant_op.constant(True)
4438        guarded_assert = control_flow_ops.Assert(true, [value], name="guarded")
4439        unguarded_assert = gen_logging_ops._assert(
4440            true, [value], name="unguarded")
4441      opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
4442      guarded_metadata = config_pb2.RunMetadata()
4443      sess.run(guarded_assert, options=opts, run_metadata=guarded_metadata)
4444      unguarded_metadata = config_pb2.RunMetadata()
4445      sess.run(unguarded_assert, options=opts, run_metadata=unguarded_metadata)
4446      guarded_nodestat_names = [
4447          n.node_name
4448          for d in guarded_metadata.step_stats.dev_stats
4449          for n in d.node_stats
4450      ]
4451      unguarded_nodestat_names = [
4452          n.node_name
4453          for d in unguarded_metadata.step_stats.dev_stats
4454          for n in d.node_stats
4455      ]
4456      guarded_memcpy_nodestat_names = [
4457          n for n in guarded_nodestat_names if "MEMCPYDtoH" in n
4458      ]
4459      unguarded_memcpy_nodestat_names = [
4460          n for n in unguarded_nodestat_names if "MEMCPYDtoH" in n
4461      ]
4462      if "GPU" in [d.device_type for d in device_lib.list_local_devices()]:
4463        # A copy was performed for the unguarded assert
4464        self.assertLess(0, len(unguarded_memcpy_nodestat_names),
4465                        str(unguarded_nodestat_names))
4466      # No copy was performed for the guarded assert
4467      self.assertEqual([], guarded_memcpy_nodestat_names)
4468
4469
4470class WhileOpBenchmark(test.Benchmark):
4471  """Evaluate the performance of while_loop op."""
4472
4473  def _getInitVariables(self):
4474    batch_size = 10
4475    image_size = 256
4476    kernel_size = 3
4477    depth = 16
4478
4479    init_step = constant_op.constant(-1)
4480    image = variable_scope.get_variable(
4481        "image",
4482        initializer=random_ops.random_normal(
4483            [batch_size, image_size, image_size, depth],
4484            dtype=dtypes.float32,
4485            stddev=1e-1))
4486    kernel = variable_scope.get_variable(
4487        "weights",
4488        initializer=random_ops.truncated_normal(
4489            [kernel_size, kernel_size, depth, depth],
4490            dtype=dtypes.float32,
4491            stddev=1e-1))
4492    return init_step, image, kernel
4493
4494  def _runOneBenchmark(self,
4495                       default_device,
4496                       num_iters=10,
4497                       static_unroll=False,
4498                       steps=10):
4499    """Evaluate the while loop performance.
4500
4501    Args:
4502      default_device: The default device to run all ops except the loop_body.
4503        loop_body is always run on GPU.
4504      num_iters: Number of iterations to run.
4505      static_unroll: If true, run unrolled version; otherwise, run while_loop.
4506      steps: Total number of repeated steps to run the loop.
4507
4508    Returns:
4509      The duration of the run in seconds.
4510    """
4511
4512    def loop_body(i, x):
4513      with ops.device("/gpu:0"):
4514        # Always put loop body on GPU.
4515        nx = nn_ops.conv2d(
4516            input=x,
4517            filter=kernel,
4518            strides=[1, 1, 1, 1],
4519            padding="SAME",
4520            data_format="NHWC",
4521            name="conv2d")
4522        ni = math_ops.add(i, 1)
4523        return ni, nx
4524
4525    ops.reset_default_graph()
4526    with session.Session() as sess, ops.device(default_device):
4527      # Get the initial id i, input x, and kernel.
4528      i, x, kernel = self._getInitVariables()
4529      variables.global_variables_initializer().run()
4530
4531      if static_unroll:
4532        for _ in xrange(steps):
4533          i, x = loop_body(i, x)
4534      else:
4535        i, x = control_flow_ops.while_loop(
4536            lambda i, _: i < steps,
4537            loop_body, [i, x],
4538            parallel_iterations=steps,
4539            swap_memory=True)
4540
4541      r = math_ops.reduce_sum(x)
4542      dx, dk = gradients_impl.gradients(r, [x, kernel])
4543      # Use group to avoid fetching back results.
4544      r = control_flow_ops.group(dx, dk)
4545
4546      for _ in xrange(3):
4547        # exclude warm up time
4548        self.evaluate(r)
4549
4550      start_time = time.time()
4551      for _ in xrange(num_iters):
4552        self.evaluate(r)
4553      return (time.time() - start_time) / num_iters
4554
4555  def benchmarkWhileOpCrossDevicePlacement(self):
4556    iters = 10
4557    # Run loop body on GPU, but other ops on CPU.
4558    duration = self._runOneBenchmark("cpu", iters, static_unroll=False)
4559    self.report_benchmark(
4560        name="while_op_cross_device", iters=iters, wall_time=duration)
4561
4562  def benchmarkWhileOpSameDevicePlacement(self):
4563    iters = 10
4564    # Run all ops on the same GPU device.
4565    duration = self._runOneBenchmark("gpu", iters, static_unroll=False)
4566    self.report_benchmark(
4567        name="while_op_same_device", iters=iters, wall_time=duration)
4568
4569  def benchmarkWhileOpUnrollCrossDevicePlacement(self):
4570    iters = 10
4571    # Run loop body on GPU, but other ops on CPU.
4572    duration = self._runOneBenchmark("cpu", iters, static_unroll=True)
4573    self.report_benchmark(
4574        name="unroll_cross_device_cpu", iters=iters, wall_time=duration)
4575
4576  def benchmarkWhileOpUnrollSameDevicePlacement(self):
4577    iters = 10
4578    # Run all ops on GPU.
4579    duration = self._runOneBenchmark("gpu", iters, static_unroll=True)
4580    self.report_benchmark(
4581        name="unroll_same_device", iters=iters, wall_time=duration)
4582
4583
4584@test_util.with_control_flow_v2
4585class EagerTest(test.TestCase):
4586
4587  def testCond(self):
4588    with context.eager_mode():
4589      pred = math_ops.less(1, 2)
4590      fn1 = lambda: [constant_op.constant(10)]
4591      fn2 = lambda: [constant_op.constant(20)]
4592      r = control_flow_ops.cond(pred, fn1, fn2)
4593
4594      self.assertAllEqual(r.numpy(), 10)
4595      self.assertFalse(isinstance(r, list))
4596
4597  # TODO(b/117279927): Re-enable once msan failure is fixed.
4598  def DISABLED_testCondInDefun(self):
4599    with context.eager_mode():
4600
4601      @eager_function.defun
4602      def foo(pred):
4603        # TODO(b/111124878): this only needs to output one element.
4604        fn1 = lambda: (constant_op.constant(10), constant_op.constant(100))
4605        fn2 = lambda: (constant_op.constant(20), constant_op.constant(200))
4606        return control_flow_ops.cond(constant_op.constant(pred), fn1, fn2)
4607
4608      r = foo(True)
4609      self.assertAllEqual(r[0].numpy(), 10)
4610      self.assertNotIsInstance(r, list)
4611
4612      r = foo(False)
4613      self.assertAllEqual(r[0].numpy(), 20)
4614      self.assertFalse(isinstance(r, list))
4615
4616  def testWhileLoop(self):
4617    with context.eager_mode():
4618      tensor = constant_op.constant([1, 2, 3, 4, 5])
4619      self.assertAllEqual(isum(tensor).numpy(), [46, 47, 48, 49, 50])
4620
4621  def testWhileLoopWithMaxIterations(self):
4622    with context.eager_mode():
4623      tensor = constant_op.constant([1, 2, 3, 4, 5])
4624      self.assertAllEqual(
4625          isum(tensor, maximum_iterations=3).numpy(),
4626          [1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3])
4627
4628  @test_util.run_v1_only("b/120545219")
4629  def testWhileWithMaximumIterationsAndSingleArgument(self):
4630    with context.eager_mode():
4631      tensor = constant_op.constant(0)
4632      r = control_flow_ops.while_loop(
4633          lambda i: i < 3, lambda i: i + 1, [tensor], maximum_iterations=1)
4634      self.assertEqual(1, r.numpy())
4635
4636  def testWithDependencies(self):
4637    with context.eager_mode():
4638      t1 = constant_op.constant(1)
4639      t2 = constant_op.constant(2)
4640      t3 = control_flow_ops.with_dependencies(t1, t2)
4641      self.assertAllEqual(t2.numpy(), t3.numpy())
4642
4643  def testTuple(self):
4644    with context.eager_mode():
4645      t1 = constant_op.constant(1)
4646      t2 = constant_op.constant(2)
4647      tup1, tup2 = control_flow_ops.tuple([t1, t2])
4648      self.assertAllEqual(t1.numpy(), tup1.numpy())
4649      self.assertAllEqual(t2.numpy(), tup2.numpy())
4650
4651  @test_util.run_v1_only("b/120545219")
4652  def testCase(self):
4653    with context.eager_mode():
4654      x = constant_op.constant(1)
4655      y = constant_op.constant(2)
4656      z = constant_op.constant(3)
4657      f1 = lambda: constant_op.constant(17)
4658      f2 = lambda: constant_op.constant(23)
4659      f3 = lambda: constant_op.constant(-1)
4660
4661      r1 = control_flow_ops.case(
4662          [(x < y, f1), (x > z, f2)], default=f3, exclusive=True)
4663      self.assertAllEqual(r1.numpy(), 17)
4664
4665
4666if __name__ == "__main__":
4667  test.main()
4668