1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""critical section tests.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import itertools 22 23from absl.testing import parameterized 24 25from tensorflow.python.data.experimental.ops import prefetching_ops 26from tensorflow.python.data.ops import dataset_ops 27from tensorflow.python.eager import context 28from tensorflow.python.eager import def_function 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import test_util 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import control_flow_v2_toggles 34from tensorflow.python.ops import critical_section_ops 35from tensorflow.python.ops import resource_variable_ops 36from tensorflow.python.platform import test 37from tensorflow.python.platform import tf_logging as logging 38# TODO(ebrevdo): Re-enable once CriticalSection is in core. 39# from tensorflow.python.training import saver as saver_lib 40 41 42@test_util.with_control_flow_v2 43class CriticalSectionTest(test.TestCase, parameterized.TestCase): 44 45 @test_util.run_in_graph_and_eager_modes 46 def testCreateCriticalSection(self): 47 cs = critical_section_ops.CriticalSection(shared_name="cs") 48 v = resource_variable_ops.ResourceVariable(0.0, name="v") 49 50 def fn(a, b): 51 c = v.value() 52 with ops.control_dependencies([c]): 53 nv = v.assign_add(a * b) 54 with ops.control_dependencies([nv]): 55 return array_ops.identity(c) 56 57 num_concurrent = 100 58 r = [cs.execute(lambda: fn(1.0, 2.0)) for _ in range(num_concurrent)] 59 self.evaluate(v.initializer) 60 r_value = self.evaluate(r) 61 self.assertAllClose([2.0 * i for i in range(num_concurrent)], 62 sorted(r_value)) 63 64 @parameterized.named_parameters( 65 ("Inner%sOuter%s" % (inner, outer), inner, outer) 66 for (inner, outer) in itertools.product(*([(False, True)] * 2))) 67 @test_util.run_in_graph_and_eager_modes 68 @test_util.xla_allow_fallback("b/128495870") 69 def testCriticalSectionWithControlFlow(self, outer_cond, inner_cond): 70 if (not context.executing_eagerly() and 71 control_flow_v2_toggles.control_flow_v2_enabled()): 72 self.skipTest("b/135070612") 73 cs = critical_section_ops.CriticalSection(shared_name="cs") 74 v = resource_variable_ops.ResourceVariable(0.0, name="v") 75 num_concurrent = 100 76 77 # pylint: disable=cell-var-from-loop 78 def fn(a, b): 79 c = v.read_value() 80 def true_fn(): 81 with ops.control_dependencies([c]): 82 nv = v.assign_add(a * b) 83 with ops.control_dependencies([nv]): 84 return array_ops.identity(c) 85 return control_flow_ops.cond( 86 array_ops.identity(inner_cond), true_fn, lambda: c) 87 88 def execute(): 89 return cs.execute(lambda: fn(1.0, 2.0)) 90 91 r = [ 92 control_flow_ops.cond(array_ops.identity(outer_cond), 93 execute, 94 v.read_value) 95 for _ in range(num_concurrent) 96 ] 97 # pylint: enable=cell-var-from-loop 98 99 self.evaluate(v.initializer) 100 r_value = self.evaluate(r) 101 if inner_cond and outer_cond: 102 self.assertAllClose([2.0 * i for i in range(num_concurrent)], 103 sorted(r_value)) 104 else: 105 self.assertAllClose([0] * num_concurrent, r_value) 106 107 @test_util.run_v1_only("b/123990562 Sees CancelledError on some calls") 108 def testCriticalSectionInParallelDoesntDeadlockOnError(self): 109 # No eager mode execution of this test because eager does not 110 # run fn() in parallel, which is where the deadlock could 111 # potentially occur (in graph mode). 112 cs = critical_section_ops.CriticalSection(shared_name="cs") 113 v = resource_variable_ops.ResourceVariable(0.0, name="v") 114 115 def fn(i): 116 error = control_flow_ops.Assert((i % 2) == 1, ["Error"]) 117 with ops.control_dependencies([error]): 118 return v.read_value() 119 120 num_concurrent = 2 121 122 @def_function.function(autograph=False) 123 def run_concurrently(): 124 return [cs.execute(lambda: fn(i)) for i in range(num_concurrent)] 125 126 if not context.executing_eagerly(): 127 run_concurrently = run_concurrently() 128 129 self.evaluate(v.initializer) 130 for _ in range(100): 131 with self.assertRaisesOpError("Error"): 132 if context.executing_eagerly(): 133 run_concurrently() 134 else: 135 self.evaluate(run_concurrently) 136 137 @test_util.run_in_graph_and_eager_modes 138 def testCreateCriticalSectionFnReturnsOp(self): 139 cs = critical_section_ops.CriticalSection(shared_name="cs") 140 v = resource_variable_ops.ResourceVariable(0.0, name="v") 141 142 def fn_return_op(a, b): 143 c = v.read_value() 144 with ops.control_dependencies([c]): 145 nv = v.assign_add(a * b) 146 with ops.control_dependencies([nv]): 147 return control_flow_ops.no_op() 148 149 num_concurrent = 100 150 r = [cs.execute(lambda: fn_return_op(1.0, 2.0)) 151 for _ in range(num_concurrent)] 152 self.evaluate(v.initializer) 153 self.evaluate(r) 154 final_v = self.evaluate(v) 155 self.assertAllClose(2.0 * num_concurrent, final_v) 156 157 @test_util.run_v1_only("Collections don't exist in TF2") 158 def testCollection(self): 159 cs = critical_section_ops.CriticalSection(shared_name="cs") 160 self.assertIn( 161 cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)) 162 add = lambda x: x + 1 163 execute = cs.execute(lambda: add(1.0), name="my_execute") 164 execute_op = [ 165 x for x in execute.graph.get_operations() 166 if "my_execute" in x.name and "MutexLock" in x.type 167 ][0] 168 self.assertIn( 169 execute_op, 170 [signature.op for signature in 171 ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)]) 172 173 def testRecursiveCriticalSectionAccessIsIllegal(self): 174 # This does not work properly in eager mode. Eager users will 175 # just hit a deadlock if they do this. But at least it'll be easier 176 # to debug. 177 cs = critical_section_ops.CriticalSection() 178 add = lambda y: y + 1 179 def fn(x): 180 return cs.execute(lambda: add(x)) 181 182 with self.assertRaisesRegex( 183 ValueError, r"Attempting to lock a CriticalSection in which we are"): 184 cs.execute(lambda: fn(1.0)) 185 186 def testRecursiveCriticalSectionAccessViaCapturedTensorIsProtected(self): 187 # This one is subtle; and we're being overly cautious here. The 188 # deadlock we are ensuring we catch is: 189 # 190 # to_capture = CS[lambda x: x + 1](1.0) 191 # deadlocked = CS[lambda x: x + to_capture](1.0) 192 # 193 # This would have caused a deadlock because executing `deadlocked` will 194 # lock the mutex on CS; but then due to dependencies, will attempt 195 # to compute `to_capture`. This computation requires locking CS, 196 # but that is not possible now because CS is already locked by 197 # `deadlocked`. 198 # 199 # We check that CriticalSection.execute properly inserts new 200 # control dependencies to its lock to ensure all captured 201 # operations are finished before anything runs within the critical section. 202 cs = critical_section_ops.CriticalSection(shared_name="cs") 203 fn = array_ops.identity 204 to_capture = cs.execute(lambda: fn(1.0)) 205 fn_captures = lambda x: x + to_capture 206 to_capture_too = array_ops.identity(to_capture) 207 208 ex_0 = cs.execute(lambda: fn_captures(1.0)) 209 210 with ops.control_dependencies([to_capture]): 211 # This is OK because to_capture will execute before this next call 212 ex_1 = cs.execute(lambda: fn_captures(1.0)) 213 214 dependency = array_ops.identity(to_capture) 215 216 fn_captures_dependency = lambda x: x + dependency 217 218 ex_2 = cs.execute(lambda: fn_captures_dependency(1.0)) 219 220 with ops.control_dependencies([to_capture_too]): 221 ex_3 = cs.execute(lambda: fn_captures_dependency(1.0)) 222 223 # Ensure there's no actual deadlock on to_execute. 224 self.assertEqual(2.0, self.evaluate(ex_0)) 225 self.assertEqual(2.0, self.evaluate(ex_1)) 226 self.assertEqual(2.0, self.evaluate(ex_2)) 227 self.assertEqual(2.0, self.evaluate(ex_3)) 228 229 def testRecursiveCriticalSectionAccessWithinLoopIsProtected(self): 230 cs = critical_section_ops.CriticalSection(shared_name="cs") 231 232 def body_implicit_capture(i, j): 233 # This would have caused a deadlock if not for logic in execute 234 # that inserts additional control dependencies onto the lock op: 235 # * Loop body argument j is captured by fn() 236 # * i is running in parallel to move forward the execution 237 # * j is not being checked by the predicate function 238 # * output of cs.execute() is returned as next j. 239 fn = lambda: j + 1 240 return (i + 1, cs.execute(fn)) 241 242 (i_n, j_n) = control_flow_ops.while_loop( 243 lambda i, _: i < 1000, 244 body_implicit_capture, 245 [0, 0], 246 parallel_iterations=25) 247 # For consistency between eager and graph mode. 248 i_n = array_ops.identity(i_n) 249 logging.warn( 250 "\n==============\nRunning " 251 "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " 252 "body_implicit_capture'\n" 253 "==============\n") 254 self.assertEqual((1000, 1000), self.evaluate((i_n, j_n))) 255 logging.warn( 256 "\n==============\nSuccessfully finished running " 257 "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " 258 "body_implicit_capture'\n" 259 "==============\n") 260 261 def body_implicit_capture_protected(i, j): 262 # This version is ok because we manually add a control 263 # dependency on j, which is an argument to the while_loop body 264 # and captured by fn. 265 fn = lambda: j + 1 266 with ops.control_dependencies([j]): 267 return (i + 1, cs.execute(fn)) 268 269 (i_n, j_n) = control_flow_ops.while_loop( 270 lambda i, _: i < 1000, 271 body_implicit_capture_protected, 272 [0, 0], 273 parallel_iterations=25) 274 # For consistency between eager and graph mode. 275 i_n = array_ops.identity(i_n) 276 logging.warn( 277 "\n==============\nRunning " 278 "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " 279 "body_implicit_capture_protected'\n" 280 "==============\n") 281 self.assertEqual((1000, 1000), self.evaluate((i_n, j_n))) 282 logging.warn( 283 "\n==============\nSuccessfully finished running " 284 "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " 285 "body_implicit_capture_protected'\n" 286 "==============\n") 287 288 def body_args_capture(i, j): 289 # This version is ok because j is an argument to fn and we can 290 # ensure there's a control dependency on j. 291 fn = lambda x: x + 1 292 return (i + 1, cs.execute(lambda: fn(j))) 293 294 (i_n, j_n) = control_flow_ops.while_loop( 295 lambda i, _: i < 1000, 296 body_args_capture, 297 [0, 0], 298 parallel_iterations=25) 299 # For consistency between eager and graph mode. 300 i_n = array_ops.identity(i_n) 301 logging.warn( 302 "\n==============\nRunning " 303 "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " 304 "body_args_capture'\n" 305 "==============\n") 306 self.assertEqual((1000, 1000), self.evaluate((i_n, j_n))) 307 logging.warn( 308 "\n==============\nSuccessfully finished running " 309 "'testRecursiveCriticalSectionAccessWithinLoopDoesNotDeadlock " 310 "body_args_capture'\n" 311 "==============\n") 312 313 def testRecursiveCriticalSectionAccessIsIllegalSameSharedName(self): 314 # This does not work properly in eager mode. Eager users will 315 # just hit a deadlock if they do this. But at least it'll be easier 316 # to debug. 317 cs = critical_section_ops.CriticalSection(shared_name="cs") 318 cs_same = critical_section_ops.CriticalSection(shared_name="cs") 319 add = lambda x: x + 1 320 def fn(x): 321 return cs_same.execute(lambda: add(x)) 322 323 with self.assertRaisesRegex( 324 ValueError, r"Attempting to lock a CriticalSection in which we are"): 325 cs.execute(lambda: fn(1.0)) 326 327 @test_util.run_v1_only( 328 "b/123955885 Can't identify consumed resources in eager mode") 329 def testMultipleCSExecutionsRequestSameResource(self): 330 cs0 = critical_section_ops.CriticalSection() 331 cs1 = critical_section_ops.CriticalSection() 332 v = resource_variable_ops.ResourceVariable(0.0, name="v") 333 cs0.execute(lambda: v + 1) 334 # It's OK for the same CriticalSection to access this resource. 335 cs0.execute(lambda: v - 1) 336 # It's *not* OK for a different CriticalSection to access it by 337 # default. 338 with self.assertRaisesRegex(ValueError, 339 "requested exclusive resource access"): 340 cs1.execute(lambda: v + 1) 341 # It's not even OK if the second call doesn't request exclusive access. 342 with self.assertRaisesRegex(ValueError, 343 "requested exclusive resource access"): 344 cs1.execute(lambda: v + 1, exclusive_resource_access=False) 345 346 v2 = resource_variable_ops.ResourceVariable(0.0, name="v2") 347 cs0.execute(lambda: v2 + 1, exclusive_resource_access=False) 348 # It's OK if neither requests exclusive resource access. 349 cs1.execute(lambda: v2 + 1, exclusive_resource_access=False) 350 351 # It's not OK if the second request requires exclusive resource 352 # access. 353 with self.assertRaisesRegex(ValueError, 354 "requested exclusive resource access"): 355 cs1.execute(lambda: v2 + 1) 356 357 def testControlDependencyFromOutsideWhileLoopMixedWithInsideLoop(self): 358 cs = critical_section_ops.CriticalSection() 359 v = resource_variable_ops.ResourceVariable(0, name="v") 360 # Make sure that the control dependencies on v do not cause issues 361 # in the lock_op's automatic control dependency adder. 362 # 363 # Note, here v must be a resource variable (or something similar), 364 # otherwise it gets hoisted into the while_loop by the time we add 365 # control dependencies to the lock_op. 366 def body(i): 367 add_j = lambda j: v + j + 1 368 return cs.execute(lambda: add_j(i)) 369 out = control_flow_ops.while_loop( 370 lambda i: i < 10, body, [0]) 371 self.evaluate(v.initializer) 372 self.assertEqual(10, self.evaluate(out)) 373 374 @test_util.run_in_graph_and_eager_modes 375 def testInsideFunction(self): 376 if test_util.is_gpu_available(): 377 self.skipTest( 378 "b/123899495: Colocation errors for critical sections in map on GPU") 379 cs = critical_section_ops.CriticalSection() 380 with ops.device("/gpu:0" if test_util.is_gpu_available() else "/cpu:0"): 381 v = resource_variable_ops.ResourceVariable(1) 382 def fn(): 383 return v.read_value() 384 385 # map() creates a TensorFlow function. 386 ds = dataset_ops.Dataset.range(1) 387 if test_util.is_gpu_available(): 388 ds = (ds.apply(prefetching_ops.copy_to_device("/gpu:0")) 389 .apply(prefetching_ops.map_on_gpu(lambda _: cs.execute(fn)))) 390 else: 391 ds = ds.map(lambda _: cs.execute(fn)) 392 393 def get_first(): 394 if context.executing_eagerly(): 395 return self.evaluate(dataset_ops.make_one_shot_iterator(ds).get_next()) 396 itr = dataset_ops.make_initializable_iterator(ds) 397 self.evaluate([v.initializer, itr.initializer]) 398 return self.evaluate(itr.get_next()) 399 400 self.assertEqual(1, get_first()) 401 402 403if __name__ == "__main__": 404 test.main() 405