1 //! Split a single value implementing `AsyncRead + AsyncWrite` into separate 2 //! `AsyncRead` and `AsyncWrite` handles. 3 //! 4 //! To restore this read/write object from its `split::ReadHalf` and 5 //! `split::WriteHalf` use `unsplit`. 6 7 use crate::io::{AsyncRead, AsyncWrite, ReadBuf}; 8 9 use std::cell::UnsafeCell; 10 use std::fmt; 11 use std::io; 12 use std::pin::Pin; 13 use std::sync::atomic::AtomicBool; 14 use std::sync::atomic::Ordering::{Acquire, Release}; 15 use std::sync::Arc; 16 use std::task::{Context, Poll}; 17 18 cfg_io_util! { 19 /// The readable half of a value returned from [`split`](split()). 20 pub struct ReadHalf<T> { 21 inner: Arc<Inner<T>>, 22 } 23 24 /// The writable half of a value returned from [`split`](split()). 25 pub struct WriteHalf<T> { 26 inner: Arc<Inner<T>>, 27 } 28 29 /// Splits a single value implementing `AsyncRead + AsyncWrite` into separate 30 /// `AsyncRead` and `AsyncWrite` handles. 31 /// 32 /// To restore this read/write object from its `ReadHalf` and 33 /// `WriteHalf` use [`unsplit`](ReadHalf::unsplit()). 34 pub fn split<T>(stream: T) -> (ReadHalf<T>, WriteHalf<T>) 35 where 36 T: AsyncRead + AsyncWrite, 37 { 38 let inner = Arc::new(Inner { 39 locked: AtomicBool::new(false), 40 stream: UnsafeCell::new(stream), 41 }); 42 43 let rd = ReadHalf { 44 inner: inner.clone(), 45 }; 46 47 let wr = WriteHalf { inner }; 48 49 (rd, wr) 50 } 51 } 52 53 struct Inner<T> { 54 locked: AtomicBool, 55 stream: UnsafeCell<T>, 56 } 57 58 struct Guard<'a, T> { 59 inner: &'a Inner<T>, 60 } 61 62 impl<T> ReadHalf<T> { 63 /// Checks if this `ReadHalf` and some `WriteHalf` were split from the same 64 /// stream. is_pair_of(&self, other: &WriteHalf<T>) -> bool65 pub fn is_pair_of(&self, other: &WriteHalf<T>) -> bool { 66 other.is_pair_of(&self) 67 } 68 69 /// Reunites with a previously split `WriteHalf`. 70 /// 71 /// # Panics 72 /// 73 /// If this `ReadHalf` and the given `WriteHalf` do not originate from the 74 /// same `split` operation this method will panic. 75 /// This can be checked ahead of time by comparing the stream ID 76 /// of the two halves. unsplit(self, wr: WriteHalf<T>) -> T77 pub fn unsplit(self, wr: WriteHalf<T>) -> T { 78 if self.is_pair_of(&wr) { 79 drop(wr); 80 81 let inner = Arc::try_unwrap(self.inner) 82 .ok() 83 .expect("`Arc::try_unwrap` failed"); 84 85 inner.stream.into_inner() 86 } else { 87 panic!("Unrelated `split::Write` passed to `split::Read::unsplit`.") 88 } 89 } 90 } 91 92 impl<T> WriteHalf<T> { 93 /// Check if this `WriteHalf` and some `ReadHalf` were split from the same 94 /// stream. is_pair_of(&self, other: &ReadHalf<T>) -> bool95 pub fn is_pair_of(&self, other: &ReadHalf<T>) -> bool { 96 Arc::ptr_eq(&self.inner, &other.inner) 97 } 98 } 99 100 impl<T: AsyncRead> AsyncRead for ReadHalf<T> { poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>101 fn poll_read( 102 self: Pin<&mut Self>, 103 cx: &mut Context<'_>, 104 buf: &mut ReadBuf<'_>, 105 ) -> Poll<io::Result<()>> { 106 let mut inner = ready!(self.inner.poll_lock(cx)); 107 inner.stream_pin().poll_read(cx, buf) 108 } 109 } 110 111 impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> { poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>112 fn poll_write( 113 self: Pin<&mut Self>, 114 cx: &mut Context<'_>, 115 buf: &[u8], 116 ) -> Poll<Result<usize, io::Error>> { 117 let mut inner = ready!(self.inner.poll_lock(cx)); 118 inner.stream_pin().poll_write(cx, buf) 119 } 120 poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>121 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { 122 let mut inner = ready!(self.inner.poll_lock(cx)); 123 inner.stream_pin().poll_flush(cx) 124 } 125 poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>126 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> { 127 let mut inner = ready!(self.inner.poll_lock(cx)); 128 inner.stream_pin().poll_shutdown(cx) 129 } 130 } 131 132 impl<T> Inner<T> { poll_lock(&self, cx: &mut Context<'_>) -> Poll<Guard<'_, T>>133 fn poll_lock(&self, cx: &mut Context<'_>) -> Poll<Guard<'_, T>> { 134 if self 135 .locked 136 .compare_exchange(false, true, Acquire, Acquire) 137 .is_ok() 138 { 139 Poll::Ready(Guard { inner: self }) 140 } else { 141 // Spin... but investigate a better strategy 142 143 std::thread::yield_now(); 144 cx.waker().wake_by_ref(); 145 146 Poll::Pending 147 } 148 } 149 } 150 151 impl<T> Guard<'_, T> { stream_pin(&mut self) -> Pin<&mut T>152 fn stream_pin(&mut self) -> Pin<&mut T> { 153 // safety: the stream is pinned in `Arc` and the `Guard` ensures mutual 154 // exclusion. 155 unsafe { Pin::new_unchecked(&mut *self.inner.stream.get()) } 156 } 157 } 158 159 impl<T> Drop for Guard<'_, T> { drop(&mut self)160 fn drop(&mut self) { 161 self.inner.locked.store(false, Release); 162 } 163 } 164 165 unsafe impl<T: Send> Send for ReadHalf<T> {} 166 unsafe impl<T: Send> Send for WriteHalf<T> {} 167 unsafe impl<T: Sync> Sync for ReadHalf<T> {} 168 unsafe impl<T: Sync> Sync for WriteHalf<T> {} 169 170 impl<T: fmt::Debug> fmt::Debug for ReadHalf<T> { fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result171 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 172 fmt.debug_struct("split::ReadHalf").finish() 173 } 174 } 175 176 impl<T: fmt::Debug> fmt::Debug for WriteHalf<T> { fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result177 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { 178 fmt.debug_struct("split::WriteHalf").finish() 179 } 180 } 181