1 //! Synchronization primitive allowing multiple threads to synchronize the 2 //! beginning of some computation. 3 //! 4 //! Implementation adopted the 'Barrier' type of the standard library. See: 5 //! https://doc.rust-lang.org/std/sync/struct.Barrier.html 6 //! 7 //! Copyright 2014 The Rust Project Developers. See the COPYRIGHT 8 //! file at the top-level directory of this distribution and at 9 //! http://rust-lang.org/COPYRIGHT. 10 //! 11 //! Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or 12 //! http://www.apache.org/licenses/LICENSE-2.0> or the MIT license 13 //! <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your 14 //! option. This file may not be copied, modified, or distributed 15 //! except according to those terms. 16 17 use core::sync::atomic::spin_loop_hint as cpu_relax; 18 19 use crate::Mutex; 20 21 /// A primitive that synchronizes the execution of multiple threads. 22 /// 23 /// # Example 24 /// 25 /// ``` 26 /// use spin; 27 /// use std::sync::Arc; 28 /// use std::thread; 29 /// 30 /// let mut handles = Vec::with_capacity(10); 31 /// let barrier = Arc::new(spin::Barrier::new(10)); 32 /// for _ in 0..10 { 33 /// let c = barrier.clone(); 34 /// // The same messages will be printed together. 35 /// // You will NOT see any interleaving. 36 /// handles.push(thread::spawn(move|| { 37 /// println!("before wait"); 38 /// c.wait(); 39 /// println!("after wait"); 40 /// })); 41 /// } 42 /// // Wait for other threads to finish. 43 /// for handle in handles { 44 /// handle.join().unwrap(); 45 /// } 46 /// ``` 47 pub struct Barrier { 48 lock: Mutex<BarrierState>, 49 num_threads: usize, 50 } 51 52 // The inner state of a double barrier 53 struct BarrierState { 54 count: usize, 55 generation_id: usize, 56 } 57 58 /// A `BarrierWaitResult` is returned by [`wait`] when all threads in the [`Barrier`] 59 /// have rendezvoused. 60 /// 61 /// [`wait`]: struct.Barrier.html#method.wait 62 /// [`Barrier`]: struct.Barrier.html 63 /// 64 /// # Examples 65 /// 66 /// ``` 67 /// use spin; 68 /// 69 /// let barrier = spin::Barrier::new(1); 70 /// let barrier_wait_result = barrier.wait(); 71 /// ``` 72 pub struct BarrierWaitResult(bool); 73 74 impl Barrier { 75 /// Creates a new barrier that can block a given number of threads. 76 /// 77 /// A barrier will block `n`-1 threads which call [`wait`] and then wake up 78 /// all threads at once when the `n`th thread calls [`wait`]. A Barrier created 79 /// with n = 0 will behave identically to one created with n = 1. 80 /// 81 /// [`wait`]: #method.wait 82 /// 83 /// # Examples 84 /// 85 /// ``` 86 /// use spin; 87 /// 88 /// let barrier = spin::Barrier::new(10); 89 /// ``` new(n: usize) -> Barrier90 pub const fn new(n: usize) -> Barrier { 91 Barrier { 92 lock: Mutex::new(BarrierState { 93 count: 0, 94 generation_id: 0, 95 }), 96 num_threads: n, 97 } 98 } 99 100 /// Blocks the current thread until all threads have rendezvoused here. 101 /// 102 /// Barriers are re-usable after all threads have rendezvoused once, and can 103 /// be used continuously. 104 /// 105 /// A single (arbitrary) thread will receive a [`BarrierWaitResult`] that 106 /// returns `true` from [`is_leader`] when returning from this function, and 107 /// all other threads will receive a result that will return `false` from 108 /// [`is_leader`]. 109 /// 110 /// [`BarrierWaitResult`]: struct.BarrierWaitResult.html 111 /// [`is_leader`]: struct.BarrierWaitResult.html#method.is_leader 112 /// 113 /// # Examples 114 /// 115 /// ``` 116 /// use spin; 117 /// use std::sync::Arc; 118 /// use std::thread; 119 /// 120 /// let mut handles = Vec::with_capacity(10); 121 /// let barrier = Arc::new(spin::Barrier::new(10)); 122 /// for _ in 0..10 { 123 /// let c = barrier.clone(); 124 /// // The same messages will be printed together. 125 /// // You will NOT see any interleaving. 126 /// handles.push(thread::spawn(move|| { 127 /// println!("before wait"); 128 /// c.wait(); 129 /// println!("after wait"); 130 /// })); 131 /// } 132 /// // Wait for other threads to finish. 133 /// for handle in handles { 134 /// handle.join().unwrap(); 135 /// } 136 /// ``` wait(&self) -> BarrierWaitResult137 pub fn wait(&self) -> BarrierWaitResult { 138 let mut lock = self.lock.lock(); 139 lock.count += 1; 140 141 if lock.count < self.num_threads { 142 // not the leader 143 let local_gen = lock.generation_id; 144 145 while local_gen == lock.generation_id && 146 lock.count < self.num_threads { 147 drop(lock); 148 cpu_relax(); 149 lock = self.lock.lock(); 150 } 151 BarrierWaitResult(false) 152 } else { 153 // this thread is the leader, 154 // and is responsible for incrementing the generation 155 lock.count = 0; 156 lock.generation_id = lock.generation_id.wrapping_add(1); 157 BarrierWaitResult(true) 158 } 159 } 160 } 161 162 impl BarrierWaitResult { 163 /// Returns whether this thread from [`wait`] is the "leader thread". 164 /// 165 /// Only one thread will have `true` returned from their result, all other 166 /// threads will have `false` returned. 167 /// 168 /// [`wait`]: struct.Barrier.html#method.wait 169 /// 170 /// # Examples 171 /// 172 /// ``` 173 /// use spin; 174 /// 175 /// let barrier = spin::Barrier::new(1); 176 /// let barrier_wait_result = barrier.wait(); 177 /// println!("{:?}", barrier_wait_result.is_leader()); 178 /// ``` is_leader(&self) -> bool179 pub fn is_leader(&self) -> bool { self.0 } 180 } 181 182 #[cfg(test)] 183 mod tests { 184 use std::prelude::v1::*; 185 186 use std::sync::mpsc::{channel, TryRecvError}; 187 use std::sync::Arc; 188 use std::thread; 189 190 use super::Barrier; 191 use_barrier(n: usize, barrier: Arc<Barrier>)192 fn use_barrier(n: usize, barrier: Arc<Barrier>) { 193 let (tx, rx) = channel(); 194 195 for _ in 0..n - 1 { 196 let c = barrier.clone(); 197 let tx = tx.clone(); 198 thread::spawn(move|| { 199 tx.send(c.wait().is_leader()).unwrap(); 200 }); 201 } 202 203 // At this point, all spawned threads should be blocked, 204 // so we shouldn't get anything from the port 205 assert!(match rx.try_recv() { 206 Err(TryRecvError::Empty) => true, 207 _ => false, 208 }); 209 210 let mut leader_found = barrier.wait().is_leader(); 211 212 // Now, the barrier is cleared and we should get data. 213 for _ in 0..n - 1 { 214 if rx.recv().unwrap() { 215 assert!(!leader_found); 216 leader_found = true; 217 } 218 } 219 assert!(leader_found); 220 } 221 222 #[test] test_barrier()223 fn test_barrier() { 224 const N: usize = 10; 225 226 let barrier = Arc::new(Barrier::new(N)); 227 228 use_barrier(N, barrier.clone()); 229 230 // use barrier twice to ensure it is reusable 231 use_barrier(N, barrier.clone()); 232 } 233 } 234