1 use futures_core::ready; 2 use futures_core::task::{Context, Poll}; 3 #[cfg(feature = "read-initializer")] 4 use futures_io::Initializer; 5 use futures_io::{AsyncRead, AsyncBufRead}; 6 use pin_project_lite::pin_project; 7 use std::{cmp, io}; 8 use std::pin::Pin; 9 10 pin_project! { 11 /// Reader for the [`take`](super::AsyncReadExt::take) method. 12 #[derive(Debug)] 13 #[must_use = "readers do nothing unless you `.await` or poll them"] 14 pub struct Take<R> { 15 #[pin] 16 inner: R, 17 // Add '_' to avoid conflicts with `limit` method. 18 limit_: u64, 19 } 20 } 21 22 impl<R: AsyncRead> Take<R> { new(inner: R, limit: u64) -> Self23 pub(super) fn new(inner: R, limit: u64) -> Self { 24 Self { inner, limit_: limit } 25 } 26 27 /// Returns the remaining number of bytes that can be 28 /// read before this instance will return EOF. 29 /// 30 /// # Note 31 /// 32 /// This instance may reach `EOF` after reading fewer bytes than indicated by 33 /// this method if the underlying [`AsyncRead`] instance reaches EOF. 34 /// 35 /// # Examples 36 /// 37 /// ``` 38 /// # futures::executor::block_on(async { 39 /// use futures::io::{AsyncReadExt, Cursor}; 40 /// 41 /// let reader = Cursor::new(&b"12345678"[..]); 42 /// let mut buffer = [0; 2]; 43 /// 44 /// let mut take = reader.take(4); 45 /// let n = take.read(&mut buffer).await?; 46 /// 47 /// assert_eq!(take.limit(), 2); 48 /// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap(); 49 /// ``` limit(&self) -> u6450 pub fn limit(&self) -> u64 { 51 self.limit_ 52 } 53 54 /// Sets the number of bytes that can be read before this instance will 55 /// return EOF. This is the same as constructing a new `Take` instance, so 56 /// the amount of bytes read and the previous limit value don't matter when 57 /// calling this method. 58 /// 59 /// # Examples 60 /// 61 /// ``` 62 /// # futures::executor::block_on(async { 63 /// use futures::io::{AsyncReadExt, Cursor}; 64 /// 65 /// let reader = Cursor::new(&b"12345678"[..]); 66 /// let mut buffer = [0; 4]; 67 /// 68 /// let mut take = reader.take(4); 69 /// let n = take.read(&mut buffer).await?; 70 /// 71 /// assert_eq!(n, 4); 72 /// assert_eq!(take.limit(), 0); 73 /// 74 /// take.set_limit(10); 75 /// let n = take.read(&mut buffer).await?; 76 /// assert_eq!(n, 4); 77 /// 78 /// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap(); 79 /// ``` set_limit(&mut self, limit: u64)80 pub fn set_limit(&mut self, limit: u64) { 81 self.limit_ = limit 82 } 83 84 delegate_access_inner!(inner, R, ()); 85 } 86 87 impl<R: AsyncRead> AsyncRead for Take<R> { poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<Result<usize, io::Error>>88 fn poll_read( 89 self: Pin<&mut Self>, 90 cx: &mut Context<'_>, 91 buf: &mut [u8], 92 ) -> Poll<Result<usize, io::Error>> { 93 let this = self.project(); 94 95 if *this.limit_ == 0 { 96 return Poll::Ready(Ok(0)); 97 } 98 99 let max = cmp::min(buf.len() as u64, *this.limit_) as usize; 100 let n = ready!(this.inner.poll_read(cx, &mut buf[..max]))?; 101 *this.limit_ -= n as u64; 102 Poll::Ready(Ok(n)) 103 } 104 105 #[cfg(feature = "read-initializer")] initializer(&self) -> Initializer106 unsafe fn initializer(&self) -> Initializer { 107 self.inner.initializer() 108 } 109 } 110 111 impl<R: AsyncBufRead> AsyncBufRead for Take<R> { poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>112 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { 113 let this = self.project(); 114 115 // Don't call into inner reader at all at EOF because it may still block 116 if *this.limit_ == 0 { 117 return Poll::Ready(Ok(&[])); 118 } 119 120 let buf = ready!(this.inner.poll_fill_buf(cx)?); 121 let cap = cmp::min(buf.len() as u64, *this.limit_) as usize; 122 Poll::Ready(Ok(&buf[..cap])) 123 } 124 consume(self: Pin<&mut Self>, amt: usize)125 fn consume(self: Pin<&mut Self>, amt: usize) { 126 let this = self.project(); 127 128 // Don't let callers reset the limit by passing an overlarge value 129 let amt = cmp::min(amt as u64, *this.limit_) as usize; 130 *this.limit_ -= amt as u64; 131 this.inner.consume(amt); 132 } 133 } 134