1 use crate::sync::watch;
2 
3 use std::sync::Mutex;
4 
5 /// A barrier enables multiple threads to synchronize the beginning of some computation.
6 ///
7 /// ```
8 /// # #[tokio::main]
9 /// # async fn main() {
10 /// use tokio::sync::Barrier;
11 /// use std::sync::Arc;
12 ///
13 /// let mut handles = Vec::with_capacity(10);
14 /// let barrier = Arc::new(Barrier::new(10));
15 /// for _ in 0..10 {
16 ///     let c = barrier.clone();
17 ///     // The same messages will be printed together.
18 ///     // You will NOT see any interleaving.
19 ///     handles.push(tokio::spawn(async move {
20 ///         println!("before wait");
21 ///         let wait_result = c.wait().await;
22 ///         println!("after wait");
23 ///         wait_result
24 ///     }));
25 /// }
26 ///
27 /// // Will not resolve until all "after wait" messages have been printed
28 /// let mut num_leaders = 0;
29 /// for handle in handles {
30 ///     let wait_result = handle.await.unwrap();
31 ///     if wait_result.is_leader() {
32 ///         num_leaders += 1;
33 ///     }
34 /// }
35 ///
36 /// // Exactly one barrier will resolve as the "leader"
37 /// assert_eq!(num_leaders, 1);
38 /// # }
39 /// ```
40 #[derive(Debug)]
41 pub struct Barrier {
42     state: Mutex<BarrierState>,
43     wait: watch::Receiver<usize>,
44     n: usize,
45 }
46 
47 #[derive(Debug)]
48 struct BarrierState {
49     waker: watch::Sender<usize>,
50     arrived: usize,
51     generation: usize,
52 }
53 
54 impl Barrier {
55     /// Creates a new barrier that can block a given number of threads.
56     ///
57     /// A barrier will block `n`-1 threads which call [`Barrier::wait`] and then wake up all
58     /// threads at once when the `n`th thread calls `wait`.
new(mut n: usize) -> Barrier59     pub fn new(mut n: usize) -> Barrier {
60         let (waker, wait) = crate::sync::watch::channel(0);
61 
62         if n == 0 {
63             // if n is 0, it's not clear what behavior the user wants.
64             // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every
65             // .wait() immediately unblocks, so we adopt that here as well.
66             n = 1;
67         }
68 
69         Barrier {
70             state: Mutex::new(BarrierState {
71                 waker,
72                 arrived: 0,
73                 generation: 1,
74             }),
75             n,
76             wait,
77         }
78     }
79 
80     /// Does not resolve until all tasks have rendezvoused here.
81     ///
82     /// Barriers are re-usable after all threads have rendezvoused once, and can
83     /// be used continuously.
84     ///
85     /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from
86     /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other threads
87     /// will receive a result that will return `false` from `is_leader`.
wait(&self) -> BarrierWaitResult88     pub async fn wait(&self) -> BarrierWaitResult {
89         // NOTE: we are taking a _synchronous_ lock here.
90         // It is okay to do so because the critical section is fast and never yields, so it cannot
91         // deadlock even if another future is concurrently holding the lock.
92         // It is _desireable_ to do so as synchronous Mutexes are, at least in theory, faster than
93         // the asynchronous counter-parts, so we should use them where possible [citation needed].
94         // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across
95         // a yield point, and thus marks the returned future as !Send.
96         let generation = {
97             let mut state = self.state.lock().unwrap();
98             let generation = state.generation;
99             state.arrived += 1;
100             if state.arrived == self.n {
101                 // we are the leader for this generation
102                 // wake everyone, increment the generation, and return
103                 state
104                     .waker
105                     .send(state.generation)
106                     .expect("there is at least one receiver");
107                 state.arrived = 0;
108                 state.generation += 1;
109                 return BarrierWaitResult(true);
110             }
111 
112             generation
113         };
114 
115         // we're going to have to wait for the last of the generation to arrive
116         let mut wait = self.wait.clone();
117 
118         loop {
119             let _ = wait.changed().await;
120 
121             // note that the first time through the loop, this _will_ yield a generation
122             // immediately, since we cloned a receiver that has never seen any values.
123             if *wait.borrow() >= generation {
124                 break;
125             }
126         }
127 
128         BarrierWaitResult(false)
129     }
130 }
131 
132 /// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused.
133 #[derive(Debug, Clone)]
134 pub struct BarrierWaitResult(bool);
135 
136 impl BarrierWaitResult {
137     /// Returns `true` if this thread from wait is the "leader thread".
138     ///
139     /// Only one thread will have `true` returned from their result, all other threads will have
140     /// `false` returned.
is_leader(&self) -> bool141     pub fn is_leader(&self) -> bool {
142         self.0
143     }
144 }
145