1 //! `TcpStream` split support.
2 //!
3 //! A `TcpStream` can be split into a `ReadHalf` and a
4 //! `WriteHalf` with the `TcpStream::split` method. `ReadHalf`
5 //! implements `AsyncRead` while `WriteHalf` implements `AsyncWrite`.
6 //!
7 //! Compared to the generic split of `AsyncRead + AsyncWrite`, this specialized
8 //! split has no associated overhead and enforces all invariants at the type
9 //! level.
10 
11 use crate::future::poll_fn;
12 use crate::io::{AsyncRead, AsyncWrite, ReadBuf};
13 use crate::net::TcpStream;
14 
15 use std::io;
16 use std::net::Shutdown;
17 use std::pin::Pin;
18 use std::task::{Context, Poll};
19 
20 /// Borrowed read half of a [`TcpStream`], created by [`split`].
21 ///
22 /// Reading from a `ReadHalf` is usually done using the convenience methods found on the
23 /// [`AsyncReadExt`] trait.
24 ///
25 /// [`TcpStream`]: TcpStream
26 /// [`split`]: TcpStream::split()
27 /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
28 #[derive(Debug)]
29 pub struct ReadHalf<'a>(&'a TcpStream);
30 
31 /// Borrowed write half of a [`TcpStream`], created by [`split`].
32 ///
33 /// Note that in the [`AsyncWrite`] implemenation of this type, [`poll_shutdown`] will
34 /// shut down the TCP stream in the write direction.
35 ///
36 /// Writing to an `WriteHalf` is usually done using the convenience methods found
37 /// on the [`AsyncWriteExt`] trait.
38 ///
39 /// [`TcpStream`]: TcpStream
40 /// [`split`]: TcpStream::split()
41 /// [`AsyncWrite`]: trait@crate::io::AsyncWrite
42 /// [`poll_shutdown`]: fn@crate::io::AsyncWrite::poll_shutdown
43 /// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
44 #[derive(Debug)]
45 pub struct WriteHalf<'a>(&'a TcpStream);
46 
split(stream: &mut TcpStream) -> (ReadHalf<'_>, WriteHalf<'_>)47 pub(crate) fn split(stream: &mut TcpStream) -> (ReadHalf<'_>, WriteHalf<'_>) {
48     (ReadHalf(&*stream), WriteHalf(&*stream))
49 }
50 
51 impl ReadHalf<'_> {
52     /// Attempt to receive data on the socket, without removing that data from
53     /// the queue, registering the current task for wakeup if data is not yet
54     /// available.
55     ///
56     /// Note that on multiple calls to `poll_peek` or `poll_read`, only the
57     /// `Waker` from the `Context` passed to the most recent call is scheduled
58     /// to receive a wakeup.
59     ///
60     /// See the [`TcpStream::poll_peek`] level documenation for more details.
61     ///
62     /// # Examples
63     ///
64     /// ```no_run
65     /// use tokio::io::{self, ReadBuf};
66     /// use tokio::net::TcpStream;
67     ///
68     /// use futures::future::poll_fn;
69     ///
70     /// #[tokio::main]
71     /// async fn main() -> io::Result<()> {
72     ///     let mut stream = TcpStream::connect("127.0.0.1:8000").await?;
73     ///     let (mut read_half, _) = stream.split();
74     ///     let mut buf = [0; 10];
75     ///     let mut buf = ReadBuf::new(&mut buf);
76     ///
77     ///     poll_fn(|cx| {
78     ///         read_half.poll_peek(cx, &mut buf)
79     ///     }).await?;
80     ///
81     ///     Ok(())
82     /// }
83     /// ```
84     ///
85     /// [`TcpStream::poll_peek`]: TcpStream::poll_peek
poll_peek( &mut self, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<usize>>86     pub fn poll_peek(
87         &mut self,
88         cx: &mut Context<'_>,
89         buf: &mut ReadBuf<'_>,
90     ) -> Poll<io::Result<usize>> {
91         self.0.poll_peek(cx, buf)
92     }
93 
94     /// Receives data on the socket from the remote address to which it is
95     /// connected, without removing that data from the queue. On success,
96     /// returns the number of bytes peeked.
97     ///
98     /// See the [`TcpStream::peek`] level documenation for more details.
99     ///
100     /// [`TcpStream::peek`]: TcpStream::peek
101     ///
102     /// # Examples
103     ///
104     /// ```no_run
105     /// use tokio::net::TcpStream;
106     /// use tokio::io::AsyncReadExt;
107     /// use std::error::Error;
108     ///
109     /// #[tokio::main]
110     /// async fn main() -> Result<(), Box<dyn Error>> {
111     ///     // Connect to a peer
112     ///     let mut stream = TcpStream::connect("127.0.0.1:8080").await?;
113     ///     let (mut read_half, _) = stream.split();
114     ///
115     ///     let mut b1 = [0; 10];
116     ///     let mut b2 = [0; 10];
117     ///
118     ///     // Peek at the data
119     ///     let n = read_half.peek(&mut b1).await?;
120     ///
121     ///     // Read the data
122     ///     assert_eq!(n, read_half.read(&mut b2[..n]).await?);
123     ///     assert_eq!(&b1[..n], &b2[..n]);
124     ///
125     ///     Ok(())
126     /// }
127     /// ```
128     ///
129     /// The [`read`] method is defined on the [`AsyncReadExt`] trait.
130     ///
131     /// [`read`]: fn@crate::io::AsyncReadExt::read
132     /// [`AsyncReadExt`]: trait@crate::io::AsyncReadExt
peek(&mut self, buf: &mut [u8]) -> io::Result<usize>133     pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
134         let mut buf = ReadBuf::new(buf);
135         poll_fn(|cx| self.poll_peek(cx, &mut buf)).await
136     }
137 }
138 
139 impl AsyncRead for ReadHalf<'_> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>140     fn poll_read(
141         self: Pin<&mut Self>,
142         cx: &mut Context<'_>,
143         buf: &mut ReadBuf<'_>,
144     ) -> Poll<io::Result<()>> {
145         self.0.poll_read_priv(cx, buf)
146     }
147 }
148 
149 impl AsyncWrite for WriteHalf<'_> {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>150     fn poll_write(
151         self: Pin<&mut Self>,
152         cx: &mut Context<'_>,
153         buf: &[u8],
154     ) -> Poll<io::Result<usize>> {
155         self.0.poll_write_priv(cx, buf)
156     }
157 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<io::Result<usize>>158     fn poll_write_vectored(
159         self: Pin<&mut Self>,
160         cx: &mut Context<'_>,
161         bufs: &[io::IoSlice<'_>],
162     ) -> Poll<io::Result<usize>> {
163         self.0.poll_write_vectored_priv(cx, bufs)
164     }
165 
is_write_vectored(&self) -> bool166     fn is_write_vectored(&self) -> bool {
167         self.0.is_write_vectored()
168     }
169 
170     #[inline]
poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>>171     fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
172         // tcp flush is a no-op
173         Poll::Ready(Ok(()))
174     }
175 
176     // `poll_shutdown` on a write half shutdowns the stream in the "write" direction.
poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>>177     fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
178         self.0.shutdown_std(Shutdown::Write).into()
179     }
180 }
181 
182 impl AsRef<TcpStream> for ReadHalf<'_> {
as_ref(&self) -> &TcpStream183     fn as_ref(&self) -> &TcpStream {
184         self.0
185     }
186 }
187 
188 impl AsRef<TcpStream> for WriteHalf<'_> {
as_ref(&self) -> &TcpStream189     fn as_ref(&self) -> &TcpStream {
190         self.0
191     }
192 }
193