1""" 2Various tests for synchronization primitives. 3""" 4 5import os 6import sys 7import time 8from _thread import start_new_thread, TIMEOUT_MAX 9import threading 10import unittest 11import weakref 12 13from test import support 14 15 16requires_fork = unittest.skipUnless(hasattr(os, 'fork'), 17 "platform doesn't support fork " 18 "(no _at_fork_reinit method)") 19 20 21def _wait(): 22 # A crude wait/yield function not relying on synchronization primitives. 23 time.sleep(0.01) 24 25class Bunch(object): 26 """ 27 A bunch of threads. 28 """ 29 def __init__(self, f, n, wait_before_exit=False): 30 """ 31 Construct a bunch of `n` threads running the same function `f`. 32 If `wait_before_exit` is True, the threads won't terminate until 33 do_finish() is called. 34 """ 35 self.f = f 36 self.n = n 37 self.started = [] 38 self.finished = [] 39 self._can_exit = not wait_before_exit 40 self.wait_thread = support.wait_threads_exit() 41 self.wait_thread.__enter__() 42 43 def task(): 44 tid = threading.get_ident() 45 self.started.append(tid) 46 try: 47 f() 48 finally: 49 self.finished.append(tid) 50 while not self._can_exit: 51 _wait() 52 53 try: 54 for i in range(n): 55 start_new_thread(task, ()) 56 except: 57 self._can_exit = True 58 raise 59 60 def wait_for_started(self): 61 while len(self.started) < self.n: 62 _wait() 63 64 def wait_for_finished(self): 65 while len(self.finished) < self.n: 66 _wait() 67 # Wait for threads exit 68 self.wait_thread.__exit__(None, None, None) 69 70 def do_finish(self): 71 self._can_exit = True 72 73 74class BaseTestCase(unittest.TestCase): 75 def setUp(self): 76 self._threads = support.threading_setup() 77 78 def tearDown(self): 79 support.threading_cleanup(*self._threads) 80 support.reap_children() 81 82 def assertTimeout(self, actual, expected): 83 # The waiting and/or time.monotonic() can be imprecise, which 84 # is why comparing to the expected value would sometimes fail 85 # (especially under Windows). 86 self.assertGreaterEqual(actual, expected * 0.6) 87 # Test nothing insane happened 88 self.assertLess(actual, expected * 10.0) 89 90 91class BaseLockTests(BaseTestCase): 92 """ 93 Tests for both recursive and non-recursive locks. 94 """ 95 96 def test_constructor(self): 97 lock = self.locktype() 98 del lock 99 100 def test_repr(self): 101 lock = self.locktype() 102 self.assertRegex(repr(lock), "<unlocked .* object (.*)?at .*>") 103 del lock 104 105 def test_locked_repr(self): 106 lock = self.locktype() 107 lock.acquire() 108 self.assertRegex(repr(lock), "<locked .* object (.*)?at .*>") 109 del lock 110 111 def test_acquire_destroy(self): 112 lock = self.locktype() 113 lock.acquire() 114 del lock 115 116 def test_acquire_release(self): 117 lock = self.locktype() 118 lock.acquire() 119 lock.release() 120 del lock 121 122 def test_try_acquire(self): 123 lock = self.locktype() 124 self.assertTrue(lock.acquire(False)) 125 lock.release() 126 127 def test_try_acquire_contended(self): 128 lock = self.locktype() 129 lock.acquire() 130 result = [] 131 def f(): 132 result.append(lock.acquire(False)) 133 Bunch(f, 1).wait_for_finished() 134 self.assertFalse(result[0]) 135 lock.release() 136 137 def test_acquire_contended(self): 138 lock = self.locktype() 139 lock.acquire() 140 N = 5 141 def f(): 142 lock.acquire() 143 lock.release() 144 145 b = Bunch(f, N) 146 b.wait_for_started() 147 _wait() 148 self.assertEqual(len(b.finished), 0) 149 lock.release() 150 b.wait_for_finished() 151 self.assertEqual(len(b.finished), N) 152 153 def test_with(self): 154 lock = self.locktype() 155 def f(): 156 lock.acquire() 157 lock.release() 158 def _with(err=None): 159 with lock: 160 if err is not None: 161 raise err 162 _with() 163 # Check the lock is unacquired 164 Bunch(f, 1).wait_for_finished() 165 self.assertRaises(TypeError, _with, TypeError) 166 # Check the lock is unacquired 167 Bunch(f, 1).wait_for_finished() 168 169 def test_thread_leak(self): 170 # The lock shouldn't leak a Thread instance when used from a foreign 171 # (non-threading) thread. 172 lock = self.locktype() 173 def f(): 174 lock.acquire() 175 lock.release() 176 n = len(threading.enumerate()) 177 # We run many threads in the hope that existing threads ids won't 178 # be recycled. 179 Bunch(f, 15).wait_for_finished() 180 if len(threading.enumerate()) != n: 181 # There is a small window during which a Thread instance's 182 # target function has finished running, but the Thread is still 183 # alive and registered. Avoid spurious failures by waiting a 184 # bit more (seen on a buildbot). 185 time.sleep(0.4) 186 self.assertEqual(n, len(threading.enumerate())) 187 188 def test_timeout(self): 189 lock = self.locktype() 190 # Can't set timeout if not blocking 191 self.assertRaises(ValueError, lock.acquire, False, 1) 192 # Invalid timeout values 193 self.assertRaises(ValueError, lock.acquire, timeout=-100) 194 self.assertRaises(OverflowError, lock.acquire, timeout=1e100) 195 self.assertRaises(OverflowError, lock.acquire, timeout=TIMEOUT_MAX + 1) 196 # TIMEOUT_MAX is ok 197 lock.acquire(timeout=TIMEOUT_MAX) 198 lock.release() 199 t1 = time.monotonic() 200 self.assertTrue(lock.acquire(timeout=5)) 201 t2 = time.monotonic() 202 # Just a sanity test that it didn't actually wait for the timeout. 203 self.assertLess(t2 - t1, 5) 204 results = [] 205 def f(): 206 t1 = time.monotonic() 207 results.append(lock.acquire(timeout=0.5)) 208 t2 = time.monotonic() 209 results.append(t2 - t1) 210 Bunch(f, 1).wait_for_finished() 211 self.assertFalse(results[0]) 212 self.assertTimeout(results[1], 0.5) 213 214 def test_weakref_exists(self): 215 lock = self.locktype() 216 ref = weakref.ref(lock) 217 self.assertIsNotNone(ref()) 218 219 def test_weakref_deleted(self): 220 lock = self.locktype() 221 ref = weakref.ref(lock) 222 del lock 223 self.assertIsNone(ref()) 224 225 226class LockTests(BaseLockTests): 227 """ 228 Tests for non-recursive, weak locks 229 (which can be acquired and released from different threads). 230 """ 231 def test_reacquire(self): 232 # Lock needs to be released before re-acquiring. 233 lock = self.locktype() 234 phase = [] 235 236 def f(): 237 lock.acquire() 238 phase.append(None) 239 lock.acquire() 240 phase.append(None) 241 242 with support.wait_threads_exit(): 243 start_new_thread(f, ()) 244 while len(phase) == 0: 245 _wait() 246 _wait() 247 self.assertEqual(len(phase), 1) 248 lock.release() 249 while len(phase) == 1: 250 _wait() 251 self.assertEqual(len(phase), 2) 252 253 def test_different_thread(self): 254 # Lock can be released from a different thread. 255 lock = self.locktype() 256 lock.acquire() 257 def f(): 258 lock.release() 259 b = Bunch(f, 1) 260 b.wait_for_finished() 261 lock.acquire() 262 lock.release() 263 264 def test_state_after_timeout(self): 265 # Issue #11618: check that lock is in a proper state after a 266 # (non-zero) timeout. 267 lock = self.locktype() 268 lock.acquire() 269 self.assertFalse(lock.acquire(timeout=0.01)) 270 lock.release() 271 self.assertFalse(lock.locked()) 272 self.assertTrue(lock.acquire(blocking=False)) 273 274 @requires_fork 275 def test_at_fork_reinit(self): 276 def use_lock(lock): 277 # make sure that the lock still works normally 278 # after _at_fork_reinit() 279 lock.acquire() 280 lock.release() 281 282 # unlocked 283 lock = self.locktype() 284 lock._at_fork_reinit() 285 use_lock(lock) 286 287 # locked: _at_fork_reinit() resets the lock to the unlocked state 288 lock2 = self.locktype() 289 lock2.acquire() 290 lock2._at_fork_reinit() 291 use_lock(lock2) 292 293 294class RLockTests(BaseLockTests): 295 """ 296 Tests for recursive locks. 297 """ 298 def test_reacquire(self): 299 lock = self.locktype() 300 lock.acquire() 301 lock.acquire() 302 lock.release() 303 lock.acquire() 304 lock.release() 305 lock.release() 306 307 def test_release_unacquired(self): 308 # Cannot release an unacquired lock 309 lock = self.locktype() 310 self.assertRaises(RuntimeError, lock.release) 311 lock.acquire() 312 lock.acquire() 313 lock.release() 314 lock.acquire() 315 lock.release() 316 lock.release() 317 self.assertRaises(RuntimeError, lock.release) 318 319 def test_release_save_unacquired(self): 320 # Cannot _release_save an unacquired lock 321 lock = self.locktype() 322 self.assertRaises(RuntimeError, lock._release_save) 323 lock.acquire() 324 lock.acquire() 325 lock.release() 326 lock.acquire() 327 lock.release() 328 lock.release() 329 self.assertRaises(RuntimeError, lock._release_save) 330 331 def test_different_thread(self): 332 # Cannot release from a different thread 333 lock = self.locktype() 334 def f(): 335 lock.acquire() 336 b = Bunch(f, 1, True) 337 try: 338 self.assertRaises(RuntimeError, lock.release) 339 finally: 340 b.do_finish() 341 b.wait_for_finished() 342 343 def test__is_owned(self): 344 lock = self.locktype() 345 self.assertFalse(lock._is_owned()) 346 lock.acquire() 347 self.assertTrue(lock._is_owned()) 348 lock.acquire() 349 self.assertTrue(lock._is_owned()) 350 result = [] 351 def f(): 352 result.append(lock._is_owned()) 353 Bunch(f, 1).wait_for_finished() 354 self.assertFalse(result[0]) 355 lock.release() 356 self.assertTrue(lock._is_owned()) 357 lock.release() 358 self.assertFalse(lock._is_owned()) 359 360 361class EventTests(BaseTestCase): 362 """ 363 Tests for Event objects. 364 """ 365 366 def test_is_set(self): 367 evt = self.eventtype() 368 self.assertFalse(evt.is_set()) 369 evt.set() 370 self.assertTrue(evt.is_set()) 371 evt.set() 372 self.assertTrue(evt.is_set()) 373 evt.clear() 374 self.assertFalse(evt.is_set()) 375 evt.clear() 376 self.assertFalse(evt.is_set()) 377 378 def _check_notify(self, evt): 379 # All threads get notified 380 N = 5 381 results1 = [] 382 results2 = [] 383 def f(): 384 results1.append(evt.wait()) 385 results2.append(evt.wait()) 386 b = Bunch(f, N) 387 b.wait_for_started() 388 _wait() 389 self.assertEqual(len(results1), 0) 390 evt.set() 391 b.wait_for_finished() 392 self.assertEqual(results1, [True] * N) 393 self.assertEqual(results2, [True] * N) 394 395 def test_notify(self): 396 evt = self.eventtype() 397 self._check_notify(evt) 398 # Another time, after an explicit clear() 399 evt.set() 400 evt.clear() 401 self._check_notify(evt) 402 403 def test_timeout(self): 404 evt = self.eventtype() 405 results1 = [] 406 results2 = [] 407 N = 5 408 def f(): 409 results1.append(evt.wait(0.0)) 410 t1 = time.monotonic() 411 r = evt.wait(0.5) 412 t2 = time.monotonic() 413 results2.append((r, t2 - t1)) 414 Bunch(f, N).wait_for_finished() 415 self.assertEqual(results1, [False] * N) 416 for r, dt in results2: 417 self.assertFalse(r) 418 self.assertTimeout(dt, 0.5) 419 # The event is set 420 results1 = [] 421 results2 = [] 422 evt.set() 423 Bunch(f, N).wait_for_finished() 424 self.assertEqual(results1, [True] * N) 425 for r, dt in results2: 426 self.assertTrue(r) 427 428 def test_set_and_clear(self): 429 # Issue #13502: check that wait() returns true even when the event is 430 # cleared before the waiting thread is woken up. 431 evt = self.eventtype() 432 results = [] 433 timeout = 0.250 434 N = 5 435 def f(): 436 results.append(evt.wait(timeout * 4)) 437 b = Bunch(f, N) 438 b.wait_for_started() 439 time.sleep(timeout) 440 evt.set() 441 evt.clear() 442 b.wait_for_finished() 443 self.assertEqual(results, [True] * N) 444 445 @requires_fork 446 def test_at_fork_reinit(self): 447 # ensure that condition is still using a Lock after reset 448 evt = self.eventtype() 449 with evt._cond: 450 self.assertFalse(evt._cond.acquire(False)) 451 evt._at_fork_reinit() 452 with evt._cond: 453 self.assertFalse(evt._cond.acquire(False)) 454 455 456class ConditionTests(BaseTestCase): 457 """ 458 Tests for condition variables. 459 """ 460 461 def test_acquire(self): 462 cond = self.condtype() 463 # Be default we have an RLock: the condition can be acquired multiple 464 # times. 465 cond.acquire() 466 cond.acquire() 467 cond.release() 468 cond.release() 469 lock = threading.Lock() 470 cond = self.condtype(lock) 471 cond.acquire() 472 self.assertFalse(lock.acquire(False)) 473 cond.release() 474 self.assertTrue(lock.acquire(False)) 475 self.assertFalse(cond.acquire(False)) 476 lock.release() 477 with cond: 478 self.assertFalse(lock.acquire(False)) 479 480 def test_unacquired_wait(self): 481 cond = self.condtype() 482 self.assertRaises(RuntimeError, cond.wait) 483 484 def test_unacquired_notify(self): 485 cond = self.condtype() 486 self.assertRaises(RuntimeError, cond.notify) 487 488 def _check_notify(self, cond): 489 # Note that this test is sensitive to timing. If the worker threads 490 # don't execute in a timely fashion, the main thread may think they 491 # are further along then they are. The main thread therefore issues 492 # _wait() statements to try to make sure that it doesn't race ahead 493 # of the workers. 494 # Secondly, this test assumes that condition variables are not subject 495 # to spurious wakeups. The absence of spurious wakeups is an implementation 496 # detail of Condition Variables in current CPython, but in general, not 497 # a guaranteed property of condition variables as a programming 498 # construct. In particular, it is possible that this can no longer 499 # be conveniently guaranteed should their implementation ever change. 500 N = 5 501 ready = [] 502 results1 = [] 503 results2 = [] 504 phase_num = 0 505 def f(): 506 cond.acquire() 507 ready.append(phase_num) 508 result = cond.wait() 509 cond.release() 510 results1.append((result, phase_num)) 511 cond.acquire() 512 ready.append(phase_num) 513 result = cond.wait() 514 cond.release() 515 results2.append((result, phase_num)) 516 b = Bunch(f, N) 517 b.wait_for_started() 518 # first wait, to ensure all workers settle into cond.wait() before 519 # we continue. See issues #8799 and #30727. 520 while len(ready) < 5: 521 _wait() 522 ready.clear() 523 self.assertEqual(results1, []) 524 # Notify 3 threads at first 525 cond.acquire() 526 cond.notify(3) 527 _wait() 528 phase_num = 1 529 cond.release() 530 while len(results1) < 3: 531 _wait() 532 self.assertEqual(results1, [(True, 1)] * 3) 533 self.assertEqual(results2, []) 534 # make sure all awaken workers settle into cond.wait() 535 while len(ready) < 3: 536 _wait() 537 # Notify 5 threads: they might be in their first or second wait 538 cond.acquire() 539 cond.notify(5) 540 _wait() 541 phase_num = 2 542 cond.release() 543 while len(results1) + len(results2) < 8: 544 _wait() 545 self.assertEqual(results1, [(True, 1)] * 3 + [(True, 2)] * 2) 546 self.assertEqual(results2, [(True, 2)] * 3) 547 # make sure all workers settle into cond.wait() 548 while len(ready) < 5: 549 _wait() 550 # Notify all threads: they are all in their second wait 551 cond.acquire() 552 cond.notify_all() 553 _wait() 554 phase_num = 3 555 cond.release() 556 while len(results2) < 5: 557 _wait() 558 self.assertEqual(results1, [(True, 1)] * 3 + [(True,2)] * 2) 559 self.assertEqual(results2, [(True, 2)] * 3 + [(True, 3)] * 2) 560 b.wait_for_finished() 561 562 def test_notify(self): 563 cond = self.condtype() 564 self._check_notify(cond) 565 # A second time, to check internal state is still ok. 566 self._check_notify(cond) 567 568 def test_timeout(self): 569 cond = self.condtype() 570 results = [] 571 N = 5 572 def f(): 573 cond.acquire() 574 t1 = time.monotonic() 575 result = cond.wait(0.5) 576 t2 = time.monotonic() 577 cond.release() 578 results.append((t2 - t1, result)) 579 Bunch(f, N).wait_for_finished() 580 self.assertEqual(len(results), N) 581 for dt, result in results: 582 self.assertTimeout(dt, 0.5) 583 # Note that conceptually (that"s the condition variable protocol) 584 # a wait() may succeed even if no one notifies us and before any 585 # timeout occurs. Spurious wakeups can occur. 586 # This makes it hard to verify the result value. 587 # In practice, this implementation has no spurious wakeups. 588 self.assertFalse(result) 589 590 def test_waitfor(self): 591 cond = self.condtype() 592 state = 0 593 def f(): 594 with cond: 595 result = cond.wait_for(lambda : state==4) 596 self.assertTrue(result) 597 self.assertEqual(state, 4) 598 b = Bunch(f, 1) 599 b.wait_for_started() 600 for i in range(4): 601 time.sleep(0.01) 602 with cond: 603 state += 1 604 cond.notify() 605 b.wait_for_finished() 606 607 def test_waitfor_timeout(self): 608 cond = self.condtype() 609 state = 0 610 success = [] 611 def f(): 612 with cond: 613 dt = time.monotonic() 614 result = cond.wait_for(lambda : state==4, timeout=0.1) 615 dt = time.monotonic() - dt 616 self.assertFalse(result) 617 self.assertTimeout(dt, 0.1) 618 success.append(None) 619 b = Bunch(f, 1) 620 b.wait_for_started() 621 # Only increment 3 times, so state == 4 is never reached. 622 for i in range(3): 623 time.sleep(0.01) 624 with cond: 625 state += 1 626 cond.notify() 627 b.wait_for_finished() 628 self.assertEqual(len(success), 1) 629 630 631class BaseSemaphoreTests(BaseTestCase): 632 """ 633 Common tests for {bounded, unbounded} semaphore objects. 634 """ 635 636 def test_constructor(self): 637 self.assertRaises(ValueError, self.semtype, value = -1) 638 self.assertRaises(ValueError, self.semtype, value = -sys.maxsize) 639 640 def test_acquire(self): 641 sem = self.semtype(1) 642 sem.acquire() 643 sem.release() 644 sem = self.semtype(2) 645 sem.acquire() 646 sem.acquire() 647 sem.release() 648 sem.release() 649 650 def test_acquire_destroy(self): 651 sem = self.semtype() 652 sem.acquire() 653 del sem 654 655 def test_acquire_contended(self): 656 sem = self.semtype(7) 657 sem.acquire() 658 N = 10 659 sem_results = [] 660 results1 = [] 661 results2 = [] 662 phase_num = 0 663 def f(): 664 sem_results.append(sem.acquire()) 665 results1.append(phase_num) 666 sem_results.append(sem.acquire()) 667 results2.append(phase_num) 668 b = Bunch(f, 10) 669 b.wait_for_started() 670 while len(results1) + len(results2) < 6: 671 _wait() 672 self.assertEqual(results1 + results2, [0] * 6) 673 phase_num = 1 674 for i in range(7): 675 sem.release() 676 while len(results1) + len(results2) < 13: 677 _wait() 678 self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7) 679 phase_num = 2 680 for i in range(6): 681 sem.release() 682 while len(results1) + len(results2) < 19: 683 _wait() 684 self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6) 685 # The semaphore is still locked 686 self.assertFalse(sem.acquire(False)) 687 # Final release, to let the last thread finish 688 sem.release() 689 b.wait_for_finished() 690 self.assertEqual(sem_results, [True] * (6 + 7 + 6 + 1)) 691 692 def test_multirelease(self): 693 sem = self.semtype(7) 694 sem.acquire() 695 results1 = [] 696 results2 = [] 697 phase_num = 0 698 def f(): 699 sem.acquire() 700 results1.append(phase_num) 701 sem.acquire() 702 results2.append(phase_num) 703 b = Bunch(f, 10) 704 b.wait_for_started() 705 while len(results1) + len(results2) < 6: 706 _wait() 707 self.assertEqual(results1 + results2, [0] * 6) 708 phase_num = 1 709 sem.release(7) 710 while len(results1) + len(results2) < 13: 711 _wait() 712 self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7) 713 phase_num = 2 714 sem.release(6) 715 while len(results1) + len(results2) < 19: 716 _wait() 717 self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6) 718 # The semaphore is still locked 719 self.assertFalse(sem.acquire(False)) 720 # Final release, to let the last thread finish 721 sem.release() 722 b.wait_for_finished() 723 724 def test_try_acquire(self): 725 sem = self.semtype(2) 726 self.assertTrue(sem.acquire(False)) 727 self.assertTrue(sem.acquire(False)) 728 self.assertFalse(sem.acquire(False)) 729 sem.release() 730 self.assertTrue(sem.acquire(False)) 731 732 def test_try_acquire_contended(self): 733 sem = self.semtype(4) 734 sem.acquire() 735 results = [] 736 def f(): 737 results.append(sem.acquire(False)) 738 results.append(sem.acquire(False)) 739 Bunch(f, 5).wait_for_finished() 740 # There can be a thread switch between acquiring the semaphore and 741 # appending the result, therefore results will not necessarily be 742 # ordered. 743 self.assertEqual(sorted(results), [False] * 7 + [True] * 3 ) 744 745 def test_acquire_timeout(self): 746 sem = self.semtype(2) 747 self.assertRaises(ValueError, sem.acquire, False, timeout=1.0) 748 self.assertTrue(sem.acquire(timeout=0.005)) 749 self.assertTrue(sem.acquire(timeout=0.005)) 750 self.assertFalse(sem.acquire(timeout=0.005)) 751 sem.release() 752 self.assertTrue(sem.acquire(timeout=0.005)) 753 t = time.monotonic() 754 self.assertFalse(sem.acquire(timeout=0.5)) 755 dt = time.monotonic() - t 756 self.assertTimeout(dt, 0.5) 757 758 def test_default_value(self): 759 # The default initial value is 1. 760 sem = self.semtype() 761 sem.acquire() 762 def f(): 763 sem.acquire() 764 sem.release() 765 b = Bunch(f, 1) 766 b.wait_for_started() 767 _wait() 768 self.assertFalse(b.finished) 769 sem.release() 770 b.wait_for_finished() 771 772 def test_with(self): 773 sem = self.semtype(2) 774 def _with(err=None): 775 with sem: 776 self.assertTrue(sem.acquire(False)) 777 sem.release() 778 with sem: 779 self.assertFalse(sem.acquire(False)) 780 if err: 781 raise err 782 _with() 783 self.assertTrue(sem.acquire(False)) 784 sem.release() 785 self.assertRaises(TypeError, _with, TypeError) 786 self.assertTrue(sem.acquire(False)) 787 sem.release() 788 789class SemaphoreTests(BaseSemaphoreTests): 790 """ 791 Tests for unbounded semaphores. 792 """ 793 794 def test_release_unacquired(self): 795 # Unbounded releases are allowed and increment the semaphore's value 796 sem = self.semtype(1) 797 sem.release() 798 sem.acquire() 799 sem.acquire() 800 sem.release() 801 802 803class BoundedSemaphoreTests(BaseSemaphoreTests): 804 """ 805 Tests for bounded semaphores. 806 """ 807 808 def test_release_unacquired(self): 809 # Cannot go past the initial value 810 sem = self.semtype() 811 self.assertRaises(ValueError, sem.release) 812 sem.acquire() 813 sem.release() 814 self.assertRaises(ValueError, sem.release) 815 816 817class BarrierTests(BaseTestCase): 818 """ 819 Tests for Barrier objects. 820 """ 821 N = 5 822 defaultTimeout = 2.0 823 824 def setUp(self): 825 self.barrier = self.barriertype(self.N, timeout=self.defaultTimeout) 826 def tearDown(self): 827 self.barrier.abort() 828 829 def run_threads(self, f): 830 b = Bunch(f, self.N-1) 831 f() 832 b.wait_for_finished() 833 834 def multipass(self, results, n): 835 m = self.barrier.parties 836 self.assertEqual(m, self.N) 837 for i in range(n): 838 results[0].append(True) 839 self.assertEqual(len(results[1]), i * m) 840 self.barrier.wait() 841 results[1].append(True) 842 self.assertEqual(len(results[0]), (i + 1) * m) 843 self.barrier.wait() 844 self.assertEqual(self.barrier.n_waiting, 0) 845 self.assertFalse(self.barrier.broken) 846 847 def test_barrier(self, passes=1): 848 """ 849 Test that a barrier is passed in lockstep 850 """ 851 results = [[],[]] 852 def f(): 853 self.multipass(results, passes) 854 self.run_threads(f) 855 856 def test_barrier_10(self): 857 """ 858 Test that a barrier works for 10 consecutive runs 859 """ 860 return self.test_barrier(10) 861 862 def test_wait_return(self): 863 """ 864 test the return value from barrier.wait 865 """ 866 results = [] 867 def f(): 868 r = self.barrier.wait() 869 results.append(r) 870 871 self.run_threads(f) 872 self.assertEqual(sum(results), sum(range(self.N))) 873 874 def test_action(self): 875 """ 876 Test the 'action' callback 877 """ 878 results = [] 879 def action(): 880 results.append(True) 881 barrier = self.barriertype(self.N, action) 882 def f(): 883 barrier.wait() 884 self.assertEqual(len(results), 1) 885 886 self.run_threads(f) 887 888 def test_abort(self): 889 """ 890 Test that an abort will put the barrier in a broken state 891 """ 892 results1 = [] 893 results2 = [] 894 def f(): 895 try: 896 i = self.barrier.wait() 897 if i == self.N//2: 898 raise RuntimeError 899 self.barrier.wait() 900 results1.append(True) 901 except threading.BrokenBarrierError: 902 results2.append(True) 903 except RuntimeError: 904 self.barrier.abort() 905 pass 906 907 self.run_threads(f) 908 self.assertEqual(len(results1), 0) 909 self.assertEqual(len(results2), self.N-1) 910 self.assertTrue(self.barrier.broken) 911 912 def test_reset(self): 913 """ 914 Test that a 'reset' on a barrier frees the waiting threads 915 """ 916 results1 = [] 917 results2 = [] 918 results3 = [] 919 def f(): 920 i = self.barrier.wait() 921 if i == self.N//2: 922 # Wait until the other threads are all in the barrier. 923 while self.barrier.n_waiting < self.N-1: 924 time.sleep(0.001) 925 self.barrier.reset() 926 else: 927 try: 928 self.barrier.wait() 929 results1.append(True) 930 except threading.BrokenBarrierError: 931 results2.append(True) 932 # Now, pass the barrier again 933 self.barrier.wait() 934 results3.append(True) 935 936 self.run_threads(f) 937 self.assertEqual(len(results1), 0) 938 self.assertEqual(len(results2), self.N-1) 939 self.assertEqual(len(results3), self.N) 940 941 942 def test_abort_and_reset(self): 943 """ 944 Test that a barrier can be reset after being broken. 945 """ 946 results1 = [] 947 results2 = [] 948 results3 = [] 949 barrier2 = self.barriertype(self.N) 950 def f(): 951 try: 952 i = self.barrier.wait() 953 if i == self.N//2: 954 raise RuntimeError 955 self.barrier.wait() 956 results1.append(True) 957 except threading.BrokenBarrierError: 958 results2.append(True) 959 except RuntimeError: 960 self.barrier.abort() 961 pass 962 # Synchronize and reset the barrier. Must synchronize first so 963 # that everyone has left it when we reset, and after so that no 964 # one enters it before the reset. 965 if barrier2.wait() == self.N//2: 966 self.barrier.reset() 967 barrier2.wait() 968 self.barrier.wait() 969 results3.append(True) 970 971 self.run_threads(f) 972 self.assertEqual(len(results1), 0) 973 self.assertEqual(len(results2), self.N-1) 974 self.assertEqual(len(results3), self.N) 975 976 def test_timeout(self): 977 """ 978 Test wait(timeout) 979 """ 980 def f(): 981 i = self.barrier.wait() 982 if i == self.N // 2: 983 # One thread is late! 984 time.sleep(1.0) 985 # Default timeout is 2.0, so this is shorter. 986 self.assertRaises(threading.BrokenBarrierError, 987 self.barrier.wait, 0.5) 988 self.run_threads(f) 989 990 def test_default_timeout(self): 991 """ 992 Test the barrier's default timeout 993 """ 994 # create a barrier with a low default timeout 995 barrier = self.barriertype(self.N, timeout=0.3) 996 def f(): 997 i = barrier.wait() 998 if i == self.N // 2: 999 # One thread is later than the default timeout of 0.3s. 1000 time.sleep(1.0) 1001 self.assertRaises(threading.BrokenBarrierError, barrier.wait) 1002 self.run_threads(f) 1003 1004 def test_single_thread(self): 1005 b = self.barriertype(1) 1006 b.wait() 1007 b.wait() 1008