1 extern crate proc_macro;
2 use proc_macro::TokenStream;
3 use proc_macro2::{Delimiter, Group, TokenStream as TokenStream2, TokenTree};
4 use quote::quote;
5 use syn::visit_mut::VisitMut;
6 
7 struct Scrub {
8     is_xforming: bool,
9     is_try: bool,
10     unit: Box<syn::Expr>,
11     num_yield: u32,
12 }
13 
parse_input(input: TokenStream) -> syn::Result<Vec<syn::Stmt>>14 fn parse_input(input: TokenStream) -> syn::Result<Vec<syn::Stmt>> {
15     let input = replace_for_await(input.into());
16     // syn does not provide a way to parse `Vec<Stmt>` directly from `TokenStream`,
17     // so wrap input in a brace and then parse it as a block.
18     let input = TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Brace, input)));
19     let syn::Block { stmts, .. } = syn::parse2(input)?;
20 
21     Ok(stmts)
22 }
23 
24 impl VisitMut for Scrub {
visit_expr_mut(&mut self, i: &mut syn::Expr)25     fn visit_expr_mut(&mut self, i: &mut syn::Expr) {
26         if !self.is_xforming {
27             syn::visit_mut::visit_expr_mut(self, i);
28             return;
29         }
30 
31         match i {
32             syn::Expr::Yield(yield_expr) => {
33                 self.num_yield += 1;
34 
35                 let value_expr = if let Some(ref e) = yield_expr.expr {
36                     e
37                 } else {
38                     &self.unit
39                 };
40 
41                 // let ident = &self.yielder;
42 
43                 *i = if self.is_try {
44                     syn::parse_quote! { __yield_tx.send(Ok(#value_expr)).await }
45                 } else {
46                     syn::parse_quote! { __yield_tx.send(#value_expr).await }
47                 };
48             }
49             syn::Expr::Try(try_expr) => {
50                 syn::visit_mut::visit_expr_try_mut(self, try_expr);
51                 // let ident = &self.yielder;
52                 let e = &try_expr.expr;
53 
54                 *i = syn::parse_quote! {
55                     match #e {
56                         Ok(v) => v,
57                         Err(e) => {
58                             __yield_tx.send(Err(e.into())).await;
59                             return;
60                         }
61                     }
62                 };
63             }
64             syn::Expr::Closure(_) | syn::Expr::Async(_) => {
65                 let prev = self.is_xforming;
66                 self.is_xforming = false;
67                 syn::visit_mut::visit_expr_mut(self, i);
68                 self.is_xforming = prev;
69             }
70             syn::Expr::ForLoop(expr) => {
71                 syn::visit_mut::visit_expr_for_loop_mut(self, expr);
72                 // TODO: Should we allow other attributes?
73                 if expr.attrs.len() != 1 || !expr.attrs[0].path.is_ident("await") {
74                     return;
75                 }
76                 let syn::ExprForLoop {
77                     attrs,
78                     label,
79                     pat,
80                     expr,
81                     body,
82                     ..
83                 } = expr;
84 
85                 let attr = attrs.pop().unwrap();
86                 if let Err(e) = syn::parse2::<syn::parse::Nothing>(attr.tokens) {
87                     *i = syn::parse2(e.to_compile_error()).unwrap();
88                     return;
89                 }
90 
91                 *i = syn::parse_quote! {{
92                     let mut __pinned = #expr;
93                     let mut __pinned = unsafe {
94                         ::core::pin::Pin::new_unchecked(&mut __pinned)
95                     };
96                     #label
97                     loop {
98                         let #pat = match ::async_stream::reexport::next(&mut __pinned).await {
99                             ::core::option::Option::Some(e) => e,
100                             ::core::option::Option::None => break,
101                         };
102                         #body
103                     }
104                 }}
105             }
106             _ => syn::visit_mut::visit_expr_mut(self, i),
107         }
108     }
109 
visit_item_mut(&mut self, i: &mut syn::Item)110     fn visit_item_mut(&mut self, i: &mut syn::Item) {
111         let prev = self.is_xforming;
112         self.is_xforming = false;
113         syn::visit_mut::visit_item_mut(self, i);
114         self.is_xforming = prev;
115     }
116 }
117 
118 /// Asynchronous stream
119 ///
120 /// See [crate](index.html) documentation for more details.
121 ///
122 /// # Examples
123 ///
124 /// ```rust
125 /// use async_stream::stream;
126 ///
127 /// use futures_util::pin_mut;
128 /// use futures_util::stream::StreamExt;
129 ///
130 /// #[tokio::main]
131 /// async fn main() {
132 ///     let s = stream! {
133 ///         for i in 0..3 {
134 ///             yield i;
135 ///         }
136 ///     };
137 ///
138 ///     pin_mut!(s); // needed for iteration
139 ///
140 ///     while let Some(value) = s.next().await {
141 ///         println!("got {}", value);
142 ///     }
143 /// }
144 /// ```
145 #[proc_macro]
stream(input: TokenStream) -> TokenStream146 pub fn stream(input: TokenStream) -> TokenStream {
147     let mut stmts = match parse_input(input) {
148         Ok(x) => x,
149         Err(e) => return e.to_compile_error().into(),
150     };
151 
152     let mut scrub = Scrub {
153         is_xforming: true,
154         is_try: false,
155         unit: syn::parse_quote!(()),
156         num_yield: 0,
157     };
158 
159     for mut stmt in &mut stmts[..] {
160         scrub.visit_stmt_mut(&mut stmt);
161     }
162 
163     let dummy_yield = if scrub.num_yield == 0 {
164         Some(quote!(if false {
165             __yield_tx.send(()).await;
166         }))
167     } else {
168         None
169     };
170 
171     quote!({
172         let (mut __yield_tx, __yield_rx) = ::async_stream::yielder::pair();
173         ::async_stream::AsyncStream::new(__yield_rx, async move {
174             #dummy_yield
175             #(#stmts)*
176         })
177     })
178     .into()
179 }
180 
181 /// Asynchronous fallible stream
182 ///
183 /// See [crate](index.html) documentation for more details.
184 ///
185 /// # Examples
186 ///
187 /// ```rust
188 /// use tokio::net::{TcpListener, TcpStream};
189 ///
190 /// use async_stream::try_stream;
191 /// use futures_core::stream::Stream;
192 ///
193 /// use std::io;
194 /// use std::net::SocketAddr;
195 ///
196 /// fn bind_and_accept(addr: SocketAddr)
197 ///     -> impl Stream<Item = io::Result<TcpStream>>
198 /// {
199 ///     try_stream! {
200 ///         let mut listener = TcpListener::bind(addr).await?;
201 ///
202 ///         loop {
203 ///             let (stream, addr) = listener.accept().await?;
204 ///             println!("received on {:?}", addr);
205 ///             yield stream;
206 ///         }
207 ///     }
208 /// }
209 /// ```
210 #[proc_macro]
try_stream(input: TokenStream) -> TokenStream211 pub fn try_stream(input: TokenStream) -> TokenStream {
212     let mut stmts = match parse_input(input) {
213         Ok(x) => x,
214         Err(e) => return e.to_compile_error().into(),
215     };
216 
217     let mut scrub = Scrub {
218         is_xforming: true,
219         is_try: true,
220         unit: syn::parse_quote!(()),
221         num_yield: 0,
222     };
223 
224     for mut stmt in &mut stmts[..] {
225         scrub.visit_stmt_mut(&mut stmt);
226     }
227 
228     let dummy_yield = if scrub.num_yield == 0 {
229         Some(quote!(if false {
230             __yield_tx.send(()).await;
231         }))
232     } else {
233         None
234     };
235 
236     quote!({
237         let (mut __yield_tx, __yield_rx) = ::async_stream::yielder::pair();
238         ::async_stream::AsyncStream::new(__yield_rx, async move {
239             #dummy_yield
240             #(#stmts)*
241         })
242     })
243     .into()
244 }
245 
replace_for_await(input: TokenStream2) -> TokenStream2246 fn replace_for_await(input: TokenStream2) -> TokenStream2 {
247     let mut input = input.into_iter().peekable();
248     let mut tokens = Vec::new();
249 
250     while let Some(token) = input.next() {
251         match token {
252             TokenTree::Ident(ident) => {
253                 match input.peek() {
254                     Some(TokenTree::Ident(next)) if ident == "for" && next == "await" => {
255                         tokens.extend(quote!(#[#next]));
256                         let _ = input.next();
257                     }
258                     _ => {}
259                 }
260                 tokens.push(ident.into());
261             }
262             TokenTree::Group(group) => {
263                 let stream = replace_for_await(group.stream());
264                 tokens.push(Group::new(group.delimiter(), stream).into());
265             }
266             _ => tokens.push(token),
267         }
268     }
269 
270     tokens.into_iter().collect()
271 }
272