1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""critical section tests."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import itertools
22
23from absl.testing import parameterized
24
25from tensorflow.python.data.experimental.ops import prefetching_ops
26from tensorflow.python.data.ops import dataset_ops
27from tensorflow.python.eager import context
28from tensorflow.python.eager import def_function
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import test_util
31from tensorflow.python.ops import array_ops
32from tensorflow.python.ops import control_flow_ops
33from tensorflow.python.ops import control_flow_v2_toggles
34from tensorflow.python.ops import critical_section_ops
35from tensorflow.python.ops import resource_variable_ops
36from tensorflow.python.platform import test
37from tensorflow.python.platform import tf_logging as logging
38# TODO(ebrevdo): Re-enable once CriticalSection is in core.
39# from tensorflow.python.training import saver as saver_lib
40
41
42@test_util.with_control_flow_v2
43class CriticalSectionTest(test.TestCase, parameterized.TestCase):
44
45  @test_util.run_in_graph_and_eager_modes
46  def testCreateCriticalSection(self):
47    cs = critical_section_ops.CriticalSection(shared_name="cs")
48    v = resource_variable_ops.ResourceVariable(0.0, name="v")
49
50    def fn(a, b):
51      c = v.value()
52      with ops.control_dependencies([c]):
53        nv = v.assign_add(a * b)
54        with ops.control_dependencies([nv]):
55          return array_ops.identity(c)
56
57    num_concurrent = 100
58    r = [cs.execute(lambda: fn(1.0, 2.0)) for _ in range(num_concurrent)]
59    self.evaluate(v.initializer)
60    r_value = self.evaluate(r)
61    self.assertAllClose([2.0 * i for i in range(num_concurrent)],
62                        sorted(r_value))
63
64  @parameterized.named_parameters(
65      ("Inner%sOuter%s" % (inner, outer), inner, outer)
66      for (inner, outer) in itertools.product(*([(False, True)] * 2)))
67  @test_util.run_in_graph_and_eager_modes
68  @test_util.xla_allow_fallback("b/128495870")
69  def testCriticalSectionWithControlFlow(self, outer_cond, inner_cond):
70    if (not context.executing_eagerly() and
71        control_flow_v2_toggles.control_flow_v2_enabled()):
72      self.skipTest("b/135070612")
73    cs = critical_section_ops.CriticalSection(shared_name="cs")
74    v = resource_variable_ops.ResourceVariable(0.0, name="v")
75    num_concurrent = 100
76
77    # pylint: disable=cell-var-from-loop
78    def fn(a, b):
79      c = v.read_value()
80      def true_fn():
81        with ops.control_dependencies([c]):
82          nv = v.assign_add(a * b)
83          with ops.control_dependencies([nv]):
84            return array_ops.identity(c)
85      return control_flow_ops.cond(
86          array_ops.identity(inner_cond), true_fn, lambda: c)
87
88    def execute():
89      return cs.execute(lambda: fn(1.0, 2.0))
90
91    r = [
92        control_flow_ops.cond(array_ops.identity(outer_cond),
93                              execute,
94                              v.read_value)
95        for _ in range(num_concurrent)
96    ]
97    # pylint: enable=cell-var-from-loop
98
99    self.evaluate(v.initializer)
100    r_value = self.evaluate(r)
101    if inner_cond and outer_cond:
102      self.assertAllClose([2.0 * i for i in range(num_concurrent)],
103                          sorted(r_value))
104    else:
105      self.assertAllClose([0] * num_concurrent, r_value)
106
107  @test_util.run_v1_only("b/123990562 Sees CancelledError on some calls")
108  def testCriticalSectionInParallelDoesntDeadlockOnError(self):
109    # No eager mode execution of this test because eager does not
110    # run fn() in parallel, which is where the deadlock could
111    # potentially occur (in graph mode).
112    cs = critical_section_ops.CriticalSection(shared_name="cs")
113    v = resource_variable_ops.ResourceVariable(0.0, name="v")
114
115    def fn(i):
116      error = control_flow_ops.Assert((i % 2) == 1, ["Error"])
117      with ops.control_dependencies([error]):
118        return v.read_value()
119
120    num_concurrent = 2
121
122    @def_function.function(autograph=False)
123    def run_concurrently():
124      return [cs.execute(lambda: fn(i)) for i in range(num_concurrent)]
125
126    if not context.executing_eagerly():
127      run_concurrently = run_concurrently()
128
129    self.evaluate(v.initializer)
130    for _ in range(100):
131      with self.assertRaisesOpError("Error"):
132        if context.executing_eagerly():
133          run_concurrently()
134        else:
135          self.evaluate(run_concurrently)
136
137  @test_util.run_in_graph_and_eager_modes
138  def testCreateCriticalSectionFnReturnsOp(self):
139    cs = critical_section_ops.CriticalSection(shared_name="cs")
140    v = resource_variable_ops.ResourceVariable(0.0, name="v")
141
142    def fn_return_op(a, b):
143      c = v.read_value()
144      with ops.control_dependencies([c]):
145        nv = v.assign_add(a * b)
146        with ops.control_dependencies([nv]):
147          return control_flow_ops.no_op()
148
149    num_concurrent = 100
150    r = [cs.execute(lambda: fn_return_op(1.0, 2.0))
151         for _ in range(num_concurrent)]
152    self.evaluate(v.initializer)
153    self.evaluate(r)
154    final_v = self.evaluate(v)
155    self.assertAllClose(2.0 * num_concurrent, final_v)
156
157  @test_util.run_v1_only("Collections don't exist in TF2")
158  def testCollection(self):
159    cs = critical_section_ops.CriticalSection(shared_name="cs")
160    self.assertIn(
161        cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS))
162    add = lambda x: x + 1
163    execute = cs.execute(lambda: add(1.0), name="my_execute")
164    execute_op = [
165        x for x in execute.graph.get_operations()
166        if "my_execute" in x.name and "MutexLock" in x.type
167    ][0]
168    self.assertIn(
169        execute_op,
170        [signature.op for signature in
171         ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)])
172
173  def testRecursiveCriticalSectionAccessIsIllegal(self):
174    # This does not work properly in eager mode.  Eager users will
175    # just hit a deadlock if they do this.  But at least it'll be easier
176    # to debug.
177    cs = critical_section_ops.CriticalSection()
178    add = lambda y: y + 1
179    def fn(x):
180      return cs.execute(lambda: add(x))
181
182    with self.assertRaisesRegex(
183        ValueError, r"Attempting to lock a CriticalSection in which we are"):
184      cs.execute(lambda: fn(1.0))
185
186  def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self):
187    # This one is subtle; and we're being overly cautious here.  The
188    # deadlock we are ensuring we catch is:
189    #
190    # to_capture = CS[lambda x: x + 1](1.0)
191    # deadlocked = CS[lambda x: x + to_capture](1.0)
192    #
193    # This would have caused a deadlock because executing `deadlocked` will
194    # lock the mutex on CS; but then due to dependencies, will attempt
195    # to compute `to_capture`.  This computation requires locking CS,
196    # but that is not possible now because CS is already locked by
197    # `deadlocked`.
198    #
199    # We check that CriticalSection.execute properly inserts new
200    # control dependencies to its lock to ensure all captured
201    # operations are finished before anything runs within the critical section.
202    cs = critical_section_ops.CriticalSection(shared_name="cs")
203    fn = array_ops.identity
204    to_capture = cs.execute(lambda: fn(1.0))
205    fn_captures = lambda x: x + to_capture
206    to_capture_too = array_ops.identity(to_capture)
207
208    ex_0 = cs.execute(lambda: fn_captures(1.0))
209
210    with ops.control_dependencies([to_capture]):
211      # This is OK because to_capture will execute before this next call
212      ex_1 = cs.execute(lambda: fn_captures(1.0))
213
214    dependency = array_ops.identity(to_capture)
215
216    fn_captures_dependency = lambda x: x + dependency
217
218    ex_2 = cs.execute(lambda: fn_captures_dependency(1.0))
219
220    with ops.control_dependencies([to_capture_too]):
221      ex_3 = cs.execute(lambda: fn_captures_dependency(1.0))
222
223    # Ensure there's no actual deadlock on to_execute.
224    self.assertEqual(2.0, self.evaluate(ex_0))
225    self.assertEqual(2.0, self.evaluate(ex_1))
226    self.assertEqual(2.0, self.evaluate(ex_2))
227    self.assertEqual(2.0, self.evaluate(ex_3))
228
229  def testRecursiveCriticalSectionAccessWithinLoopIsProtected(self):
230    cs = critical_section_ops.CriticalSection(shared_name="cs")
231
232    def body_implicit_capture(i, j):
233      # This would have caused a deadlock if not for logic in execute
234      # that inserts additional control dependencies onto the lock op:
235      #   * Loop body argument j is captured by fn()
236      #   * i is running in parallel to move forward the execution
237      #   * j is not being checked by the predicate function
238      #   * output of cs.execute() is returned as next j.
239      fn = lambda: j + 1
240      return (i + 1, cs.execute(fn))
241
242    (i_n, j_n) = control_flow_ops.while_loop(
243        lambda i, _: i < 1000,
244        body_implicit_capture,
245        [0, 0],
246        parallel_iterations=25)
247    # For consistency between eager and graph mode.
248    i_n = array_ops.identity(i_n)
249    logging.warn(
250        "\n==============\nRunning "
251        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
252        "body_implicit_capture'\n"
253        "==============\n")
254    self.assertEqual((1000, 1000), self.evaluate((i_n, j_n)))
255    logging.warn(
256        "\n==============\nSuccessfully finished running "
257        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
258        "body_implicit_capture'\n"
259        "==============\n")
260
261    def body_implicit_capture_protected(i, j):
262      # This version is ok because we manually add a control
263      # dependency on j, which is an argument to the while_loop body
264      # and captured by fn.
265      fn = lambda: j + 1
266      with ops.control_dependencies([j]):
267        return (i + 1, cs.execute(fn))
268
269    (i_n, j_n) = control_flow_ops.while_loop(
270        lambda i, _: i < 1000,
271        body_implicit_capture_protected,
272        [0, 0],
273        parallel_iterations=25)
274    # For consistency between eager and graph mode.
275    i_n = array_ops.identity(i_n)
276    logging.warn(
277        "\n==============\nRunning "
278        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
279        "body_implicit_capture_protected'\n"
280        "==============\n")
281    self.assertEqual((1000, 1000), self.evaluate((i_n, j_n)))
282    logging.warn(
283        "\n==============\nSuccessfully finished running "
284        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
285        "body_implicit_capture_protected'\n"
286        "==============\n")
287
288    def body_args_capture(i, j):
289      # This version is ok because j is an argument to fn and we can
290      # ensure there's a control dependency on j.
291      fn = lambda x: x + 1
292      return (i + 1, cs.execute(lambda: fn(j)))
293
294    (i_n, j_n) = control_flow_ops.while_loop(
295        lambda i, _: i < 1000,
296        body_args_capture,
297        [0, 0],
298        parallel_iterations=25)
299    # For consistency between eager and graph mode.
300    i_n = array_ops.identity(i_n)
301    logging.warn(
302        "\n==============\nRunning "
303        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
304        "body_args_capture'\n"
305        "==============\n")
306    self.assertEqual((1000, 1000), self.evaluate((i_n, j_n)))
307    logging.warn(
308        "\n==============\nSuccessfully finished running "
309        "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock "
310        "body_args_capture'\n"
311        "==============\n")
312
313  def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self):
314    # This does not work properly in eager mode.  Eager users will
315    # just hit a deadlock if they do this.  But at least it'll be easier
316    # to debug.
317    cs = critical_section_ops.CriticalSection(shared_name="cs")
318    cs_same = critical_section_ops.CriticalSection(shared_name="cs")
319    add = lambda x: x + 1
320    def fn(x):
321      return cs_same.execute(lambda: add(x))
322
323    with self.assertRaisesRegex(
324        ValueError, r"Attempting to lock a CriticalSection in which we are"):
325      cs.execute(lambda: fn(1.0))
326
327  @test_util.run_v1_only(
328      "b/123955885 Can't identify consumed resources in eager mode")
329  def testMultipleCSExecutionsRequestSameResource(self):
330    cs0 = critical_section_ops.CriticalSection()
331    cs1 = critical_section_ops.CriticalSection()
332    v = resource_variable_ops.ResourceVariable(0.0, name="v")
333    cs0.execute(lambda: v + 1)
334    # It's OK for the same CriticalSection to access this resource.
335    cs0.execute(lambda: v - 1)
336    # It's *not* OK for a different CriticalSection to access it by
337    # default.
338    with self.assertRaisesRegex(ValueError,
339                                "requested exclusive resource access"):
340      cs1.execute(lambda: v + 1)
341    # It's not even OK if the second call doesn't request exclusive access.
342    with self.assertRaisesRegex(ValueError,
343                                "requested exclusive resource access"):
344      cs1.execute(lambda: v + 1, exclusive_resource_access=False)
345
346    v2 = resource_variable_ops.ResourceVariable(0.0, name="v2")
347    cs0.execute(lambda: v2 + 1, exclusive_resource_access=False)
348    # It's OK if neither requests exclusive resource access.
349    cs1.execute(lambda: v2 + 1, exclusive_resource_access=False)
350
351    # It's not OK if the second request requires exclusive resource
352    # access.
353    with self.assertRaisesRegex(ValueError,
354                                "requested exclusive resource access"):
355      cs1.execute(lambda: v2 + 1)
356
357  def testControlDependencyFromOutsideWhileLoopMixedWithInsideLoop(self):
358    cs = critical_section_ops.CriticalSection()
359    v = resource_variable_ops.ResourceVariable(0, name="v")
360    # Make sure that the control dependencies on v do not cause issues
361    # in the lock_op's automatic control dependency adder.
362    #
363    # Note, here v must be a resource variable (or something similar),
364    # otherwise it gets hoisted into the while_loop by the time we add
365    # control dependencies to the lock_op.
366    def body(i):
367      add_j = lambda j: v + j + 1
368      return cs.execute(lambda: add_j(i))
369    out = control_flow_ops.while_loop(
370        lambda i: i < 10, body, [0])
371    self.evaluate(v.initializer)
372    self.assertEqual(10, self.evaluate(out))
373
374  @test_util.run_in_graph_and_eager_modes
375  def testInsideFunction(self):
376    if test_util.is_gpu_available():
377      self.skipTest(
378          "b/123899495: Colocation errors for critical sections in map on GPU")
379    cs = critical_section_ops.CriticalSection()
380    with ops.device("/gpu:0" if test_util.is_gpu_available() else "/cpu:0"):
381      v = resource_variable_ops.ResourceVariable(1)
382    def fn():
383      return v.read_value()
384
385    # map() creates a TensorFlow function.
386    ds = dataset_ops.Dataset.range(1)
387    if test_util.is_gpu_available():
388      ds = (ds.apply(prefetching_ops.copy_to_device("/gpu:0"))
389            .apply(prefetching_ops.map_on_gpu(lambda _: cs.execute(fn))))
390    else:
391      ds = ds.map(lambda _: cs.execute(fn))
392
393    def get_first():
394      if context.executing_eagerly():
395        return self.evaluate(dataset_ops.make_one_shot_iterator(ds).get_next())
396      itr = dataset_ops.make_initializable_iterator(ds)
397      self.evaluate([v.initializer, itr.initializer])
398      return self.evaluate(itr.get_next())
399
400    self.assertEqual(1, get_first())
401
402
403if __name__ == "__main__":
404  test.main()
405