1# Some simple queue module tests, plus some failure conditions
2# to ensure the Queue locks remain stable.
3import collections
4import itertools
5import queue
6import random
7import sys
8import threading
9import time
10import unittest
11import weakref
12from test import support
13
14
15try:
16    import _queue
17except ImportError:
18    _queue = None
19
20QUEUE_SIZE = 5
21
22def qfull(q):
23    return q.maxsize > 0 and q.qsize() == q.maxsize
24
25# A thread to run a function that unclogs a blocked Queue.
26class _TriggerThread(threading.Thread):
27    def __init__(self, fn, args):
28        self.fn = fn
29        self.args = args
30        self.startedEvent = threading.Event()
31        threading.Thread.__init__(self)
32
33    def run(self):
34        # The sleep isn't necessary, but is intended to give the blocking
35        # function in the main thread a chance at actually blocking before
36        # we unclog it.  But if the sleep is longer than the timeout-based
37        # tests wait in their blocking functions, those tests will fail.
38        # So we give them much longer timeout values compared to the
39        # sleep here (I aimed at 10 seconds for blocking functions --
40        # they should never actually wait that long - they should make
41        # progress as soon as we call self.fn()).
42        time.sleep(0.1)
43        self.startedEvent.set()
44        self.fn(*self.args)
45
46
47# Execute a function that blocks, and in a separate thread, a function that
48# triggers the release.  Returns the result of the blocking function.  Caution:
49# block_func must guarantee to block until trigger_func is called, and
50# trigger_func must guarantee to change queue state so that block_func can make
51# enough progress to return.  In particular, a block_func that just raises an
52# exception regardless of whether trigger_func is called will lead to
53# timing-dependent sporadic failures, and one of those went rarely seen but
54# undiagnosed for years.  Now block_func must be unexceptional.  If block_func
55# is supposed to raise an exception, call do_exceptional_blocking_test()
56# instead.
57
58class BlockingTestMixin:
59
60    def do_blocking_test(self, block_func, block_args, trigger_func, trigger_args):
61        thread = _TriggerThread(trigger_func, trigger_args)
62        thread.start()
63        try:
64            self.result = block_func(*block_args)
65            # If block_func returned before our thread made the call, we failed!
66            if not thread.startedEvent.is_set():
67                self.fail("blocking function %r appeared not to block" %
68                          block_func)
69            return self.result
70        finally:
71            support.join_thread(thread, 10) # make sure the thread terminates
72
73    # Call this instead if block_func is supposed to raise an exception.
74    def do_exceptional_blocking_test(self,block_func, block_args, trigger_func,
75                                   trigger_args, expected_exception_class):
76        thread = _TriggerThread(trigger_func, trigger_args)
77        thread.start()
78        try:
79            try:
80                block_func(*block_args)
81            except expected_exception_class:
82                raise
83            else:
84                self.fail("expected exception of kind %r" %
85                                 expected_exception_class)
86        finally:
87            support.join_thread(thread, 10) # make sure the thread terminates
88            if not thread.startedEvent.is_set():
89                self.fail("trigger thread ended but event never set")
90
91
92class BaseQueueTestMixin(BlockingTestMixin):
93    def setUp(self):
94        self.cum = 0
95        self.cumlock = threading.Lock()
96
97    def basic_queue_test(self, q):
98        if q.qsize():
99            raise RuntimeError("Call this function with an empty queue")
100        self.assertTrue(q.empty())
101        self.assertFalse(q.full())
102        # I guess we better check things actually queue correctly a little :)
103        q.put(111)
104        q.put(333)
105        q.put(222)
106        target_order = dict(Queue = [111, 333, 222],
107                            LifoQueue = [222, 333, 111],
108                            PriorityQueue = [111, 222, 333])
109        actual_order = [q.get(), q.get(), q.get()]
110        self.assertEqual(actual_order, target_order[q.__class__.__name__],
111                         "Didn't seem to queue the correct data!")
112        for i in range(QUEUE_SIZE-1):
113            q.put(i)
114            self.assertTrue(q.qsize(), "Queue should not be empty")
115        self.assertTrue(not qfull(q), "Queue should not be full")
116        last = 2 * QUEUE_SIZE
117        full = 3 * 2 * QUEUE_SIZE
118        q.put(last)
119        self.assertTrue(qfull(q), "Queue should be full")
120        self.assertFalse(q.empty())
121        self.assertTrue(q.full())
122        try:
123            q.put(full, block=0)
124            self.fail("Didn't appear to block with a full queue")
125        except queue.Full:
126            pass
127        try:
128            q.put(full, timeout=0.01)
129            self.fail("Didn't appear to time-out with a full queue")
130        except queue.Full:
131            pass
132        # Test a blocking put
133        self.do_blocking_test(q.put, (full,), q.get, ())
134        self.do_blocking_test(q.put, (full, True, 10), q.get, ())
135        # Empty it
136        for i in range(QUEUE_SIZE):
137            q.get()
138        self.assertTrue(not q.qsize(), "Queue should be empty")
139        try:
140            q.get(block=0)
141            self.fail("Didn't appear to block with an empty queue")
142        except queue.Empty:
143            pass
144        try:
145            q.get(timeout=0.01)
146            self.fail("Didn't appear to time-out with an empty queue")
147        except queue.Empty:
148            pass
149        # Test a blocking get
150        self.do_blocking_test(q.get, (), q.put, ('empty',))
151        self.do_blocking_test(q.get, (True, 10), q.put, ('empty',))
152
153
154    def worker(self, q):
155        while True:
156            x = q.get()
157            if x < 0:
158                q.task_done()
159                return
160            with self.cumlock:
161                self.cum += x
162            q.task_done()
163
164    def queue_join_test(self, q):
165        self.cum = 0
166        threads = []
167        for i in (0,1):
168            thread = threading.Thread(target=self.worker, args=(q,))
169            thread.start()
170            threads.append(thread)
171        for i in range(100):
172            q.put(i)
173        q.join()
174        self.assertEqual(self.cum, sum(range(100)),
175                         "q.join() did not block until all tasks were done")
176        for i in (0,1):
177            q.put(-1)         # instruct the threads to close
178        q.join()                # verify that you can join twice
179        for thread in threads:
180            thread.join()
181
182    def test_queue_task_done(self):
183        # Test to make sure a queue task completed successfully.
184        q = self.type2test()
185        try:
186            q.task_done()
187        except ValueError:
188            pass
189        else:
190            self.fail("Did not detect task count going negative")
191
192    def test_queue_join(self):
193        # Test that a queue join()s successfully, and before anything else
194        # (done twice for insurance).
195        q = self.type2test()
196        self.queue_join_test(q)
197        self.queue_join_test(q)
198        try:
199            q.task_done()
200        except ValueError:
201            pass
202        else:
203            self.fail("Did not detect task count going negative")
204
205    def test_basic(self):
206        # Do it a couple of times on the same queue.
207        # Done twice to make sure works with same instance reused.
208        q = self.type2test(QUEUE_SIZE)
209        self.basic_queue_test(q)
210        self.basic_queue_test(q)
211
212    def test_negative_timeout_raises_exception(self):
213        q = self.type2test(QUEUE_SIZE)
214        with self.assertRaises(ValueError):
215            q.put(1, timeout=-1)
216        with self.assertRaises(ValueError):
217            q.get(1, timeout=-1)
218
219    def test_nowait(self):
220        q = self.type2test(QUEUE_SIZE)
221        for i in range(QUEUE_SIZE):
222            q.put_nowait(1)
223        with self.assertRaises(queue.Full):
224            q.put_nowait(1)
225
226        for i in range(QUEUE_SIZE):
227            q.get_nowait()
228        with self.assertRaises(queue.Empty):
229            q.get_nowait()
230
231    def test_shrinking_queue(self):
232        # issue 10110
233        q = self.type2test(3)
234        q.put(1)
235        q.put(2)
236        q.put(3)
237        with self.assertRaises(queue.Full):
238            q.put_nowait(4)
239        self.assertEqual(q.qsize(), 3)
240        q.maxsize = 2                       # shrink the queue
241        with self.assertRaises(queue.Full):
242            q.put_nowait(4)
243
244class QueueTest(BaseQueueTestMixin, unittest.TestCase):
245    type2test = queue.Queue
246
247class LifoQueueTest(BaseQueueTestMixin, unittest.TestCase):
248    type2test = queue.LifoQueue
249
250class PriorityQueueTest(BaseQueueTestMixin, unittest.TestCase):
251    type2test = queue.PriorityQueue
252
253
254
255# A Queue subclass that can provoke failure at a moment's notice :)
256class FailingQueueException(Exception):
257    pass
258
259class FailingQueue(queue.Queue):
260    def __init__(self, *args):
261        self.fail_next_put = False
262        self.fail_next_get = False
263        queue.Queue.__init__(self, *args)
264    def _put(self, item):
265        if self.fail_next_put:
266            self.fail_next_put = False
267            raise FailingQueueException("You Lose")
268        return queue.Queue._put(self, item)
269    def _get(self):
270        if self.fail_next_get:
271            self.fail_next_get = False
272            raise FailingQueueException("You Lose")
273        return queue.Queue._get(self)
274
275class FailingQueueTest(BlockingTestMixin, unittest.TestCase):
276
277    def failing_queue_test(self, q):
278        if q.qsize():
279            raise RuntimeError("Call this function with an empty queue")
280        for i in range(QUEUE_SIZE-1):
281            q.put(i)
282        # Test a failing non-blocking put.
283        q.fail_next_put = True
284        try:
285            q.put("oops", block=0)
286            self.fail("The queue didn't fail when it should have")
287        except FailingQueueException:
288            pass
289        q.fail_next_put = True
290        try:
291            q.put("oops", timeout=0.1)
292            self.fail("The queue didn't fail when it should have")
293        except FailingQueueException:
294            pass
295        q.put("last")
296        self.assertTrue(qfull(q), "Queue should be full")
297        # Test a failing blocking put
298        q.fail_next_put = True
299        try:
300            self.do_blocking_test(q.put, ("full",), q.get, ())
301            self.fail("The queue didn't fail when it should have")
302        except FailingQueueException:
303            pass
304        # Check the Queue isn't damaged.
305        # put failed, but get succeeded - re-add
306        q.put("last")
307        # Test a failing timeout put
308        q.fail_next_put = True
309        try:
310            self.do_exceptional_blocking_test(q.put, ("full", True, 10), q.get, (),
311                                              FailingQueueException)
312            self.fail("The queue didn't fail when it should have")
313        except FailingQueueException:
314            pass
315        # Check the Queue isn't damaged.
316        # put failed, but get succeeded - re-add
317        q.put("last")
318        self.assertTrue(qfull(q), "Queue should be full")
319        q.get()
320        self.assertTrue(not qfull(q), "Queue should not be full")
321        q.put("last")
322        self.assertTrue(qfull(q), "Queue should be full")
323        # Test a blocking put
324        self.do_blocking_test(q.put, ("full",), q.get, ())
325        # Empty it
326        for i in range(QUEUE_SIZE):
327            q.get()
328        self.assertTrue(not q.qsize(), "Queue should be empty")
329        q.put("first")
330        q.fail_next_get = True
331        try:
332            q.get()
333            self.fail("The queue didn't fail when it should have")
334        except FailingQueueException:
335            pass
336        self.assertTrue(q.qsize(), "Queue should not be empty")
337        q.fail_next_get = True
338        try:
339            q.get(timeout=0.1)
340            self.fail("The queue didn't fail when it should have")
341        except FailingQueueException:
342            pass
343        self.assertTrue(q.qsize(), "Queue should not be empty")
344        q.get()
345        self.assertTrue(not q.qsize(), "Queue should be empty")
346        q.fail_next_get = True
347        try:
348            self.do_exceptional_blocking_test(q.get, (), q.put, ('empty',),
349                                              FailingQueueException)
350            self.fail("The queue didn't fail when it should have")
351        except FailingQueueException:
352            pass
353        # put succeeded, but get failed.
354        self.assertTrue(q.qsize(), "Queue should not be empty")
355        q.get()
356        self.assertTrue(not q.qsize(), "Queue should be empty")
357
358    def test_failing_queue(self):
359        # Test to make sure a queue is functioning correctly.
360        # Done twice to the same instance.
361        q = FailingQueue(QUEUE_SIZE)
362        self.failing_queue_test(q)
363        self.failing_queue_test(q)
364
365
366class BaseSimpleQueueTest:
367
368    def setUp(self):
369        self.q = self.type2test()
370
371    def feed(self, q, seq, rnd):
372        while True:
373            try:
374                val = seq.pop()
375            except IndexError:
376                return
377            q.put(val)
378            if rnd.random() > 0.5:
379                time.sleep(rnd.random() * 1e-3)
380
381    def consume(self, q, results, sentinel):
382        while True:
383            val = q.get()
384            if val == sentinel:
385                return
386            results.append(val)
387
388    def consume_nonblock(self, q, results, sentinel):
389        while True:
390            while True:
391                try:
392                    val = q.get(block=False)
393                except queue.Empty:
394                    time.sleep(1e-5)
395                else:
396                    break
397            if val == sentinel:
398                return
399            results.append(val)
400
401    def consume_timeout(self, q, results, sentinel):
402        while True:
403            while True:
404                try:
405                    val = q.get(timeout=1e-5)
406                except queue.Empty:
407                    pass
408                else:
409                    break
410            if val == sentinel:
411                return
412            results.append(val)
413
414    def run_threads(self, n_feeders, n_consumers, q, inputs,
415                    feed_func, consume_func):
416        results = []
417        sentinel = None
418        seq = inputs + [sentinel] * n_consumers
419        seq.reverse()
420        rnd = random.Random(42)
421
422        exceptions = []
423        def log_exceptions(f):
424            def wrapper(*args, **kwargs):
425                try:
426                    f(*args, **kwargs)
427                except BaseException as e:
428                    exceptions.append(e)
429            return wrapper
430
431        feeders = [threading.Thread(target=log_exceptions(feed_func),
432                                    args=(q, seq, rnd))
433                   for i in range(n_feeders)]
434        consumers = [threading.Thread(target=log_exceptions(consume_func),
435                                      args=(q, results, sentinel))
436                     for i in range(n_consumers)]
437
438        with support.start_threads(feeders + consumers):
439            pass
440
441        self.assertFalse(exceptions)
442        self.assertTrue(q.empty())
443        self.assertEqual(q.qsize(), 0)
444
445        return results
446
447    def test_basic(self):
448        # Basic tests for get(), put() etc.
449        q = self.q
450        self.assertTrue(q.empty())
451        self.assertEqual(q.qsize(), 0)
452        q.put(1)
453        self.assertFalse(q.empty())
454        self.assertEqual(q.qsize(), 1)
455        q.put(2)
456        q.put_nowait(3)
457        q.put(4)
458        self.assertFalse(q.empty())
459        self.assertEqual(q.qsize(), 4)
460
461        self.assertEqual(q.get(), 1)
462        self.assertEqual(q.qsize(), 3)
463
464        self.assertEqual(q.get_nowait(), 2)
465        self.assertEqual(q.qsize(), 2)
466
467        self.assertEqual(q.get(block=False), 3)
468        self.assertFalse(q.empty())
469        self.assertEqual(q.qsize(), 1)
470
471        self.assertEqual(q.get(timeout=0.1), 4)
472        self.assertTrue(q.empty())
473        self.assertEqual(q.qsize(), 0)
474
475        with self.assertRaises(queue.Empty):
476            q.get(block=False)
477        with self.assertRaises(queue.Empty):
478            q.get(timeout=1e-3)
479        with self.assertRaises(queue.Empty):
480            q.get_nowait()
481        self.assertTrue(q.empty())
482        self.assertEqual(q.qsize(), 0)
483
484    def test_negative_timeout_raises_exception(self):
485        q = self.q
486        q.put(1)
487        with self.assertRaises(ValueError):
488            q.get(timeout=-1)
489
490    def test_order(self):
491        # Test a pair of concurrent put() and get()
492        q = self.q
493        inputs = list(range(100))
494        results = self.run_threads(1, 1, q, inputs, self.feed, self.consume)
495
496        # One producer, one consumer => results appended in well-defined order
497        self.assertEqual(results, inputs)
498
499    def test_many_threads(self):
500        # Test multiple concurrent put() and get()
501        N = 50
502        q = self.q
503        inputs = list(range(10000))
504        results = self.run_threads(N, N, q, inputs, self.feed, self.consume)
505
506        # Multiple consumers without synchronization append the
507        # results in random order
508        self.assertEqual(sorted(results), inputs)
509
510    def test_many_threads_nonblock(self):
511        # Test multiple concurrent put() and get(block=False)
512        N = 50
513        q = self.q
514        inputs = list(range(10000))
515        results = self.run_threads(N, N, q, inputs,
516                                   self.feed, self.consume_nonblock)
517
518        self.assertEqual(sorted(results), inputs)
519
520    def test_many_threads_timeout(self):
521        # Test multiple concurrent put() and get(timeout=...)
522        N = 50
523        q = self.q
524        inputs = list(range(1000))
525        results = self.run_threads(N, N, q, inputs,
526                                   self.feed, self.consume_timeout)
527
528        self.assertEqual(sorted(results), inputs)
529
530    def test_references(self):
531        # The queue should lose references to each item as soon as
532        # it leaves the queue.
533        class C:
534            pass
535
536        N = 20
537        q = self.q
538        for i in range(N):
539            q.put(C())
540        for i in range(N):
541            wr = weakref.ref(q.get())
542            self.assertIsNone(wr())
543
544
545class PySimpleQueueTest(BaseSimpleQueueTest, unittest.TestCase):
546    type2test = queue._PySimpleQueue
547
548
549@unittest.skipIf(_queue is None, "No _queue module found")
550class CSimpleQueueTest(BaseSimpleQueueTest, unittest.TestCase):
551
552    def setUp(self):
553        self.type2test = _queue.SimpleQueue
554        super().setUp()
555
556    def test_is_default(self):
557        self.assertIs(self.type2test, queue.SimpleQueue)
558
559    def test_reentrancy(self):
560        # bpo-14976: put() may be called reentrantly in an asynchronous
561        # callback.
562        q = self.q
563        gen = itertools.count()
564        N = 10000
565        results = []
566
567        # This test exploits the fact that __del__ in a reference cycle
568        # can be called any time the GC may run.
569
570        class Circular(object):
571            def __init__(self):
572                self.circular = self
573
574            def __del__(self):
575                q.put(next(gen))
576
577        while True:
578            o = Circular()
579            q.put(next(gen))
580            del o
581            results.append(q.get())
582            if results[-1] >= N:
583                break
584
585        self.assertEqual(results, list(range(N + 1)))
586
587
588if __name__ == "__main__":
589    unittest.main()
590