1 extern crate proc_macro;
2 use proc_macro::TokenStream;
3 use proc_macro2::{Delimiter, Group, TokenStream as TokenStream2, TokenTree};
4 use quote::quote;
5 use syn::visit_mut::VisitMut;
6
7 struct Scrub {
8 is_xforming: bool,
9 is_try: bool,
10 unit: Box<syn::Expr>,
11 num_yield: u32,
12 }
13
parse_input(input: TokenStream) -> syn::Result<Vec<syn::Stmt>>14 fn parse_input(input: TokenStream) -> syn::Result<Vec<syn::Stmt>> {
15 let input = replace_for_await(input.into());
16 // syn does not provide a way to parse `Vec<Stmt>` directly from `TokenStream`,
17 // so wrap input in a brace and then parse it as a block.
18 let input = TokenStream2::from(TokenTree::Group(Group::new(Delimiter::Brace, input)));
19 let syn::Block { stmts, .. } = syn::parse2(input)?;
20
21 Ok(stmts)
22 }
23
24 impl VisitMut for Scrub {
visit_expr_mut(&mut self, i: &mut syn::Expr)25 fn visit_expr_mut(&mut self, i: &mut syn::Expr) {
26 if !self.is_xforming {
27 syn::visit_mut::visit_expr_mut(self, i);
28 return;
29 }
30
31 match i {
32 syn::Expr::Yield(yield_expr) => {
33 self.num_yield += 1;
34
35 let value_expr = if let Some(ref e) = yield_expr.expr {
36 e
37 } else {
38 &self.unit
39 };
40
41 // let ident = &self.yielder;
42
43 *i = if self.is_try {
44 syn::parse_quote! { __yield_tx.send(Ok(#value_expr)).await }
45 } else {
46 syn::parse_quote! { __yield_tx.send(#value_expr).await }
47 };
48 }
49 syn::Expr::Try(try_expr) => {
50 syn::visit_mut::visit_expr_try_mut(self, try_expr);
51 // let ident = &self.yielder;
52 let e = &try_expr.expr;
53
54 *i = syn::parse_quote! {
55 match #e {
56 Ok(v) => v,
57 Err(e) => {
58 __yield_tx.send(Err(e.into())).await;
59 return;
60 }
61 }
62 };
63 }
64 syn::Expr::Closure(_) | syn::Expr::Async(_) => {
65 let prev = self.is_xforming;
66 self.is_xforming = false;
67 syn::visit_mut::visit_expr_mut(self, i);
68 self.is_xforming = prev;
69 }
70 syn::Expr::ForLoop(expr) => {
71 syn::visit_mut::visit_expr_for_loop_mut(self, expr);
72 // TODO: Should we allow other attributes?
73 if expr.attrs.len() != 1 || !expr.attrs[0].path.is_ident("await") {
74 return;
75 }
76 let syn::ExprForLoop {
77 attrs,
78 label,
79 pat,
80 expr,
81 body,
82 ..
83 } = expr;
84
85 let attr = attrs.pop().unwrap();
86 if let Err(e) = syn::parse2::<syn::parse::Nothing>(attr.tokens) {
87 *i = syn::parse2(e.to_compile_error()).unwrap();
88 return;
89 }
90
91 *i = syn::parse_quote! {{
92 let mut __pinned = #expr;
93 let mut __pinned = unsafe {
94 ::core::pin::Pin::new_unchecked(&mut __pinned)
95 };
96 #label
97 loop {
98 let #pat = match ::async_stream::reexport::next(&mut __pinned).await {
99 ::core::option::Option::Some(e) => e,
100 ::core::option::Option::None => break,
101 };
102 #body
103 }
104 }}
105 }
106 _ => syn::visit_mut::visit_expr_mut(self, i),
107 }
108 }
109
visit_item_mut(&mut self, i: &mut syn::Item)110 fn visit_item_mut(&mut self, i: &mut syn::Item) {
111 let prev = self.is_xforming;
112 self.is_xforming = false;
113 syn::visit_mut::visit_item_mut(self, i);
114 self.is_xforming = prev;
115 }
116 }
117
118 /// Asynchronous stream
119 ///
120 /// See [crate](index.html) documentation for more details.
121 ///
122 /// # Examples
123 ///
124 /// ```rust
125 /// use async_stream::stream;
126 ///
127 /// use futures_util::pin_mut;
128 /// use futures_util::stream::StreamExt;
129 ///
130 /// #[tokio::main]
131 /// async fn main() {
132 /// let s = stream! {
133 /// for i in 0..3 {
134 /// yield i;
135 /// }
136 /// };
137 ///
138 /// pin_mut!(s); // needed for iteration
139 ///
140 /// while let Some(value) = s.next().await {
141 /// println!("got {}", value);
142 /// }
143 /// }
144 /// ```
145 #[proc_macro]
stream(input: TokenStream) -> TokenStream146 pub fn stream(input: TokenStream) -> TokenStream {
147 let mut stmts = match parse_input(input) {
148 Ok(x) => x,
149 Err(e) => return e.to_compile_error().into(),
150 };
151
152 let mut scrub = Scrub {
153 is_xforming: true,
154 is_try: false,
155 unit: syn::parse_quote!(()),
156 num_yield: 0,
157 };
158
159 for mut stmt in &mut stmts[..] {
160 scrub.visit_stmt_mut(&mut stmt);
161 }
162
163 let dummy_yield = if scrub.num_yield == 0 {
164 Some(quote!(if false {
165 __yield_tx.send(()).await;
166 }))
167 } else {
168 None
169 };
170
171 quote!({
172 let (mut __yield_tx, __yield_rx) = ::async_stream::yielder::pair();
173 ::async_stream::AsyncStream::new(__yield_rx, async move {
174 #dummy_yield
175 #(#stmts)*
176 })
177 })
178 .into()
179 }
180
181 /// Asynchronous fallible stream
182 ///
183 /// See [crate](index.html) documentation for more details.
184 ///
185 /// # Examples
186 ///
187 /// ```rust
188 /// use tokio::net::{TcpListener, TcpStream};
189 ///
190 /// use async_stream::try_stream;
191 /// use futures_core::stream::Stream;
192 ///
193 /// use std::io;
194 /// use std::net::SocketAddr;
195 ///
196 /// fn bind_and_accept(addr: SocketAddr)
197 /// -> impl Stream<Item = io::Result<TcpStream>>
198 /// {
199 /// try_stream! {
200 /// let mut listener = TcpListener::bind(addr).await?;
201 ///
202 /// loop {
203 /// let (stream, addr) = listener.accept().await?;
204 /// println!("received on {:?}", addr);
205 /// yield stream;
206 /// }
207 /// }
208 /// }
209 /// ```
210 #[proc_macro]
try_stream(input: TokenStream) -> TokenStream211 pub fn try_stream(input: TokenStream) -> TokenStream {
212 let mut stmts = match parse_input(input) {
213 Ok(x) => x,
214 Err(e) => return e.to_compile_error().into(),
215 };
216
217 let mut scrub = Scrub {
218 is_xforming: true,
219 is_try: true,
220 unit: syn::parse_quote!(()),
221 num_yield: 0,
222 };
223
224 for mut stmt in &mut stmts[..] {
225 scrub.visit_stmt_mut(&mut stmt);
226 }
227
228 let dummy_yield = if scrub.num_yield == 0 {
229 Some(quote!(if false {
230 __yield_tx.send(()).await;
231 }))
232 } else {
233 None
234 };
235
236 quote!({
237 let (mut __yield_tx, __yield_rx) = ::async_stream::yielder::pair();
238 ::async_stream::AsyncStream::new(__yield_rx, async move {
239 #dummy_yield
240 #(#stmts)*
241 })
242 })
243 .into()
244 }
245
replace_for_await(input: TokenStream2) -> TokenStream2246 fn replace_for_await(input: TokenStream2) -> TokenStream2 {
247 let mut input = input.into_iter().peekable();
248 let mut tokens = Vec::new();
249
250 while let Some(token) = input.next() {
251 match token {
252 TokenTree::Ident(ident) => {
253 match input.peek() {
254 Some(TokenTree::Ident(next)) if ident == "for" && next == "await" => {
255 tokens.extend(quote!(#[#next]));
256 let _ = input.next();
257 }
258 _ => {}
259 }
260 tokens.push(ident.into());
261 }
262 TokenTree::Group(group) => {
263 let stream = replace_for_await(group.stream());
264 tokens.push(Group::new(group.delimiter(), stream).into());
265 }
266 _ => tokens.push(token),
267 }
268 }
269
270 tokens.into_iter().collect()
271 }
272