1"""
2Various tests for synchronization primitives.
3"""
4
5import os
6import sys
7import time
8from _thread import start_new_thread, TIMEOUT_MAX
9import threading
10import unittest
11import weakref
12
13from test import support
14
15
16requires_fork = unittest.skipUnless(hasattr(os, 'fork'),
17                                    "platform doesn't support fork "
18                                     "(no _at_fork_reinit method)")
19
20
21def _wait():
22    # A crude wait/yield function not relying on synchronization primitives.
23    time.sleep(0.01)
24
25class Bunch(object):
26    """
27    A bunch of threads.
28    """
29    def __init__(self, f, n, wait_before_exit=False):
30        """
31        Construct a bunch of `n` threads running the same function `f`.
32        If `wait_before_exit` is True, the threads won't terminate until
33        do_finish() is called.
34        """
35        self.f = f
36        self.n = n
37        self.started = []
38        self.finished = []
39        self._can_exit = not wait_before_exit
40        self.wait_thread = support.wait_threads_exit()
41        self.wait_thread.__enter__()
42
43        def task():
44            tid = threading.get_ident()
45            self.started.append(tid)
46            try:
47                f()
48            finally:
49                self.finished.append(tid)
50                while not self._can_exit:
51                    _wait()
52
53        try:
54            for i in range(n):
55                start_new_thread(task, ())
56        except:
57            self._can_exit = True
58            raise
59
60    def wait_for_started(self):
61        while len(self.started) < self.n:
62            _wait()
63
64    def wait_for_finished(self):
65        while len(self.finished) < self.n:
66            _wait()
67        # Wait for threads exit
68        self.wait_thread.__exit__(None, None, None)
69
70    def do_finish(self):
71        self._can_exit = True
72
73
74class BaseTestCase(unittest.TestCase):
75    def setUp(self):
76        self._threads = support.threading_setup()
77
78    def tearDown(self):
79        support.threading_cleanup(*self._threads)
80        support.reap_children()
81
82    def assertTimeout(self, actual, expected):
83        # The waiting and/or time.monotonic() can be imprecise, which
84        # is why comparing to the expected value would sometimes fail
85        # (especially under Windows).
86        self.assertGreaterEqual(actual, expected * 0.6)
87        # Test nothing insane happened
88        self.assertLess(actual, expected * 10.0)
89
90
91class BaseLockTests(BaseTestCase):
92    """
93    Tests for both recursive and non-recursive locks.
94    """
95
96    def test_constructor(self):
97        lock = self.locktype()
98        del lock
99
100    def test_repr(self):
101        lock = self.locktype()
102        self.assertRegex(repr(lock), "<unlocked .* object (.*)?at .*>")
103        del lock
104
105    def test_locked_repr(self):
106        lock = self.locktype()
107        lock.acquire()
108        self.assertRegex(repr(lock), "<locked .* object (.*)?at .*>")
109        del lock
110
111    def test_acquire_destroy(self):
112        lock = self.locktype()
113        lock.acquire()
114        del lock
115
116    def test_acquire_release(self):
117        lock = self.locktype()
118        lock.acquire()
119        lock.release()
120        del lock
121
122    def test_try_acquire(self):
123        lock = self.locktype()
124        self.assertTrue(lock.acquire(False))
125        lock.release()
126
127    def test_try_acquire_contended(self):
128        lock = self.locktype()
129        lock.acquire()
130        result = []
131        def f():
132            result.append(lock.acquire(False))
133        Bunch(f, 1).wait_for_finished()
134        self.assertFalse(result[0])
135        lock.release()
136
137    def test_acquire_contended(self):
138        lock = self.locktype()
139        lock.acquire()
140        N = 5
141        def f():
142            lock.acquire()
143            lock.release()
144
145        b = Bunch(f, N)
146        b.wait_for_started()
147        _wait()
148        self.assertEqual(len(b.finished), 0)
149        lock.release()
150        b.wait_for_finished()
151        self.assertEqual(len(b.finished), N)
152
153    def test_with(self):
154        lock = self.locktype()
155        def f():
156            lock.acquire()
157            lock.release()
158        def _with(err=None):
159            with lock:
160                if err is not None:
161                    raise err
162        _with()
163        # Check the lock is unacquired
164        Bunch(f, 1).wait_for_finished()
165        self.assertRaises(TypeError, _with, TypeError)
166        # Check the lock is unacquired
167        Bunch(f, 1).wait_for_finished()
168
169    def test_thread_leak(self):
170        # The lock shouldn't leak a Thread instance when used from a foreign
171        # (non-threading) thread.
172        lock = self.locktype()
173        def f():
174            lock.acquire()
175            lock.release()
176        n = len(threading.enumerate())
177        # We run many threads in the hope that existing threads ids won't
178        # be recycled.
179        Bunch(f, 15).wait_for_finished()
180        if len(threading.enumerate()) != n:
181            # There is a small window during which a Thread instance's
182            # target function has finished running, but the Thread is still
183            # alive and registered.  Avoid spurious failures by waiting a
184            # bit more (seen on a buildbot).
185            time.sleep(0.4)
186            self.assertEqual(n, len(threading.enumerate()))
187
188    def test_timeout(self):
189        lock = self.locktype()
190        # Can't set timeout if not blocking
191        self.assertRaises(ValueError, lock.acquire, False, 1)
192        # Invalid timeout values
193        self.assertRaises(ValueError, lock.acquire, timeout=-100)
194        self.assertRaises(OverflowError, lock.acquire, timeout=1e100)
195        self.assertRaises(OverflowError, lock.acquire, timeout=TIMEOUT_MAX + 1)
196        # TIMEOUT_MAX is ok
197        lock.acquire(timeout=TIMEOUT_MAX)
198        lock.release()
199        t1 = time.monotonic()
200        self.assertTrue(lock.acquire(timeout=5))
201        t2 = time.monotonic()
202        # Just a sanity test that it didn't actually wait for the timeout.
203        self.assertLess(t2 - t1, 5)
204        results = []
205        def f():
206            t1 = time.monotonic()
207            results.append(lock.acquire(timeout=0.5))
208            t2 = time.monotonic()
209            results.append(t2 - t1)
210        Bunch(f, 1).wait_for_finished()
211        self.assertFalse(results[0])
212        self.assertTimeout(results[1], 0.5)
213
214    def test_weakref_exists(self):
215        lock = self.locktype()
216        ref = weakref.ref(lock)
217        self.assertIsNotNone(ref())
218
219    def test_weakref_deleted(self):
220        lock = self.locktype()
221        ref = weakref.ref(lock)
222        del lock
223        self.assertIsNone(ref())
224
225
226class LockTests(BaseLockTests):
227    """
228    Tests for non-recursive, weak locks
229    (which can be acquired and released from different threads).
230    """
231    def test_reacquire(self):
232        # Lock needs to be released before re-acquiring.
233        lock = self.locktype()
234        phase = []
235
236        def f():
237            lock.acquire()
238            phase.append(None)
239            lock.acquire()
240            phase.append(None)
241
242        with support.wait_threads_exit():
243            start_new_thread(f, ())
244            while len(phase) == 0:
245                _wait()
246            _wait()
247            self.assertEqual(len(phase), 1)
248            lock.release()
249            while len(phase) == 1:
250                _wait()
251            self.assertEqual(len(phase), 2)
252
253    def test_different_thread(self):
254        # Lock can be released from a different thread.
255        lock = self.locktype()
256        lock.acquire()
257        def f():
258            lock.release()
259        b = Bunch(f, 1)
260        b.wait_for_finished()
261        lock.acquire()
262        lock.release()
263
264    def test_state_after_timeout(self):
265        # Issue #11618: check that lock is in a proper state after a
266        # (non-zero) timeout.
267        lock = self.locktype()
268        lock.acquire()
269        self.assertFalse(lock.acquire(timeout=0.01))
270        lock.release()
271        self.assertFalse(lock.locked())
272        self.assertTrue(lock.acquire(blocking=False))
273
274    @requires_fork
275    def test_at_fork_reinit(self):
276        def use_lock(lock):
277            # make sure that the lock still works normally
278            # after _at_fork_reinit()
279            lock.acquire()
280            lock.release()
281
282        # unlocked
283        lock = self.locktype()
284        lock._at_fork_reinit()
285        use_lock(lock)
286
287        # locked: _at_fork_reinit() resets the lock to the unlocked state
288        lock2 = self.locktype()
289        lock2.acquire()
290        lock2._at_fork_reinit()
291        use_lock(lock2)
292
293
294class RLockTests(BaseLockTests):
295    """
296    Tests for recursive locks.
297    """
298    def test_reacquire(self):
299        lock = self.locktype()
300        lock.acquire()
301        lock.acquire()
302        lock.release()
303        lock.acquire()
304        lock.release()
305        lock.release()
306
307    def test_release_unacquired(self):
308        # Cannot release an unacquired lock
309        lock = self.locktype()
310        self.assertRaises(RuntimeError, lock.release)
311        lock.acquire()
312        lock.acquire()
313        lock.release()
314        lock.acquire()
315        lock.release()
316        lock.release()
317        self.assertRaises(RuntimeError, lock.release)
318
319    def test_release_save_unacquired(self):
320        # Cannot _release_save an unacquired lock
321        lock = self.locktype()
322        self.assertRaises(RuntimeError, lock._release_save)
323        lock.acquire()
324        lock.acquire()
325        lock.release()
326        lock.acquire()
327        lock.release()
328        lock.release()
329        self.assertRaises(RuntimeError, lock._release_save)
330
331    def test_different_thread(self):
332        # Cannot release from a different thread
333        lock = self.locktype()
334        def f():
335            lock.acquire()
336        b = Bunch(f, 1, True)
337        try:
338            self.assertRaises(RuntimeError, lock.release)
339        finally:
340            b.do_finish()
341        b.wait_for_finished()
342
343    def test__is_owned(self):
344        lock = self.locktype()
345        self.assertFalse(lock._is_owned())
346        lock.acquire()
347        self.assertTrue(lock._is_owned())
348        lock.acquire()
349        self.assertTrue(lock._is_owned())
350        result = []
351        def f():
352            result.append(lock._is_owned())
353        Bunch(f, 1).wait_for_finished()
354        self.assertFalse(result[0])
355        lock.release()
356        self.assertTrue(lock._is_owned())
357        lock.release()
358        self.assertFalse(lock._is_owned())
359
360
361class EventTests(BaseTestCase):
362    """
363    Tests for Event objects.
364    """
365
366    def test_is_set(self):
367        evt = self.eventtype()
368        self.assertFalse(evt.is_set())
369        evt.set()
370        self.assertTrue(evt.is_set())
371        evt.set()
372        self.assertTrue(evt.is_set())
373        evt.clear()
374        self.assertFalse(evt.is_set())
375        evt.clear()
376        self.assertFalse(evt.is_set())
377
378    def _check_notify(self, evt):
379        # All threads get notified
380        N = 5
381        results1 = []
382        results2 = []
383        def f():
384            results1.append(evt.wait())
385            results2.append(evt.wait())
386        b = Bunch(f, N)
387        b.wait_for_started()
388        _wait()
389        self.assertEqual(len(results1), 0)
390        evt.set()
391        b.wait_for_finished()
392        self.assertEqual(results1, [True] * N)
393        self.assertEqual(results2, [True] * N)
394
395    def test_notify(self):
396        evt = self.eventtype()
397        self._check_notify(evt)
398        # Another time, after an explicit clear()
399        evt.set()
400        evt.clear()
401        self._check_notify(evt)
402
403    def test_timeout(self):
404        evt = self.eventtype()
405        results1 = []
406        results2 = []
407        N = 5
408        def f():
409            results1.append(evt.wait(0.0))
410            t1 = time.monotonic()
411            r = evt.wait(0.5)
412            t2 = time.monotonic()
413            results2.append((r, t2 - t1))
414        Bunch(f, N).wait_for_finished()
415        self.assertEqual(results1, [False] * N)
416        for r, dt in results2:
417            self.assertFalse(r)
418            self.assertTimeout(dt, 0.5)
419        # The event is set
420        results1 = []
421        results2 = []
422        evt.set()
423        Bunch(f, N).wait_for_finished()
424        self.assertEqual(results1, [True] * N)
425        for r, dt in results2:
426            self.assertTrue(r)
427
428    def test_set_and_clear(self):
429        # Issue #13502: check that wait() returns true even when the event is
430        # cleared before the waiting thread is woken up.
431        evt = self.eventtype()
432        results = []
433        timeout = 0.250
434        N = 5
435        def f():
436            results.append(evt.wait(timeout * 4))
437        b = Bunch(f, N)
438        b.wait_for_started()
439        time.sleep(timeout)
440        evt.set()
441        evt.clear()
442        b.wait_for_finished()
443        self.assertEqual(results, [True] * N)
444
445    @requires_fork
446    def test_at_fork_reinit(self):
447        # ensure that condition is still using a Lock after reset
448        evt = self.eventtype()
449        with evt._cond:
450            self.assertFalse(evt._cond.acquire(False))
451        evt._at_fork_reinit()
452        with evt._cond:
453            self.assertFalse(evt._cond.acquire(False))
454
455
456class ConditionTests(BaseTestCase):
457    """
458    Tests for condition variables.
459    """
460
461    def test_acquire(self):
462        cond = self.condtype()
463        # Be default we have an RLock: the condition can be acquired multiple
464        # times.
465        cond.acquire()
466        cond.acquire()
467        cond.release()
468        cond.release()
469        lock = threading.Lock()
470        cond = self.condtype(lock)
471        cond.acquire()
472        self.assertFalse(lock.acquire(False))
473        cond.release()
474        self.assertTrue(lock.acquire(False))
475        self.assertFalse(cond.acquire(False))
476        lock.release()
477        with cond:
478            self.assertFalse(lock.acquire(False))
479
480    def test_unacquired_wait(self):
481        cond = self.condtype()
482        self.assertRaises(RuntimeError, cond.wait)
483
484    def test_unacquired_notify(self):
485        cond = self.condtype()
486        self.assertRaises(RuntimeError, cond.notify)
487
488    def _check_notify(self, cond):
489        # Note that this test is sensitive to timing.  If the worker threads
490        # don't execute in a timely fashion, the main thread may think they
491        # are further along then they are.  The main thread therefore issues
492        # _wait() statements to try to make sure that it doesn't race ahead
493        # of the workers.
494        # Secondly, this test assumes that condition variables are not subject
495        # to spurious wakeups.  The absence of spurious wakeups is an implementation
496        # detail of Condition Variables in current CPython, but in general, not
497        # a guaranteed property of condition variables as a programming
498        # construct.  In particular, it is possible that this can no longer
499        # be conveniently guaranteed should their implementation ever change.
500        N = 5
501        ready = []
502        results1 = []
503        results2 = []
504        phase_num = 0
505        def f():
506            cond.acquire()
507            ready.append(phase_num)
508            result = cond.wait()
509            cond.release()
510            results1.append((result, phase_num))
511            cond.acquire()
512            ready.append(phase_num)
513            result = cond.wait()
514            cond.release()
515            results2.append((result, phase_num))
516        b = Bunch(f, N)
517        b.wait_for_started()
518        # first wait, to ensure all workers settle into cond.wait() before
519        # we continue. See issues #8799 and #30727.
520        while len(ready) < 5:
521            _wait()
522        ready.clear()
523        self.assertEqual(results1, [])
524        # Notify 3 threads at first
525        cond.acquire()
526        cond.notify(3)
527        _wait()
528        phase_num = 1
529        cond.release()
530        while len(results1) < 3:
531            _wait()
532        self.assertEqual(results1, [(True, 1)] * 3)
533        self.assertEqual(results2, [])
534        # make sure all awaken workers settle into cond.wait()
535        while len(ready) < 3:
536            _wait()
537        # Notify 5 threads: they might be in their first or second wait
538        cond.acquire()
539        cond.notify(5)
540        _wait()
541        phase_num = 2
542        cond.release()
543        while len(results1) + len(results2) < 8:
544            _wait()
545        self.assertEqual(results1, [(True, 1)] * 3 + [(True, 2)] * 2)
546        self.assertEqual(results2, [(True, 2)] * 3)
547        # make sure all workers settle into cond.wait()
548        while len(ready) < 5:
549            _wait()
550        # Notify all threads: they are all in their second wait
551        cond.acquire()
552        cond.notify_all()
553        _wait()
554        phase_num = 3
555        cond.release()
556        while len(results2) < 5:
557            _wait()
558        self.assertEqual(results1, [(True, 1)] * 3 + [(True,2)] * 2)
559        self.assertEqual(results2, [(True, 2)] * 3 + [(True, 3)] * 2)
560        b.wait_for_finished()
561
562    def test_notify(self):
563        cond = self.condtype()
564        self._check_notify(cond)
565        # A second time, to check internal state is still ok.
566        self._check_notify(cond)
567
568    def test_timeout(self):
569        cond = self.condtype()
570        results = []
571        N = 5
572        def f():
573            cond.acquire()
574            t1 = time.monotonic()
575            result = cond.wait(0.5)
576            t2 = time.monotonic()
577            cond.release()
578            results.append((t2 - t1, result))
579        Bunch(f, N).wait_for_finished()
580        self.assertEqual(len(results), N)
581        for dt, result in results:
582            self.assertTimeout(dt, 0.5)
583            # Note that conceptually (that"s the condition variable protocol)
584            # a wait() may succeed even if no one notifies us and before any
585            # timeout occurs.  Spurious wakeups can occur.
586            # This makes it hard to verify the result value.
587            # In practice, this implementation has no spurious wakeups.
588            self.assertFalse(result)
589
590    def test_waitfor(self):
591        cond = self.condtype()
592        state = 0
593        def f():
594            with cond:
595                result = cond.wait_for(lambda : state==4)
596                self.assertTrue(result)
597                self.assertEqual(state, 4)
598        b = Bunch(f, 1)
599        b.wait_for_started()
600        for i in range(4):
601            time.sleep(0.01)
602            with cond:
603                state += 1
604                cond.notify()
605        b.wait_for_finished()
606
607    def test_waitfor_timeout(self):
608        cond = self.condtype()
609        state = 0
610        success = []
611        def f():
612            with cond:
613                dt = time.monotonic()
614                result = cond.wait_for(lambda : state==4, timeout=0.1)
615                dt = time.monotonic() - dt
616                self.assertFalse(result)
617                self.assertTimeout(dt, 0.1)
618                success.append(None)
619        b = Bunch(f, 1)
620        b.wait_for_started()
621        # Only increment 3 times, so state == 4 is never reached.
622        for i in range(3):
623            time.sleep(0.01)
624            with cond:
625                state += 1
626                cond.notify()
627        b.wait_for_finished()
628        self.assertEqual(len(success), 1)
629
630
631class BaseSemaphoreTests(BaseTestCase):
632    """
633    Common tests for {bounded, unbounded} semaphore objects.
634    """
635
636    def test_constructor(self):
637        self.assertRaises(ValueError, self.semtype, value = -1)
638        self.assertRaises(ValueError, self.semtype, value = -sys.maxsize)
639
640    def test_acquire(self):
641        sem = self.semtype(1)
642        sem.acquire()
643        sem.release()
644        sem = self.semtype(2)
645        sem.acquire()
646        sem.acquire()
647        sem.release()
648        sem.release()
649
650    def test_acquire_destroy(self):
651        sem = self.semtype()
652        sem.acquire()
653        del sem
654
655    def test_acquire_contended(self):
656        sem = self.semtype(7)
657        sem.acquire()
658        N = 10
659        sem_results = []
660        results1 = []
661        results2 = []
662        phase_num = 0
663        def f():
664            sem_results.append(sem.acquire())
665            results1.append(phase_num)
666            sem_results.append(sem.acquire())
667            results2.append(phase_num)
668        b = Bunch(f, 10)
669        b.wait_for_started()
670        while len(results1) + len(results2) < 6:
671            _wait()
672        self.assertEqual(results1 + results2, [0] * 6)
673        phase_num = 1
674        for i in range(7):
675            sem.release()
676        while len(results1) + len(results2) < 13:
677            _wait()
678        self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7)
679        phase_num = 2
680        for i in range(6):
681            sem.release()
682        while len(results1) + len(results2) < 19:
683            _wait()
684        self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6)
685        # The semaphore is still locked
686        self.assertFalse(sem.acquire(False))
687        # Final release, to let the last thread finish
688        sem.release()
689        b.wait_for_finished()
690        self.assertEqual(sem_results, [True] * (6 + 7 + 6 + 1))
691
692    def test_multirelease(self):
693        sem = self.semtype(7)
694        sem.acquire()
695        results1 = []
696        results2 = []
697        phase_num = 0
698        def f():
699            sem.acquire()
700            results1.append(phase_num)
701            sem.acquire()
702            results2.append(phase_num)
703        b = Bunch(f, 10)
704        b.wait_for_started()
705        while len(results1) + len(results2) < 6:
706            _wait()
707        self.assertEqual(results1 + results2, [0] * 6)
708        phase_num = 1
709        sem.release(7)
710        while len(results1) + len(results2) < 13:
711            _wait()
712        self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7)
713        phase_num = 2
714        sem.release(6)
715        while len(results1) + len(results2) < 19:
716            _wait()
717        self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6)
718        # The semaphore is still locked
719        self.assertFalse(sem.acquire(False))
720        # Final release, to let the last thread finish
721        sem.release()
722        b.wait_for_finished()
723
724    def test_try_acquire(self):
725        sem = self.semtype(2)
726        self.assertTrue(sem.acquire(False))
727        self.assertTrue(sem.acquire(False))
728        self.assertFalse(sem.acquire(False))
729        sem.release()
730        self.assertTrue(sem.acquire(False))
731
732    def test_try_acquire_contended(self):
733        sem = self.semtype(4)
734        sem.acquire()
735        results = []
736        def f():
737            results.append(sem.acquire(False))
738            results.append(sem.acquire(False))
739        Bunch(f, 5).wait_for_finished()
740        # There can be a thread switch between acquiring the semaphore and
741        # appending the result, therefore results will not necessarily be
742        # ordered.
743        self.assertEqual(sorted(results), [False] * 7 + [True] *  3 )
744
745    def test_acquire_timeout(self):
746        sem = self.semtype(2)
747        self.assertRaises(ValueError, sem.acquire, False, timeout=1.0)
748        self.assertTrue(sem.acquire(timeout=0.005))
749        self.assertTrue(sem.acquire(timeout=0.005))
750        self.assertFalse(sem.acquire(timeout=0.005))
751        sem.release()
752        self.assertTrue(sem.acquire(timeout=0.005))
753        t = time.monotonic()
754        self.assertFalse(sem.acquire(timeout=0.5))
755        dt = time.monotonic() - t
756        self.assertTimeout(dt, 0.5)
757
758    def test_default_value(self):
759        # The default initial value is 1.
760        sem = self.semtype()
761        sem.acquire()
762        def f():
763            sem.acquire()
764            sem.release()
765        b = Bunch(f, 1)
766        b.wait_for_started()
767        _wait()
768        self.assertFalse(b.finished)
769        sem.release()
770        b.wait_for_finished()
771
772    def test_with(self):
773        sem = self.semtype(2)
774        def _with(err=None):
775            with sem:
776                self.assertTrue(sem.acquire(False))
777                sem.release()
778                with sem:
779                    self.assertFalse(sem.acquire(False))
780                    if err:
781                        raise err
782        _with()
783        self.assertTrue(sem.acquire(False))
784        sem.release()
785        self.assertRaises(TypeError, _with, TypeError)
786        self.assertTrue(sem.acquire(False))
787        sem.release()
788
789class SemaphoreTests(BaseSemaphoreTests):
790    """
791    Tests for unbounded semaphores.
792    """
793
794    def test_release_unacquired(self):
795        # Unbounded releases are allowed and increment the semaphore's value
796        sem = self.semtype(1)
797        sem.release()
798        sem.acquire()
799        sem.acquire()
800        sem.release()
801
802
803class BoundedSemaphoreTests(BaseSemaphoreTests):
804    """
805    Tests for bounded semaphores.
806    """
807
808    def test_release_unacquired(self):
809        # Cannot go past the initial value
810        sem = self.semtype()
811        self.assertRaises(ValueError, sem.release)
812        sem.acquire()
813        sem.release()
814        self.assertRaises(ValueError, sem.release)
815
816
817class BarrierTests(BaseTestCase):
818    """
819    Tests for Barrier objects.
820    """
821    N = 5
822    defaultTimeout = 2.0
823
824    def setUp(self):
825        self.barrier = self.barriertype(self.N, timeout=self.defaultTimeout)
826    def tearDown(self):
827        self.barrier.abort()
828
829    def run_threads(self, f):
830        b = Bunch(f, self.N-1)
831        f()
832        b.wait_for_finished()
833
834    def multipass(self, results, n):
835        m = self.barrier.parties
836        self.assertEqual(m, self.N)
837        for i in range(n):
838            results[0].append(True)
839            self.assertEqual(len(results[1]), i * m)
840            self.barrier.wait()
841            results[1].append(True)
842            self.assertEqual(len(results[0]), (i + 1) * m)
843            self.barrier.wait()
844        self.assertEqual(self.barrier.n_waiting, 0)
845        self.assertFalse(self.barrier.broken)
846
847    def test_barrier(self, passes=1):
848        """
849        Test that a barrier is passed in lockstep
850        """
851        results = [[],[]]
852        def f():
853            self.multipass(results, passes)
854        self.run_threads(f)
855
856    def test_barrier_10(self):
857        """
858        Test that a barrier works for 10 consecutive runs
859        """
860        return self.test_barrier(10)
861
862    def test_wait_return(self):
863        """
864        test the return value from barrier.wait
865        """
866        results = []
867        def f():
868            r = self.barrier.wait()
869            results.append(r)
870
871        self.run_threads(f)
872        self.assertEqual(sum(results), sum(range(self.N)))
873
874    def test_action(self):
875        """
876        Test the 'action' callback
877        """
878        results = []
879        def action():
880            results.append(True)
881        barrier = self.barriertype(self.N, action)
882        def f():
883            barrier.wait()
884            self.assertEqual(len(results), 1)
885
886        self.run_threads(f)
887
888    def test_abort(self):
889        """
890        Test that an abort will put the barrier in a broken state
891        """
892        results1 = []
893        results2 = []
894        def f():
895            try:
896                i = self.barrier.wait()
897                if i == self.N//2:
898                    raise RuntimeError
899                self.barrier.wait()
900                results1.append(True)
901            except threading.BrokenBarrierError:
902                results2.append(True)
903            except RuntimeError:
904                self.barrier.abort()
905                pass
906
907        self.run_threads(f)
908        self.assertEqual(len(results1), 0)
909        self.assertEqual(len(results2), self.N-1)
910        self.assertTrue(self.barrier.broken)
911
912    def test_reset(self):
913        """
914        Test that a 'reset' on a barrier frees the waiting threads
915        """
916        results1 = []
917        results2 = []
918        results3 = []
919        def f():
920            i = self.barrier.wait()
921            if i == self.N//2:
922                # Wait until the other threads are all in the barrier.
923                while self.barrier.n_waiting < self.N-1:
924                    time.sleep(0.001)
925                self.barrier.reset()
926            else:
927                try:
928                    self.barrier.wait()
929                    results1.append(True)
930                except threading.BrokenBarrierError:
931                    results2.append(True)
932            # Now, pass the barrier again
933            self.barrier.wait()
934            results3.append(True)
935
936        self.run_threads(f)
937        self.assertEqual(len(results1), 0)
938        self.assertEqual(len(results2), self.N-1)
939        self.assertEqual(len(results3), self.N)
940
941
942    def test_abort_and_reset(self):
943        """
944        Test that a barrier can be reset after being broken.
945        """
946        results1 = []
947        results2 = []
948        results3 = []
949        barrier2 = self.barriertype(self.N)
950        def f():
951            try:
952                i = self.barrier.wait()
953                if i == self.N//2:
954                    raise RuntimeError
955                self.barrier.wait()
956                results1.append(True)
957            except threading.BrokenBarrierError:
958                results2.append(True)
959            except RuntimeError:
960                self.barrier.abort()
961                pass
962            # Synchronize and reset the barrier.  Must synchronize first so
963            # that everyone has left it when we reset, and after so that no
964            # one enters it before the reset.
965            if barrier2.wait() == self.N//2:
966                self.barrier.reset()
967            barrier2.wait()
968            self.barrier.wait()
969            results3.append(True)
970
971        self.run_threads(f)
972        self.assertEqual(len(results1), 0)
973        self.assertEqual(len(results2), self.N-1)
974        self.assertEqual(len(results3), self.N)
975
976    def test_timeout(self):
977        """
978        Test wait(timeout)
979        """
980        def f():
981            i = self.barrier.wait()
982            if i == self.N // 2:
983                # One thread is late!
984                time.sleep(1.0)
985            # Default timeout is 2.0, so this is shorter.
986            self.assertRaises(threading.BrokenBarrierError,
987                              self.barrier.wait, 0.5)
988        self.run_threads(f)
989
990    def test_default_timeout(self):
991        """
992        Test the barrier's default timeout
993        """
994        # create a barrier with a low default timeout
995        barrier = self.barriertype(self.N, timeout=0.3)
996        def f():
997            i = barrier.wait()
998            if i == self.N // 2:
999                # One thread is later than the default timeout of 0.3s.
1000                time.sleep(1.0)
1001            self.assertRaises(threading.BrokenBarrierError, barrier.wait)
1002        self.run_threads(f)
1003
1004    def test_single_thread(self):
1005        b = self.barriertype(1)
1006        b.wait()
1007        b.wait()
1008