1 //! Bounded channel based on a preallocated array.
2 //!
3 //! This flavor has a fixed, positive capacity.
4 //!
5 //! The implementation is based on Dmitry Vyukov's bounded MPMC queue.
6 //!
7 //! Source:
8 //!   - http://www.1024cores.net/home/lock-free-algorithms/queues/bounded-mpmc-queue
9 //!   - https://docs.google.com/document/d/1yIAYmbvL3JxOKOjuCyon7JhW4cSv1wy5hC0ApeGMV9s/pub
10 //!
11 //! Copyright & License:
12 //!   - Copyright (c) 2010-2011 Dmitry Vyukov
13 //!   - Simplified BSD License and Apache License, Version 2.0
14 //!   - http://www.1024cores.net/home/code-license
15 
16 use std::cell::UnsafeCell;
17 use std::marker::PhantomData;
18 use std::mem::{self, MaybeUninit};
19 use std::ptr;
20 use std::sync::atomic::{self, AtomicUsize, Ordering};
21 use std::time::Instant;
22 
23 use crossbeam_utils::{Backoff, CachePadded};
24 
25 use crate::context::Context;
26 use crate::err::{RecvTimeoutError, SendTimeoutError, TryRecvError, TrySendError};
27 use crate::select::{Operation, SelectHandle, Selected, Token};
28 use crate::waker::SyncWaker;
29 
30 /// A slot in a channel.
31 struct Slot<T> {
32     /// The current stamp.
33     stamp: AtomicUsize,
34 
35     /// The message in this slot.
36     msg: UnsafeCell<MaybeUninit<T>>,
37 }
38 
39 /// The token type for the array flavor.
40 #[derive(Debug)]
41 pub struct ArrayToken {
42     /// Slot to read from or write to.
43     slot: *const u8,
44 
45     /// Stamp to store into the slot after reading or writing.
46     stamp: usize,
47 }
48 
49 impl Default for ArrayToken {
50     #[inline]
default() -> Self51     fn default() -> Self {
52         ArrayToken {
53             slot: ptr::null(),
54             stamp: 0,
55         }
56     }
57 }
58 
59 /// Bounded channel based on a preallocated array.
60 pub struct Channel<T> {
61     /// The head of the channel.
62     ///
63     /// This value is a "stamp" consisting of an index into the buffer, a mark bit, and a lap, but
64     /// packed into a single `usize`. The lower bits represent the index, while the upper bits
65     /// represent the lap. The mark bit in the head is always zero.
66     ///
67     /// Messages are popped from the head of the channel.
68     head: CachePadded<AtomicUsize>,
69 
70     /// The tail of the channel.
71     ///
72     /// This value is a "stamp" consisting of an index into the buffer, a mark bit, and a lap, but
73     /// packed into a single `usize`. The lower bits represent the index, while the upper bits
74     /// represent the lap. The mark bit indicates that the channel is disconnected.
75     ///
76     /// Messages are pushed into the tail of the channel.
77     tail: CachePadded<AtomicUsize>,
78 
79     /// The buffer holding slots.
80     buffer: *mut Slot<T>,
81 
82     /// The channel capacity.
83     cap: usize,
84 
85     /// A stamp with the value of `{ lap: 1, mark: 0, index: 0 }`.
86     one_lap: usize,
87 
88     /// If this bit is set in the tail, that means the channel is disconnected.
89     mark_bit: usize,
90 
91     /// Senders waiting while the channel is full.
92     senders: SyncWaker,
93 
94     /// Receivers waiting while the channel is empty and not disconnected.
95     receivers: SyncWaker,
96 
97     /// Indicates that dropping a `Channel<T>` may drop values of type `T`.
98     _marker: PhantomData<T>,
99 }
100 
101 impl<T> Channel<T> {
102     /// Creates a bounded channel of capacity `cap`.
with_capacity(cap: usize) -> Self103     pub fn with_capacity(cap: usize) -> Self {
104         assert!(cap > 0, "capacity must be positive");
105 
106         // Compute constants `mark_bit` and `one_lap`.
107         let mark_bit = (cap + 1).next_power_of_two();
108         let one_lap = mark_bit * 2;
109 
110         // Head is initialized to `{ lap: 0, mark: 0, index: 0 }`.
111         let head = 0;
112         // Tail is initialized to `{ lap: 0, mark: 0, index: 0 }`.
113         let tail = 0;
114 
115         // Allocate a buffer of `cap` slots initialized
116         // with stamps.
117         let buffer = {
118             let mut boxed: Box<[Slot<T>]> = (0..cap)
119                 .map(|i| {
120                     // Set the stamp to `{ lap: 0, mark: 0, index: i }`.
121                     Slot {
122                         stamp: AtomicUsize::new(i),
123                         msg: UnsafeCell::new(MaybeUninit::uninit()),
124                     }
125                 })
126                 .collect();
127             let ptr = boxed.as_mut_ptr();
128             mem::forget(boxed);
129             ptr
130         };
131 
132         Channel {
133             buffer,
134             cap,
135             one_lap,
136             mark_bit,
137             head: CachePadded::new(AtomicUsize::new(head)),
138             tail: CachePadded::new(AtomicUsize::new(tail)),
139             senders: SyncWaker::new(),
140             receivers: SyncWaker::new(),
141             _marker: PhantomData,
142         }
143     }
144 
145     /// Returns a receiver handle to the channel.
receiver(&self) -> Receiver<'_, T>146     pub fn receiver(&self) -> Receiver<'_, T> {
147         Receiver(self)
148     }
149 
150     /// Returns a sender handle to the channel.
sender(&self) -> Sender<'_, T>151     pub fn sender(&self) -> Sender<'_, T> {
152         Sender(self)
153     }
154 
155     /// Attempts to reserve a slot for sending a message.
start_send(&self, token: &mut Token) -> bool156     fn start_send(&self, token: &mut Token) -> bool {
157         let backoff = Backoff::new();
158         let mut tail = self.tail.load(Ordering::Relaxed);
159 
160         loop {
161             // Check if the channel is disconnected.
162             if tail & self.mark_bit != 0 {
163                 token.array.slot = ptr::null();
164                 token.array.stamp = 0;
165                 return true;
166             }
167 
168             // Deconstruct the tail.
169             let index = tail & (self.mark_bit - 1);
170             let lap = tail & !(self.one_lap - 1);
171 
172             // Inspect the corresponding slot.
173             let slot = unsafe { &*self.buffer.add(index) };
174             let stamp = slot.stamp.load(Ordering::Acquire);
175 
176             // If the tail and the stamp match, we may attempt to push.
177             if tail == stamp {
178                 let new_tail = if index + 1 < self.cap {
179                     // Same lap, incremented index.
180                     // Set to `{ lap: lap, mark: 0, index: index + 1 }`.
181                     tail + 1
182                 } else {
183                     // One lap forward, index wraps around to zero.
184                     // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`.
185                     lap.wrapping_add(self.one_lap)
186                 };
187 
188                 // Try moving the tail.
189                 match self.tail.compare_exchange_weak(
190                     tail,
191                     new_tail,
192                     Ordering::SeqCst,
193                     Ordering::Relaxed,
194                 ) {
195                     Ok(_) => {
196                         // Prepare the token for the follow-up call to `write`.
197                         token.array.slot = slot as *const Slot<T> as *const u8;
198                         token.array.stamp = tail + 1;
199                         return true;
200                     }
201                     Err(t) => {
202                         tail = t;
203                         backoff.spin();
204                     }
205                 }
206             } else if stamp.wrapping_add(self.one_lap) == tail + 1 {
207                 atomic::fence(Ordering::SeqCst);
208                 let head = self.head.load(Ordering::Relaxed);
209 
210                 // If the head lags one lap behind the tail as well...
211                 if head.wrapping_add(self.one_lap) == tail {
212                     // ...then the channel is full.
213                     return false;
214                 }
215 
216                 backoff.spin();
217                 tail = self.tail.load(Ordering::Relaxed);
218             } else {
219                 // Snooze because we need to wait for the stamp to get updated.
220                 backoff.snooze();
221                 tail = self.tail.load(Ordering::Relaxed);
222             }
223         }
224     }
225 
226     /// Writes a message into the channel.
write(&self, token: &mut Token, msg: T) -> Result<(), T>227     pub unsafe fn write(&self, token: &mut Token, msg: T) -> Result<(), T> {
228         // If there is no slot, the channel is disconnected.
229         if token.array.slot.is_null() {
230             return Err(msg);
231         }
232 
233         let slot: &Slot<T> = &*(token.array.slot as *const Slot<T>);
234 
235         // Write the message into the slot and update the stamp.
236         slot.msg.get().write(MaybeUninit::new(msg));
237         slot.stamp.store(token.array.stamp, Ordering::Release);
238 
239         // Wake a sleeping receiver.
240         self.receivers.notify();
241         Ok(())
242     }
243 
244     /// Attempts to reserve a slot for receiving a message.
start_recv(&self, token: &mut Token) -> bool245     fn start_recv(&self, token: &mut Token) -> bool {
246         let backoff = Backoff::new();
247         let mut head = self.head.load(Ordering::Relaxed);
248 
249         loop {
250             // Deconstruct the head.
251             let index = head & (self.mark_bit - 1);
252             let lap = head & !(self.one_lap - 1);
253 
254             // Inspect the corresponding slot.
255             let slot = unsafe { &*self.buffer.add(index) };
256             let stamp = slot.stamp.load(Ordering::Acquire);
257 
258             // If the the stamp is ahead of the head by 1, we may attempt to pop.
259             if head + 1 == stamp {
260                 let new = if index + 1 < self.cap {
261                     // Same lap, incremented index.
262                     // Set to `{ lap: lap, mark: 0, index: index + 1 }`.
263                     head + 1
264                 } else {
265                     // One lap forward, index wraps around to zero.
266                     // Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`.
267                     lap.wrapping_add(self.one_lap)
268                 };
269 
270                 // Try moving the head.
271                 match self.head.compare_exchange_weak(
272                     head,
273                     new,
274                     Ordering::SeqCst,
275                     Ordering::Relaxed,
276                 ) {
277                     Ok(_) => {
278                         // Prepare the token for the follow-up call to `read`.
279                         token.array.slot = slot as *const Slot<T> as *const u8;
280                         token.array.stamp = head.wrapping_add(self.one_lap);
281                         return true;
282                     }
283                     Err(h) => {
284                         head = h;
285                         backoff.spin();
286                     }
287                 }
288             } else if stamp == head {
289                 atomic::fence(Ordering::SeqCst);
290                 let tail = self.tail.load(Ordering::Relaxed);
291 
292                 // If the tail equals the head, that means the channel is empty.
293                 if (tail & !self.mark_bit) == head {
294                     // If the channel is disconnected...
295                     if tail & self.mark_bit != 0 {
296                         // ...then receive an error.
297                         token.array.slot = ptr::null();
298                         token.array.stamp = 0;
299                         return true;
300                     } else {
301                         // Otherwise, the receive operation is not ready.
302                         return false;
303                     }
304                 }
305 
306                 backoff.spin();
307                 head = self.head.load(Ordering::Relaxed);
308             } else {
309                 // Snooze because we need to wait for the stamp to get updated.
310                 backoff.snooze();
311                 head = self.head.load(Ordering::Relaxed);
312             }
313         }
314     }
315 
316     /// Reads a message from the channel.
read(&self, token: &mut Token) -> Result<T, ()>317     pub unsafe fn read(&self, token: &mut Token) -> Result<T, ()> {
318         if token.array.slot.is_null() {
319             // The channel is disconnected.
320             return Err(());
321         }
322 
323         let slot: &Slot<T> = &*(token.array.slot as *const Slot<T>);
324 
325         // Read the message from the slot and update the stamp.
326         let msg = slot.msg.get().read().assume_init();
327         slot.stamp.store(token.array.stamp, Ordering::Release);
328 
329         // Wake a sleeping sender.
330         self.senders.notify();
331         Ok(msg)
332     }
333 
334     /// Attempts to send a message into the channel.
try_send(&self, msg: T) -> Result<(), TrySendError<T>>335     pub fn try_send(&self, msg: T) -> Result<(), TrySendError<T>> {
336         let token = &mut Token::default();
337         if self.start_send(token) {
338             unsafe { self.write(token, msg).map_err(TrySendError::Disconnected) }
339         } else {
340             Err(TrySendError::Full(msg))
341         }
342     }
343 
344     /// Sends a message into the channel.
send(&self, msg: T, deadline: Option<Instant>) -> Result<(), SendTimeoutError<T>>345     pub fn send(&self, msg: T, deadline: Option<Instant>) -> Result<(), SendTimeoutError<T>> {
346         let token = &mut Token::default();
347         loop {
348             // Try sending a message several times.
349             let backoff = Backoff::new();
350             loop {
351                 if self.start_send(token) {
352                     let res = unsafe { self.write(token, msg) };
353                     return res.map_err(SendTimeoutError::Disconnected);
354                 }
355 
356                 if backoff.is_completed() {
357                     break;
358                 } else {
359                     backoff.snooze();
360                 }
361             }
362 
363             if let Some(d) = deadline {
364                 if Instant::now() >= d {
365                     return Err(SendTimeoutError::Timeout(msg));
366                 }
367             }
368 
369             Context::with(|cx| {
370                 // Prepare for blocking until a receiver wakes us up.
371                 let oper = Operation::hook(token);
372                 self.senders.register(oper, cx);
373 
374                 // Has the channel become ready just now?
375                 if !self.is_full() || self.is_disconnected() {
376                     let _ = cx.try_select(Selected::Aborted);
377                 }
378 
379                 // Block the current thread.
380                 let sel = cx.wait_until(deadline);
381 
382                 match sel {
383                     Selected::Waiting => unreachable!(),
384                     Selected::Aborted | Selected::Disconnected => {
385                         self.senders.unregister(oper).unwrap();
386                     }
387                     Selected::Operation(_) => {}
388                 }
389             });
390         }
391     }
392 
393     /// Attempts to receive a message without blocking.
try_recv(&self) -> Result<T, TryRecvError>394     pub fn try_recv(&self) -> Result<T, TryRecvError> {
395         let token = &mut Token::default();
396 
397         if self.start_recv(token) {
398             unsafe { self.read(token).map_err(|_| TryRecvError::Disconnected) }
399         } else {
400             Err(TryRecvError::Empty)
401         }
402     }
403 
404     /// Receives a message from the channel.
recv(&self, deadline: Option<Instant>) -> Result<T, RecvTimeoutError>405     pub fn recv(&self, deadline: Option<Instant>) -> Result<T, RecvTimeoutError> {
406         let token = &mut Token::default();
407         loop {
408             // Try receiving a message several times.
409             let backoff = Backoff::new();
410             loop {
411                 if self.start_recv(token) {
412                     let res = unsafe { self.read(token) };
413                     return res.map_err(|_| RecvTimeoutError::Disconnected);
414                 }
415 
416                 if backoff.is_completed() {
417                     break;
418                 } else {
419                     backoff.snooze();
420                 }
421             }
422 
423             if let Some(d) = deadline {
424                 if Instant::now() >= d {
425                     return Err(RecvTimeoutError::Timeout);
426                 }
427             }
428 
429             Context::with(|cx| {
430                 // Prepare for blocking until a sender wakes us up.
431                 let oper = Operation::hook(token);
432                 self.receivers.register(oper, cx);
433 
434                 // Has the channel become ready just now?
435                 if !self.is_empty() || self.is_disconnected() {
436                     let _ = cx.try_select(Selected::Aborted);
437                 }
438 
439                 // Block the current thread.
440                 let sel = cx.wait_until(deadline);
441 
442                 match sel {
443                     Selected::Waiting => unreachable!(),
444                     Selected::Aborted | Selected::Disconnected => {
445                         self.receivers.unregister(oper).unwrap();
446                         // If the channel was disconnected, we still have to check for remaining
447                         // messages.
448                     }
449                     Selected::Operation(_) => {}
450                 }
451             });
452         }
453     }
454 
455     /// Returns the current number of messages inside the channel.
len(&self) -> usize456     pub fn len(&self) -> usize {
457         loop {
458             // Load the tail, then load the head.
459             let tail = self.tail.load(Ordering::SeqCst);
460             let head = self.head.load(Ordering::SeqCst);
461 
462             // If the tail didn't change, we've got consistent values to work with.
463             if self.tail.load(Ordering::SeqCst) == tail {
464                 let hix = head & (self.mark_bit - 1);
465                 let tix = tail & (self.mark_bit - 1);
466 
467                 return if hix < tix {
468                     tix - hix
469                 } else if hix > tix {
470                     self.cap - hix + tix
471                 } else if (tail & !self.mark_bit) == head {
472                     0
473                 } else {
474                     self.cap
475                 };
476             }
477         }
478     }
479 
480     /// Returns the capacity of the channel.
capacity(&self) -> Option<usize>481     pub fn capacity(&self) -> Option<usize> {
482         Some(self.cap)
483     }
484 
485     /// Disconnects the channel and wakes up all blocked senders and receivers.
486     ///
487     /// Returns `true` if this call disconnected the channel.
disconnect(&self) -> bool488     pub fn disconnect(&self) -> bool {
489         let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst);
490 
491         if tail & self.mark_bit == 0 {
492             self.senders.disconnect();
493             self.receivers.disconnect();
494             true
495         } else {
496             false
497         }
498     }
499 
500     /// Returns `true` if the channel is disconnected.
is_disconnected(&self) -> bool501     pub fn is_disconnected(&self) -> bool {
502         self.tail.load(Ordering::SeqCst) & self.mark_bit != 0
503     }
504 
505     /// Returns `true` if the channel is empty.
is_empty(&self) -> bool506     pub fn is_empty(&self) -> bool {
507         let head = self.head.load(Ordering::SeqCst);
508         let tail = self.tail.load(Ordering::SeqCst);
509 
510         // Is the tail equal to the head?
511         //
512         // Note: If the head changes just before we load the tail, that means there was a moment
513         // when the channel was not empty, so it is safe to just return `false`.
514         (tail & !self.mark_bit) == head
515     }
516 
517     /// Returns `true` if the channel is full.
is_full(&self) -> bool518     pub fn is_full(&self) -> bool {
519         let tail = self.tail.load(Ordering::SeqCst);
520         let head = self.head.load(Ordering::SeqCst);
521 
522         // Is the head lagging one lap behind tail?
523         //
524         // Note: If the tail changes just before we load the head, that means there was a moment
525         // when the channel was not full, so it is safe to just return `false`.
526         head.wrapping_add(self.one_lap) == tail & !self.mark_bit
527     }
528 }
529 
530 impl<T> Drop for Channel<T> {
drop(&mut self)531     fn drop(&mut self) {
532         // Get the index of the head.
533         let hix = self.head.load(Ordering::Relaxed) & (self.mark_bit - 1);
534 
535         // Loop over all slots that hold a message and drop them.
536         for i in 0..self.len() {
537             // Compute the index of the next slot holding a message.
538             let index = if hix + i < self.cap {
539                 hix + i
540             } else {
541                 hix + i - self.cap
542             };
543 
544             unsafe {
545                 let p = {
546                     let slot = &mut *self.buffer.add(index);
547                     let msg = &mut *slot.msg.get();
548                     msg.as_mut_ptr()
549                 };
550                 p.drop_in_place();
551             }
552         }
553 
554         // Finally, deallocate the buffer, but don't run any destructors.
555         unsafe {
556             // Create a slice from the buffer to make
557             // a fat pointer. Then, use Box::from_raw
558             // to deallocate it.
559             let ptr = std::slice::from_raw_parts_mut(self.buffer, self.cap) as *mut [Slot<T>];
560             Box::from_raw(ptr);
561         }
562     }
563 }
564 
565 /// Receiver handle to a channel.
566 pub struct Receiver<'a, T>(&'a Channel<T>);
567 
568 /// Sender handle to a channel.
569 pub struct Sender<'a, T>(&'a Channel<T>);
570 
571 impl<T> SelectHandle for Receiver<'_, T> {
try_select(&self, token: &mut Token) -> bool572     fn try_select(&self, token: &mut Token) -> bool {
573         self.0.start_recv(token)
574     }
575 
deadline(&self) -> Option<Instant>576     fn deadline(&self) -> Option<Instant> {
577         None
578     }
579 
register(&self, oper: Operation, cx: &Context) -> bool580     fn register(&self, oper: Operation, cx: &Context) -> bool {
581         self.0.receivers.register(oper, cx);
582         self.is_ready()
583     }
584 
unregister(&self, oper: Operation)585     fn unregister(&self, oper: Operation) {
586         self.0.receivers.unregister(oper);
587     }
588 
accept(&self, token: &mut Token, _cx: &Context) -> bool589     fn accept(&self, token: &mut Token, _cx: &Context) -> bool {
590         self.try_select(token)
591     }
592 
is_ready(&self) -> bool593     fn is_ready(&self) -> bool {
594         !self.0.is_empty() || self.0.is_disconnected()
595     }
596 
watch(&self, oper: Operation, cx: &Context) -> bool597     fn watch(&self, oper: Operation, cx: &Context) -> bool {
598         self.0.receivers.watch(oper, cx);
599         self.is_ready()
600     }
601 
unwatch(&self, oper: Operation)602     fn unwatch(&self, oper: Operation) {
603         self.0.receivers.unwatch(oper);
604     }
605 }
606 
607 impl<T> SelectHandle for Sender<'_, T> {
try_select(&self, token: &mut Token) -> bool608     fn try_select(&self, token: &mut Token) -> bool {
609         self.0.start_send(token)
610     }
611 
deadline(&self) -> Option<Instant>612     fn deadline(&self) -> Option<Instant> {
613         None
614     }
615 
register(&self, oper: Operation, cx: &Context) -> bool616     fn register(&self, oper: Operation, cx: &Context) -> bool {
617         self.0.senders.register(oper, cx);
618         self.is_ready()
619     }
620 
unregister(&self, oper: Operation)621     fn unregister(&self, oper: Operation) {
622         self.0.senders.unregister(oper);
623     }
624 
accept(&self, token: &mut Token, _cx: &Context) -> bool625     fn accept(&self, token: &mut Token, _cx: &Context) -> bool {
626         self.try_select(token)
627     }
628 
is_ready(&self) -> bool629     fn is_ready(&self) -> bool {
630         !self.0.is_full() || self.0.is_disconnected()
631     }
632 
watch(&self, oper: Operation, cx: &Context) -> bool633     fn watch(&self, oper: Operation, cx: &Context) -> bool {
634         self.0.senders.watch(oper, cx);
635         self.is_ready()
636     }
637 
unwatch(&self, oper: Operation)638     fn unwatch(&self, oper: Operation) {
639         self.0.senders.unwatch(oper);
640     }
641 }
642