1import asyncio
2import unittest
3
4from unittest import mock
5from . import utils as test_utils
6
7
8class TestPolicy(asyncio.AbstractEventLoopPolicy):
9
10    def __init__(self, loop_factory):
11        self.loop_factory = loop_factory
12        self.loop = None
13
14    def get_event_loop(self):
15        # shouldn't ever be called by asyncio.run()
16        raise RuntimeError
17
18    def new_event_loop(self):
19        return self.loop_factory()
20
21    def set_event_loop(self, loop):
22        if loop is not None:
23            # we want to check if the loop is closed
24            # in BaseTest.tearDown
25            self.loop = loop
26
27
28class BaseTest(unittest.TestCase):
29
30    def new_loop(self):
31        loop = asyncio.BaseEventLoop()
32        loop._process_events = mock.Mock()
33        loop._selector = mock.Mock()
34        loop._selector.select.return_value = ()
35        loop.shutdown_ag_run = False
36
37        async def shutdown_asyncgens():
38            loop.shutdown_ag_run = True
39        loop.shutdown_asyncgens = shutdown_asyncgens
40
41        return loop
42
43    def setUp(self):
44        super().setUp()
45
46        policy = TestPolicy(self.new_loop)
47        asyncio.set_event_loop_policy(policy)
48
49    def tearDown(self):
50        policy = asyncio.get_event_loop_policy()
51        if policy.loop is not None:
52            self.assertTrue(policy.loop.is_closed())
53            self.assertTrue(policy.loop.shutdown_ag_run)
54
55        asyncio.set_event_loop_policy(None)
56        super().tearDown()
57
58
59class RunTests(BaseTest):
60
61    def test_asyncio_run_return(self):
62        async def main():
63            await asyncio.sleep(0)
64            return 42
65
66        self.assertEqual(asyncio.run(main()), 42)
67
68    def test_asyncio_run_raises(self):
69        async def main():
70            await asyncio.sleep(0)
71            raise ValueError('spam')
72
73        with self.assertRaisesRegex(ValueError, 'spam'):
74            asyncio.run(main())
75
76    def test_asyncio_run_only_coro(self):
77        for o in {1, lambda: None}:
78            with self.subTest(obj=o), \
79                    self.assertRaisesRegex(ValueError,
80                                           'a coroutine was expected'):
81                asyncio.run(o)
82
83    def test_asyncio_run_debug(self):
84        async def main(expected):
85            loop = asyncio.get_event_loop()
86            self.assertIs(loop.get_debug(), expected)
87
88        asyncio.run(main(False))
89        asyncio.run(main(True), debug=True)
90
91    def test_asyncio_run_from_running_loop(self):
92        async def main():
93            coro = main()
94            try:
95                asyncio.run(coro)
96            finally:
97                coro.close()  # Suppress ResourceWarning
98
99        with self.assertRaisesRegex(RuntimeError,
100                                    'cannot be called from a running'):
101            asyncio.run(main())
102
103    def test_asyncio_run_cancels_hanging_tasks(self):
104        lo_task = None
105
106        async def leftover():
107            await asyncio.sleep(0.1)
108
109        async def main():
110            nonlocal lo_task
111            lo_task = asyncio.create_task(leftover())
112            return 123
113
114        self.assertEqual(asyncio.run(main()), 123)
115        self.assertTrue(lo_task.done())
116
117    def test_asyncio_run_reports_hanging_tasks_errors(self):
118        lo_task = None
119        call_exc_handler_mock = mock.Mock()
120
121        async def leftover():
122            try:
123                await asyncio.sleep(0.1)
124            except asyncio.CancelledError:
125                1 / 0
126
127        async def main():
128            loop = asyncio.get_running_loop()
129            loop.call_exception_handler = call_exc_handler_mock
130
131            nonlocal lo_task
132            lo_task = asyncio.create_task(leftover())
133            return 123
134
135        self.assertEqual(asyncio.run(main()), 123)
136        self.assertTrue(lo_task.done())
137
138        call_exc_handler_mock.assert_called_with({
139            'message': test_utils.MockPattern(r'asyncio.run.*shutdown'),
140            'task': lo_task,
141            'exception': test_utils.MockInstanceOf(ZeroDivisionError)
142        })
143
144    def test_asyncio_run_closes_gens_after_hanging_tasks_errors(self):
145        spinner = None
146        lazyboy = None
147
148        class FancyExit(Exception):
149            pass
150
151        async def fidget():
152            while True:
153                yield 1
154                await asyncio.sleep(1)
155
156        async def spin():
157            nonlocal spinner
158            spinner = fidget()
159            try:
160                async for the_meaning_of_life in spinner:  # NoQA
161                    pass
162            except asyncio.CancelledError:
163                1 / 0
164
165        async def main():
166            loop = asyncio.get_running_loop()
167            loop.call_exception_handler = mock.Mock()
168
169            nonlocal lazyboy
170            lazyboy = asyncio.create_task(spin())
171            raise FancyExit
172
173        with self.assertRaises(FancyExit):
174            asyncio.run(main())
175
176        self.assertTrue(lazyboy.done())
177
178        self.assertIsNone(spinner.ag_frame)
179        self.assertFalse(spinner.ag_running)
180