1 // Necessary for using `Mutex<usize>` for conditional variables
2 #![allow(clippy::mutex_atomic)]
3 
4 use crate::primitive::sync::{Arc, Condvar, Mutex};
5 use std::fmt;
6 
7 /// Enables threads to synchronize the beginning or end of some computation.
8 ///
9 /// # Wait groups vs barriers
10 ///
11 /// `WaitGroup` is very similar to [`Barrier`], but there are a few differences:
12 ///
13 /// * [`Barrier`] needs to know the number of threads at construction, while `WaitGroup` is cloned to
14 ///   register more threads.
15 ///
16 /// * A [`Barrier`] can be reused even after all threads have synchronized, while a `WaitGroup`
17 ///   synchronizes threads only once.
18 ///
19 /// * All threads wait for others to reach the [`Barrier`]. With `WaitGroup`, each thread can choose
20 ///   to either wait for other threads or to continue without blocking.
21 ///
22 /// # Examples
23 ///
24 /// ```
25 /// use crossbeam_utils::sync::WaitGroup;
26 /// use std::thread;
27 ///
28 /// // Create a new wait group.
29 /// let wg = WaitGroup::new();
30 ///
31 /// for _ in 0..4 {
32 ///     // Create another reference to the wait group.
33 ///     let wg = wg.clone();
34 ///
35 ///     thread::spawn(move || {
36 ///         // Do some work.
37 ///
38 ///         // Drop the reference to the wait group.
39 ///         drop(wg);
40 ///     });
41 /// }
42 ///
43 /// // Block until all threads have finished their work.
44 /// wg.wait();
45 /// ```
46 ///
47 /// [`Barrier`]: std::sync::Barrier
48 pub struct WaitGroup {
49     inner: Arc<Inner>,
50 }
51 
52 /// Inner state of a `WaitGroup`.
53 struct Inner {
54     cvar: Condvar,
55     count: Mutex<usize>,
56 }
57 
58 impl Default for WaitGroup {
default() -> Self59     fn default() -> Self {
60         Self {
61             inner: Arc::new(Inner {
62                 cvar: Condvar::new(),
63                 count: Mutex::new(1),
64             }),
65         }
66     }
67 }
68 
69 impl WaitGroup {
70     /// Creates a new wait group and returns the single reference to it.
71     ///
72     /// # Examples
73     ///
74     /// ```
75     /// use crossbeam_utils::sync::WaitGroup;
76     ///
77     /// let wg = WaitGroup::new();
78     /// ```
new() -> Self79     pub fn new() -> Self {
80         Self::default()
81     }
82 
83     /// Drops this reference and waits until all other references are dropped.
84     ///
85     /// # Examples
86     ///
87     /// ```
88     /// use crossbeam_utils::sync::WaitGroup;
89     /// use std::thread;
90     ///
91     /// let wg = WaitGroup::new();
92     ///
93     /// thread::spawn({
94     ///     let wg = wg.clone();
95     ///     move || {
96     ///         // Block until both threads have reached `wait()`.
97     ///         wg.wait();
98     ///     }
99     /// });
100     ///
101     /// // Block until both threads have reached `wait()`.
102     /// wg.wait();
103     /// ```
wait(self)104     pub fn wait(self) {
105         if *self.inner.count.lock().unwrap() == 1 {
106             return;
107         }
108 
109         let inner = self.inner.clone();
110         drop(self);
111 
112         let mut count = inner.count.lock().unwrap();
113         while *count > 0 {
114             count = inner.cvar.wait(count).unwrap();
115         }
116     }
117 }
118 
119 impl Drop for WaitGroup {
drop(&mut self)120     fn drop(&mut self) {
121         let mut count = self.inner.count.lock().unwrap();
122         *count -= 1;
123 
124         if *count == 0 {
125             self.inner.cvar.notify_all();
126         }
127     }
128 }
129 
130 impl Clone for WaitGroup {
clone(&self) -> WaitGroup131     fn clone(&self) -> WaitGroup {
132         let mut count = self.inner.count.lock().unwrap();
133         *count += 1;
134 
135         WaitGroup {
136             inner: self.inner.clone(),
137         }
138     }
139 }
140 
141 impl fmt::Debug for WaitGroup {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result142     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143         let count: &usize = &*self.inner.count.lock().unwrap();
144         f.debug_struct("WaitGroup").field("count", count).finish()
145     }
146 }
147