1 //! Split a single value implementing `AsyncRead + AsyncWrite` into separate
2 //! `AsyncRead` and `AsyncWrite` handles.
3 //!
4 //! To restore this read/write object from its `split::ReadHalf` and
5 //! `split::WriteHalf` use `unsplit`.
6 
7 use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
8 
9 use std::cell::UnsafeCell;
10 use std::fmt;
11 use std::io;
12 use std::pin::Pin;
13 use std::sync::atomic::AtomicBool;
14 use std::sync::atomic::Ordering::{Acquire, Release};
15 use std::sync::Arc;
16 use std::task::{Context, Poll};
17 
18 cfg_io_util! {
19     /// The readable half of a value returned from [`split`](split()).
20     pub struct ReadHalf<T> {
21         inner: Arc<Inner<T>>,
22     }
23 
24     /// The writable half of a value returned from [`split`](split()).
25     pub struct WriteHalf<T> {
26         inner: Arc<Inner<T>>,
27     }
28 
29     /// Splits a single value implementing `AsyncRead + AsyncWrite` into separate
30     /// `AsyncRead` and `AsyncWrite` handles.
31     ///
32     /// To restore this read/write object from its `ReadHalf` and
33     /// `WriteHalf` use [`unsplit`](ReadHalf::unsplit()).
34     pub fn split<T>(stream: T) -> (ReadHalf<T>, WriteHalf<T>)
35     where
36         T: AsyncRead + AsyncWrite,
37     {
38         let inner = Arc::new(Inner {
39             locked: AtomicBool::new(false),
40             stream: UnsafeCell::new(stream),
41         });
42 
43         let rd = ReadHalf {
44             inner: inner.clone(),
45         };
46 
47         let wr = WriteHalf { inner };
48 
49         (rd, wr)
50     }
51 }
52 
53 struct Inner<T> {
54     locked: AtomicBool,
55     stream: UnsafeCell<T>,
56 }
57 
58 struct Guard<'a, T> {
59     inner: &'a Inner<T>,
60 }
61 
62 impl<T> ReadHalf<T> {
63     /// Checks if this `ReadHalf` and some `WriteHalf` were split from the same
64     /// stream.
is_pair_of(&self, other: &WriteHalf<T>) -> bool65     pub fn is_pair_of(&self, other: &WriteHalf<T>) -> bool {
66         other.is_pair_of(&self)
67     }
68 
69     /// Reunites with a previously split `WriteHalf`.
70     ///
71     /// # Panics
72     ///
73     /// If this `ReadHalf` and the given `WriteHalf` do not originate from the
74     /// same `split` operation this method will panic.
75     /// This can be checked ahead of time by comparing the stream ID
76     /// of the two halves.
unsplit(self, wr: WriteHalf<T>) -> T77     pub fn unsplit(self, wr: WriteHalf<T>) -> T {
78         if self.is_pair_of(&wr) {
79             drop(wr);
80 
81             let inner = Arc::try_unwrap(self.inner)
82                 .ok()
83                 .expect("`Arc::try_unwrap` failed");
84 
85             inner.stream.into_inner()
86         } else {
87             panic!("Unrelated `split::Write` passed to `split::Read::unsplit`.")
88         }
89     }
90 }
91 
92 impl<T> WriteHalf<T> {
93     /// Check if this `WriteHalf` and some `ReadHalf` were split from the same
94     /// stream.
is_pair_of(&self, other: &ReadHalf<T>) -> bool95     pub fn is_pair_of(&self, other: &ReadHalf<T>) -> bool {
96         Arc::ptr_eq(&self.inner, &other.inner)
97     }
98 }
99 
100 impl<T: AsyncRead> AsyncRead for ReadHalf<T> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>101     fn poll_read(
102         self: Pin<&mut Self>,
103         cx: &mut Context<'_>,
104         buf: &mut ReadBuf<'_>,
105     ) -> Poll<io::Result<()>> {
106         let mut inner = ready!(self.inner.poll_lock(cx));
107         inner.stream_pin().poll_read(cx, buf)
108     }
109 }
110 
111 impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>112     fn poll_write(
113         self: Pin<&mut Self>,
114         cx: &mut Context<'_>,
115         buf: &[u8],
116     ) -> Poll<Result<usize, io::Error>> {
117         let mut inner = ready!(self.inner.poll_lock(cx));
118         inner.stream_pin().poll_write(cx, buf)
119     }
120 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>121     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
122         let mut inner = ready!(self.inner.poll_lock(cx));
123         inner.stream_pin().poll_flush(cx)
124     }
125 
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>126     fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
127         let mut inner = ready!(self.inner.poll_lock(cx));
128         inner.stream_pin().poll_shutdown(cx)
129     }
130 }
131 
132 impl<T> Inner<T> {
poll_lock(&self, cx: &mut Context<'_>) -> Poll<Guard<'_, T>>133     fn poll_lock(&self, cx: &mut Context<'_>) -> Poll<Guard<'_, T>> {
134         if self
135             .locked
136             .compare_exchange(false, true, Acquire, Acquire)
137             .is_ok()
138         {
139             Poll::Ready(Guard { inner: self })
140         } else {
141             // Spin... but investigate a better strategy
142 
143             std::thread::yield_now();
144             cx.waker().wake_by_ref();
145 
146             Poll::Pending
147         }
148     }
149 }
150 
151 impl<T> Guard<'_, T> {
stream_pin(&mut self) -> Pin<&mut T>152     fn stream_pin(&mut self) -> Pin<&mut T> {
153         // safety: the stream is pinned in `Arc` and the `Guard` ensures mutual
154         // exclusion.
155         unsafe { Pin::new_unchecked(&mut *self.inner.stream.get()) }
156     }
157 }
158 
159 impl<T> Drop for Guard<'_, T> {
drop(&mut self)160     fn drop(&mut self) {
161         self.inner.locked.store(false, Release);
162     }
163 }
164 
165 unsafe impl<T: Send> Send for ReadHalf<T> {}
166 unsafe impl<T: Send> Send for WriteHalf<T> {}
167 unsafe impl<T: Sync> Sync for ReadHalf<T> {}
168 unsafe impl<T: Sync> Sync for WriteHalf<T> {}
169 
170 impl<T: fmt::Debug> fmt::Debug for ReadHalf<T> {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result171     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
172         fmt.debug_struct("split::ReadHalf").finish()
173     }
174 }
175 
176 impl<T: fmt::Debug> fmt::Debug for WriteHalf<T> {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result177     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
178         fmt.debug_struct("split::WriteHalf").finish()
179     }
180 }
181