1 use crate::sync::watch; 2 3 use std::sync::Mutex; 4 5 /// A barrier enables multiple threads to synchronize the beginning of some computation. 6 /// 7 /// ``` 8 /// # #[tokio::main] 9 /// # async fn main() { 10 /// use tokio::sync::Barrier; 11 /// use std::sync::Arc; 12 /// 13 /// let mut handles = Vec::with_capacity(10); 14 /// let barrier = Arc::new(Barrier::new(10)); 15 /// for _ in 0..10 { 16 /// let c = barrier.clone(); 17 /// // The same messages will be printed together. 18 /// // You will NOT see any interleaving. 19 /// handles.push(tokio::spawn(async move { 20 /// println!("before wait"); 21 /// let wait_result = c.wait().await; 22 /// println!("after wait"); 23 /// wait_result 24 /// })); 25 /// } 26 /// 27 /// // Will not resolve until all "after wait" messages have been printed 28 /// let mut num_leaders = 0; 29 /// for handle in handles { 30 /// let wait_result = handle.await.unwrap(); 31 /// if wait_result.is_leader() { 32 /// num_leaders += 1; 33 /// } 34 /// } 35 /// 36 /// // Exactly one barrier will resolve as the "leader" 37 /// assert_eq!(num_leaders, 1); 38 /// # } 39 /// ``` 40 #[derive(Debug)] 41 pub struct Barrier { 42 state: Mutex<BarrierState>, 43 wait: watch::Receiver<usize>, 44 n: usize, 45 } 46 47 #[derive(Debug)] 48 struct BarrierState { 49 waker: watch::Sender<usize>, 50 arrived: usize, 51 generation: usize, 52 } 53 54 impl Barrier { 55 /// Creates a new barrier that can block a given number of threads. 56 /// 57 /// A barrier will block `n`-1 threads which call [`Barrier::wait`] and then wake up all 58 /// threads at once when the `n`th thread calls `wait`. new(mut n: usize) -> Barrier59 pub fn new(mut n: usize) -> Barrier { 60 let (waker, wait) = crate::sync::watch::channel(0); 61 62 if n == 0 { 63 // if n is 0, it's not clear what behavior the user wants. 64 // in std::sync::Barrier, an n of 0 exhibits the same behavior as n == 1, where every 65 // .wait() immediately unblocks, so we adopt that here as well. 66 n = 1; 67 } 68 69 Barrier { 70 state: Mutex::new(BarrierState { 71 waker, 72 arrived: 0, 73 generation: 1, 74 }), 75 n, 76 wait, 77 } 78 } 79 80 /// Does not resolve until all tasks have rendezvoused here. 81 /// 82 /// Barriers are re-usable after all threads have rendezvoused once, and can 83 /// be used continuously. 84 /// 85 /// A single (arbitrary) future will receive a [`BarrierWaitResult`] that returns `true` from 86 /// [`BarrierWaitResult::is_leader`] when returning from this function, and all other threads 87 /// will receive a result that will return `false` from `is_leader`. wait(&self) -> BarrierWaitResult88 pub async fn wait(&self) -> BarrierWaitResult { 89 // NOTE: we are taking a _synchronous_ lock here. 90 // It is okay to do so because the critical section is fast and never yields, so it cannot 91 // deadlock even if another future is concurrently holding the lock. 92 // It is _desireable_ to do so as synchronous Mutexes are, at least in theory, faster than 93 // the asynchronous counter-parts, so we should use them where possible [citation needed]. 94 // NOTE: the extra scope here is so that the compiler doesn't think `state` is held across 95 // a yield point, and thus marks the returned future as !Send. 96 let generation = { 97 let mut state = self.state.lock().unwrap(); 98 let generation = state.generation; 99 state.arrived += 1; 100 if state.arrived == self.n { 101 // we are the leader for this generation 102 // wake everyone, increment the generation, and return 103 state 104 .waker 105 .send(state.generation) 106 .expect("there is at least one receiver"); 107 state.arrived = 0; 108 state.generation += 1; 109 return BarrierWaitResult(true); 110 } 111 112 generation 113 }; 114 115 // we're going to have to wait for the last of the generation to arrive 116 let mut wait = self.wait.clone(); 117 118 loop { 119 let _ = wait.changed().await; 120 121 // note that the first time through the loop, this _will_ yield a generation 122 // immediately, since we cloned a receiver that has never seen any values. 123 if *wait.borrow() >= generation { 124 break; 125 } 126 } 127 128 BarrierWaitResult(false) 129 } 130 } 131 132 /// A `BarrierWaitResult` is returned by `wait` when all threads in the `Barrier` have rendezvoused. 133 #[derive(Debug, Clone)] 134 pub struct BarrierWaitResult(bool); 135 136 impl BarrierWaitResult { 137 /// Returns `true` if this thread from wait is the "leader thread". 138 /// 139 /// Only one thread will have `true` returned from their result, all other threads will have 140 /// `false` returned. is_leader(&self) -> bool141 pub fn is_leader(&self) -> bool { 142 self.0 143 } 144 } 145