1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3 
4 use futures::{
5     future::{pending, ready},
6     FutureExt,
7 };
8 
9 use tokio::runtime::{self, Runtime};
10 use tokio::sync::{mpsc, oneshot};
11 use tokio::task::{self, LocalSet};
12 use tokio::time;
13 
14 use std::cell::Cell;
15 use std::sync::atomic::Ordering::{self, SeqCst};
16 use std::sync::atomic::{AtomicBool, AtomicUsize};
17 use std::time::Duration;
18 
19 #[tokio::test(flavor = "current_thread")]
local_basic_scheduler()20 async fn local_basic_scheduler() {
21     LocalSet::new()
22         .run_until(async {
23             task::spawn_local(async {}).await.unwrap();
24         })
25         .await;
26 }
27 
28 #[tokio::test(flavor = "multi_thread")]
local_threadpool()29 async fn local_threadpool() {
30     thread_local! {
31         static ON_RT_THREAD: Cell<bool> = Cell::new(false);
32     }
33 
34     ON_RT_THREAD.with(|cell| cell.set(true));
35 
36     LocalSet::new()
37         .run_until(async {
38             assert!(ON_RT_THREAD.with(|cell| cell.get()));
39             task::spawn_local(async {
40                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
41             })
42             .await
43             .unwrap();
44         })
45         .await;
46 }
47 
48 #[tokio::test(flavor = "multi_thread")]
localset_future_threadpool()49 async fn localset_future_threadpool() {
50     thread_local! {
51         static ON_LOCAL_THREAD: Cell<bool> = Cell::new(false);
52     }
53 
54     ON_LOCAL_THREAD.with(|cell| cell.set(true));
55 
56     let local = LocalSet::new();
57     local.spawn_local(async move {
58         assert!(ON_LOCAL_THREAD.with(|cell| cell.get()));
59     });
60     local.await;
61 }
62 
63 #[tokio::test(flavor = "multi_thread")]
localset_future_timers()64 async fn localset_future_timers() {
65     static RAN1: AtomicBool = AtomicBool::new(false);
66     static RAN2: AtomicBool = AtomicBool::new(false);
67 
68     let local = LocalSet::new();
69     local.spawn_local(async move {
70         time::sleep(Duration::from_millis(10)).await;
71         RAN1.store(true, Ordering::SeqCst);
72     });
73     local.spawn_local(async move {
74         time::sleep(Duration::from_millis(20)).await;
75         RAN2.store(true, Ordering::SeqCst);
76     });
77     local.await;
78     assert!(RAN1.load(Ordering::SeqCst));
79     assert!(RAN2.load(Ordering::SeqCst));
80 }
81 
82 #[tokio::test]
localset_future_drives_all_local_futs()83 async fn localset_future_drives_all_local_futs() {
84     static RAN1: AtomicBool = AtomicBool::new(false);
85     static RAN2: AtomicBool = AtomicBool::new(false);
86     static RAN3: AtomicBool = AtomicBool::new(false);
87 
88     let local = LocalSet::new();
89     local.spawn_local(async move {
90         task::spawn_local(async {
91             task::yield_now().await;
92             RAN3.store(true, Ordering::SeqCst);
93         });
94         task::yield_now().await;
95         RAN1.store(true, Ordering::SeqCst);
96     });
97     local.spawn_local(async move {
98         task::yield_now().await;
99         RAN2.store(true, Ordering::SeqCst);
100     });
101     local.await;
102     assert!(RAN1.load(Ordering::SeqCst));
103     assert!(RAN2.load(Ordering::SeqCst));
104     assert!(RAN3.load(Ordering::SeqCst));
105 }
106 
107 #[tokio::test(flavor = "multi_thread")]
local_threadpool_timer()108 async fn local_threadpool_timer() {
109     // This test ensures that runtime services like the timer are properly
110     // set for the local task set.
111     thread_local! {
112         static ON_RT_THREAD: Cell<bool> = Cell::new(false);
113     }
114 
115     ON_RT_THREAD.with(|cell| cell.set(true));
116 
117     LocalSet::new()
118         .run_until(async {
119             assert!(ON_RT_THREAD.with(|cell| cell.get()));
120             let join = task::spawn_local(async move {
121                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
122                 time::sleep(Duration::from_millis(10)).await;
123                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
124             });
125             join.await.unwrap();
126         })
127         .await;
128 }
129 
130 #[test]
131 // This will panic, since the thread that calls `block_on` cannot use
132 // in-place blocking inside of `block_on`.
133 #[should_panic]
local_threadpool_blocking_in_place()134 fn local_threadpool_blocking_in_place() {
135     thread_local! {
136         static ON_RT_THREAD: Cell<bool> = Cell::new(false);
137     }
138 
139     ON_RT_THREAD.with(|cell| cell.set(true));
140 
141     let rt = runtime::Builder::new_current_thread()
142         .enable_all()
143         .build()
144         .unwrap();
145     LocalSet::new().block_on(&rt, async {
146         assert!(ON_RT_THREAD.with(|cell| cell.get()));
147         let join = task::spawn_local(async move {
148             assert!(ON_RT_THREAD.with(|cell| cell.get()));
149             task::block_in_place(|| {});
150             assert!(ON_RT_THREAD.with(|cell| cell.get()));
151         });
152         join.await.unwrap();
153     });
154 }
155 
156 #[tokio::test(flavor = "multi_thread")]
local_threadpool_blocking_run()157 async fn local_threadpool_blocking_run() {
158     thread_local! {
159         static ON_RT_THREAD: Cell<bool> = Cell::new(false);
160     }
161 
162     ON_RT_THREAD.with(|cell| cell.set(true));
163 
164     LocalSet::new()
165         .run_until(async {
166             assert!(ON_RT_THREAD.with(|cell| cell.get()));
167             let join = task::spawn_local(async move {
168                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
169                 task::spawn_blocking(|| {
170                     assert!(
171                         !ON_RT_THREAD.with(|cell| cell.get()),
172                         "blocking must not run on the local task set's thread"
173                     );
174                 })
175                 .await
176                 .unwrap();
177                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
178             });
179             join.await.unwrap();
180         })
181         .await;
182 }
183 
184 #[tokio::test(flavor = "multi_thread")]
all_spawns_are_local()185 async fn all_spawns_are_local() {
186     use futures::future;
187     thread_local! {
188         static ON_RT_THREAD: Cell<bool> = Cell::new(false);
189     }
190 
191     ON_RT_THREAD.with(|cell| cell.set(true));
192 
193     LocalSet::new()
194         .run_until(async {
195             assert!(ON_RT_THREAD.with(|cell| cell.get()));
196             let handles = (0..128)
197                 .map(|_| {
198                     task::spawn_local(async {
199                         assert!(ON_RT_THREAD.with(|cell| cell.get()));
200                     })
201                 })
202                 .collect::<Vec<_>>();
203             for joined in future::join_all(handles).await {
204                 joined.unwrap();
205             }
206         })
207         .await;
208 }
209 
210 #[tokio::test(flavor = "multi_thread")]
nested_spawn_is_local()211 async fn nested_spawn_is_local() {
212     thread_local! {
213         static ON_RT_THREAD: Cell<bool> = Cell::new(false);
214     }
215 
216     ON_RT_THREAD.with(|cell| cell.set(true));
217 
218     LocalSet::new()
219         .run_until(async {
220             assert!(ON_RT_THREAD.with(|cell| cell.get()));
221             task::spawn_local(async {
222                 assert!(ON_RT_THREAD.with(|cell| cell.get()));
223                 task::spawn_local(async {
224                     assert!(ON_RT_THREAD.with(|cell| cell.get()));
225                     task::spawn_local(async {
226                         assert!(ON_RT_THREAD.with(|cell| cell.get()));
227                         task::spawn_local(async {
228                             assert!(ON_RT_THREAD.with(|cell| cell.get()));
229                         })
230                         .await
231                         .unwrap();
232                     })
233                     .await
234                     .unwrap();
235                 })
236                 .await
237                 .unwrap();
238             })
239             .await
240             .unwrap();
241         })
242         .await;
243 }
244 
245 #[test]
join_local_future_elsewhere()246 fn join_local_future_elsewhere() {
247     thread_local! {
248         static ON_RT_THREAD: Cell<bool> = Cell::new(false);
249     }
250 
251     ON_RT_THREAD.with(|cell| cell.set(true));
252 
253     let rt = runtime::Runtime::new().unwrap();
254     let local = LocalSet::new();
255     local.block_on(&rt, async move {
256         let (tx, rx) = oneshot::channel();
257         let join = task::spawn_local(async move {
258             println!("hello world running...");
259             assert!(
260                 ON_RT_THREAD.with(|cell| cell.get()),
261                 "local task must run on local thread, no matter where it is awaited"
262             );
263             rx.await.unwrap();
264 
265             println!("hello world task done");
266             "hello world"
267         });
268         let join2 = task::spawn(async move {
269             assert!(
270                 !ON_RT_THREAD.with(|cell| cell.get()),
271                 "spawned task should be on a worker"
272             );
273 
274             tx.send(()).expect("task shouldn't have ended yet");
275             println!("waking up hello world...");
276 
277             join.await.expect("task should complete successfully");
278 
279             println!("hello world task joined");
280         });
281         join2.await.unwrap()
282     });
283 }
284 
285 #[test]
drop_cancels_tasks()286 fn drop_cancels_tasks() {
287     use std::rc::Rc;
288 
289     // This test reproduces issue #1842
290     let rt = rt();
291     let rc1 = Rc::new(());
292     let rc2 = rc1.clone();
293 
294     let (started_tx, started_rx) = oneshot::channel();
295 
296     let local = LocalSet::new();
297     local.spawn_local(async move {
298         // Move this in
299         let _rc2 = rc2;
300 
301         started_tx.send(()).unwrap();
302         loop {
303             time::sleep(Duration::from_secs(3600)).await;
304         }
305     });
306 
307     local.block_on(&rt, async {
308         started_rx.await.unwrap();
309     });
310     drop(local);
311     drop(rt);
312 
313     assert_eq!(1, Rc::strong_count(&rc1));
314 }
315 
316 /// Runs a test function in a separate thread, and panics if the test does not
317 /// complete within the specified timeout, or if the test function panics.
318 ///
319 /// This is intended for running tests whose failure mode is a hang or infinite
320 /// loop that cannot be detected otherwise.
with_timeout(timeout: Duration, f: impl FnOnce() + Send + 'static)321 fn with_timeout(timeout: Duration, f: impl FnOnce() + Send + 'static) {
322     use std::sync::mpsc::RecvTimeoutError;
323 
324     let (done_tx, done_rx) = std::sync::mpsc::channel();
325     let thread = std::thread::spawn(move || {
326         f();
327 
328         // Send a message on the channel so that the test thread can
329         // determine if we have entered an infinite loop:
330         done_tx.send(()).unwrap();
331     });
332 
333     // Since the failure mode of this test is an infinite loop, rather than
334     // something we can easily make assertions about, we'll run it in a
335     // thread. When the test thread finishes, it will send a message on a
336     // channel to this thread. We'll wait for that message with a fairly
337     // generous timeout, and if we don't recieve it, we assume the test
338     // thread has hung.
339     //
340     // Note that it should definitely complete in under a minute, but just
341     // in case CI is slow, we'll give it a long timeout.
342     match done_rx.recv_timeout(timeout) {
343         Err(RecvTimeoutError::Timeout) => panic!(
344             "test did not complete within {:?} seconds, \
345              we have (probably) entered an infinite loop!",
346             timeout,
347         ),
348         // Did the test thread panic? We'll find out for sure when we `join`
349         // with it.
350         Err(RecvTimeoutError::Disconnected) => {
351             println!("done_rx dropped, did the test thread panic?");
352         }
353         // Test completed successfully!
354         Ok(()) => {}
355     }
356 
357     thread.join().expect("test thread should not panic!")
358 }
359 
360 #[test]
drop_cancels_remote_tasks()361 fn drop_cancels_remote_tasks() {
362     // This test reproduces issue #1885.
363     with_timeout(Duration::from_secs(60), || {
364         let (tx, mut rx) = mpsc::channel::<()>(1024);
365 
366         let rt = rt();
367 
368         let local = LocalSet::new();
369         local.spawn_local(async move { while rx.recv().await.is_some() {} });
370         local.block_on(&rt, async {
371             time::sleep(Duration::from_millis(1)).await;
372         });
373 
374         drop(tx);
375 
376         // This enters an infinite loop if the remote notified tasks are not
377         // properly cancelled.
378         drop(local);
379     });
380 }
381 
382 #[test]
local_tasks_wake_join_all()383 fn local_tasks_wake_join_all() {
384     // This test reproduces issue #2460.
385     with_timeout(Duration::from_secs(60), || {
386         use futures::future::join_all;
387         use tokio::task::LocalSet;
388 
389         let rt = rt();
390         let set = LocalSet::new();
391         let mut handles = Vec::new();
392 
393         for _ in 1..=128 {
394             handles.push(set.spawn_local(async move {
395                 tokio::task::spawn_local(async move {}).await.unwrap();
396             }));
397         }
398 
399         rt.block_on(set.run_until(join_all(handles)));
400     });
401 }
402 
403 #[tokio::test]
local_tasks_are_polled_after_tick()404 async fn local_tasks_are_polled_after_tick() {
405     // Reproduces issues #1899 and #1900
406 
407     static RX1: AtomicUsize = AtomicUsize::new(0);
408     static RX2: AtomicUsize = AtomicUsize::new(0);
409     static EXPECTED: usize = 500;
410 
411     let (tx, mut rx) = mpsc::unbounded_channel();
412 
413     let local = LocalSet::new();
414 
415     local
416         .run_until(async {
417             let task2 = task::spawn(async move {
418                 // Wait a bit
419                 time::sleep(Duration::from_millis(100)).await;
420 
421                 let mut oneshots = Vec::with_capacity(EXPECTED);
422 
423                 // Send values
424                 for _ in 0..EXPECTED {
425                     let (oneshot_tx, oneshot_rx) = oneshot::channel();
426                     oneshots.push(oneshot_tx);
427                     tx.send(oneshot_rx).unwrap();
428                 }
429 
430                 time::sleep(Duration::from_millis(100)).await;
431 
432                 for tx in oneshots.drain(..) {
433                     tx.send(()).unwrap();
434                 }
435 
436                 time::sleep(Duration::from_millis(300)).await;
437                 let rx1 = RX1.load(SeqCst);
438                 let rx2 = RX2.load(SeqCst);
439                 println!("EXPECT = {}; RX1 = {}; RX2 = {}", EXPECTED, rx1, rx2);
440                 assert_eq!(EXPECTED, rx1);
441                 assert_eq!(EXPECTED, rx2);
442             });
443 
444             while let Some(oneshot) = rx.recv().await {
445                 RX1.fetch_add(1, SeqCst);
446 
447                 task::spawn_local(async move {
448                     oneshot.await.unwrap();
449                     RX2.fetch_add(1, SeqCst);
450                 });
451             }
452 
453             task2.await.unwrap();
454         })
455         .await;
456 }
457 
458 #[tokio::test]
acquire_mutex_in_drop()459 async fn acquire_mutex_in_drop() {
460     use futures::future::pending;
461 
462     let (tx1, rx1) = oneshot::channel();
463     let (tx2, rx2) = oneshot::channel();
464     let local = LocalSet::new();
465 
466     local.spawn_local(async move {
467         let _ = rx2.await;
468         unreachable!();
469     });
470 
471     local.spawn_local(async move {
472         let _ = rx1.await;
473         tx2.send(()).unwrap();
474         unreachable!();
475     });
476 
477     // Spawn a task that will never notify
478     local.spawn_local(async move {
479         pending::<()>().await;
480         tx1.send(()).unwrap();
481     });
482 
483     // Tick the loop
484     local
485         .run_until(async {
486             task::yield_now().await;
487         })
488         .await;
489 
490     // Drop the LocalSet
491     drop(local);
492 }
493 
494 #[tokio::test]
spawn_wakes_localset()495 async fn spawn_wakes_localset() {
496     let local = LocalSet::new();
497     futures::select! {
498         _ = local.run_until(pending::<()>()).fuse() => unreachable!(),
499         ret = async { local.spawn_local(ready(())).await.unwrap()}.fuse() => ret
500     }
501 }
502 
rt() -> Runtime503 fn rt() -> Runtime {
504     tokio::runtime::Builder::new_current_thread()
505         .enable_all()
506         .build()
507         .unwrap()
508 }
509