1 #![cfg(test)]
2 
3 #[allow(deprecated)]
4 use crate::Configuration;
5 use crate::{ThreadPoolBuildError, ThreadPoolBuilder};
6 use std::sync::atomic::{AtomicUsize, Ordering};
7 use std::sync::{Arc, Barrier};
8 
9 #[test]
worker_thread_index()10 fn worker_thread_index() {
11     let pool = ThreadPoolBuilder::new().num_threads(22).build().unwrap();
12     assert_eq!(pool.current_num_threads(), 22);
13     assert_eq!(pool.current_thread_index(), None);
14     let index = pool.install(|| pool.current_thread_index().unwrap());
15     assert!(index < 22);
16 }
17 
18 #[test]
start_callback_called()19 fn start_callback_called() {
20     let n_threads = 16;
21     let n_called = Arc::new(AtomicUsize::new(0));
22     // Wait for all the threads in the pool plus the one running tests.
23     let barrier = Arc::new(Barrier::new(n_threads + 1));
24 
25     let b = barrier.clone();
26     let nc = n_called.clone();
27     let start_handler = move |_| {
28         nc.fetch_add(1, Ordering::SeqCst);
29         b.wait();
30     };
31 
32     let conf = ThreadPoolBuilder::new()
33         .num_threads(n_threads)
34         .start_handler(start_handler);
35     let _ = conf.build().unwrap();
36 
37     // Wait for all the threads to have been scheduled to run.
38     barrier.wait();
39 
40     // The handler must have been called on every started thread.
41     assert_eq!(n_called.load(Ordering::SeqCst), n_threads);
42 }
43 
44 #[test]
exit_callback_called()45 fn exit_callback_called() {
46     let n_threads = 16;
47     let n_called = Arc::new(AtomicUsize::new(0));
48     // Wait for all the threads in the pool plus the one running tests.
49     let barrier = Arc::new(Barrier::new(n_threads + 1));
50 
51     let b = barrier.clone();
52     let nc = n_called.clone();
53     let exit_handler = move |_| {
54         nc.fetch_add(1, Ordering::SeqCst);
55         b.wait();
56     };
57 
58     let conf = ThreadPoolBuilder::new()
59         .num_threads(n_threads)
60         .exit_handler(exit_handler);
61     {
62         let _ = conf.build().unwrap();
63         // Drop the pool so it stops the running threads.
64     }
65 
66     // Wait for all the threads to have been scheduled to run.
67     barrier.wait();
68 
69     // The handler must have been called on every exiting thread.
70     assert_eq!(n_called.load(Ordering::SeqCst), n_threads);
71 }
72 
73 #[test]
handler_panics_handled_correctly()74 fn handler_panics_handled_correctly() {
75     let n_threads = 16;
76     let n_called = Arc::new(AtomicUsize::new(0));
77     // Wait for all the threads in the pool plus the one running tests.
78     let start_barrier = Arc::new(Barrier::new(n_threads + 1));
79     let exit_barrier = Arc::new(Barrier::new(n_threads + 1));
80 
81     let start_handler = move |_| {
82         panic!("ensure panic handler is called when starting");
83     };
84     let exit_handler = move |_| {
85         panic!("ensure panic handler is called when exiting");
86     };
87 
88     let sb = start_barrier.clone();
89     let eb = exit_barrier.clone();
90     let nc = n_called.clone();
91     let panic_handler = move |_| {
92         let val = nc.fetch_add(1, Ordering::SeqCst);
93         if val < n_threads {
94             sb.wait();
95         } else {
96             eb.wait();
97         }
98     };
99 
100     let conf = ThreadPoolBuilder::new()
101         .num_threads(n_threads)
102         .start_handler(start_handler)
103         .exit_handler(exit_handler)
104         .panic_handler(panic_handler);
105     {
106         let _ = conf.build().unwrap();
107 
108         // Wait for all the threads to start, panic in the start handler,
109         // and been taken care of by the panic handler.
110         start_barrier.wait();
111 
112         // Drop the pool so it stops the running threads.
113     }
114 
115     // Wait for all the threads to exit, panic in the exit handler,
116     // and been taken care of by the panic handler.
117     exit_barrier.wait();
118 
119     // The panic handler must have been called twice on every thread.
120     assert_eq!(n_called.load(Ordering::SeqCst), 2 * n_threads);
121 }
122 
123 #[test]
124 #[allow(deprecated)]
check_config_build()125 fn check_config_build() {
126     let pool = ThreadPoolBuilder::new().num_threads(22).build().unwrap();
127     assert_eq!(pool.current_num_threads(), 22);
128 }
129 
130 /// Helper used by check_error_send_sync to ensure ThreadPoolBuildError is Send + Sync
_send_sync<T: Send + Sync>()131 fn _send_sync<T: Send + Sync>() {}
132 
133 #[test]
check_error_send_sync()134 fn check_error_send_sync() {
135     _send_sync::<ThreadPoolBuildError>();
136 }
137 
138 #[allow(deprecated)]
139 #[test]
configuration()140 fn configuration() {
141     let start_handler = move |_| {};
142     let exit_handler = move |_| {};
143     let panic_handler = move |_| {};
144     let thread_name = move |i| format!("thread_name_{}", i);
145 
146     // Ensure we can call all public methods on Configuration
147     Configuration::new()
148         .thread_name(thread_name)
149         .num_threads(5)
150         .panic_handler(panic_handler)
151         .stack_size(4e6 as usize)
152         .breadth_first()
153         .start_handler(start_handler)
154         .exit_handler(exit_handler)
155         .build()
156         .unwrap();
157 }
158 
159 #[test]
default_pool()160 fn default_pool() {
161     ThreadPoolBuilder::default().build().unwrap();
162 }
163 
164 /// Test that custom spawned threads get their `WorkerThread` cleared once
165 /// the pool is done with them, allowing them to be used with rayon again
166 /// later. e.g. WebAssembly want to have their own pool of available threads.
167 #[test]
cleared_current_thread() -> Result<(), ThreadPoolBuildError>168 fn cleared_current_thread() -> Result<(), ThreadPoolBuildError> {
169     let n_threads = 5;
170     let mut handles = vec![];
171     let pool = ThreadPoolBuilder::new()
172         .num_threads(n_threads)
173         .spawn_handler(|thread| {
174             let handle = std::thread::spawn(move || {
175                 thread.run();
176 
177                 // Afterward, the current thread shouldn't be set anymore.
178                 assert_eq!(crate::current_thread_index(), None);
179             });
180             handles.push(handle);
181             Ok(())
182         })
183         .build()?;
184     assert_eq!(handles.len(), n_threads);
185 
186     pool.install(|| assert!(crate::current_thread_index().is_some()));
187     drop(pool);
188 
189     // Wait for all threads to make their assertions and exit
190     for handle in handles {
191         handle.join().unwrap();
192     }
193 
194     Ok(())
195 }
196