1 //! The futures-rs `select! macro implementation.
2 
3 use proc_macro::TokenStream;
4 use proc_macro2::Span;
5 use quote::{format_ident, quote};
6 use syn::{parse_quote, Expr, Ident, Pat, Token};
7 use syn::parse::{Parse, ParseStream};
8 
9 mod kw {
10     syn::custom_keyword!(complete);
11 }
12 
13 struct Select {
14     // span of `complete`, then expression after `=> ...`
15     complete: Option<Expr>,
16     default: Option<Expr>,
17     normal_fut_exprs: Vec<Expr>,
18     normal_fut_handlers: Vec<(Pat, Expr)>,
19 }
20 
21 #[allow(clippy::large_enum_variant)]
22 enum CaseKind {
23     Complete,
24     Default,
25     Normal(Pat, Expr),
26 }
27 
28 impl Parse for Select {
parse(input: ParseStream<'_>) -> syn::Result<Self>29     fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
30         let mut select = Self {
31             complete: None,
32             default: None,
33             normal_fut_exprs: vec![],
34             normal_fut_handlers: vec![],
35         };
36 
37         while !input.is_empty() {
38             let case_kind = if input.peek(kw::complete) {
39                 // `complete`
40                 if select.complete.is_some() {
41                     return Err(input.error("multiple `complete` cases found, only one allowed"));
42                 }
43                 input.parse::<kw::complete>()?;
44                 CaseKind::Complete
45             } else if input.peek(Token![default]) {
46                 // `default`
47                 if select.default.is_some() {
48                     return Err(input.error("multiple `default` cases found, only one allowed"));
49                 }
50                 input.parse::<Ident>()?;
51                 CaseKind::Default
52             } else {
53                 // `<pat> = <expr>`
54                 let pat = input.parse()?;
55                 input.parse::<Token![=]>()?;
56                 let expr = input.parse()?;
57                 CaseKind::Normal(pat, expr)
58             };
59 
60             // `=> <expr>`
61             input.parse::<Token![=>]>()?;
62             let expr = input.parse::<Expr>()?;
63 
64             // Commas after the expression are only optional if it's a `Block`
65             // or it is the last branch in the `match`.
66             let is_block = match expr { Expr::Block(_) => true, _ => false };
67             if is_block || input.is_empty() {
68                 input.parse::<Option<Token![,]>>()?;
69             } else {
70                 input.parse::<Token![,]>()?;
71             }
72 
73             match case_kind {
74                 CaseKind::Complete => select.complete = Some(expr),
75                 CaseKind::Default => select.default = Some(expr),
76                 CaseKind::Normal(pat, fut_expr) => {
77                     select.normal_fut_exprs.push(fut_expr);
78                     select.normal_fut_handlers.push((pat, expr));
79                 },
80             }
81         }
82 
83         Ok(select)
84     }
85 }
86 
87 // Enum over all the cases in which the `select!` waiting has completed and the result
88 // can be processed.
89 //
90 // `enum __PrivResult<_1, _2, ...> { _1(_1), _2(_2), ..., Complete }`
declare_result_enum( result_ident: Ident, variants: usize, complete: bool, span: Span ) -> (Vec<Ident>, syn::ItemEnum)91 fn declare_result_enum(
92     result_ident: Ident,
93     variants: usize,
94     complete: bool,
95     span: Span
96 ) -> (Vec<Ident>, syn::ItemEnum) {
97     // "_0", "_1", "_2"
98     let variant_names: Vec<Ident> =
99         (0..variants)
100             .map(|num| format_ident!("_{}", num, span = span))
101             .collect();
102 
103     let type_parameters = &variant_names;
104     let variants = &variant_names;
105 
106     let complete_variant = if complete {
107         Some(quote!(Complete))
108     } else {
109         None
110     };
111 
112     let enum_item = parse_quote! {
113         enum #result_ident<#(#type_parameters,)*> {
114             #(
115                 #variants(#type_parameters),
116             )*
117             #complete_variant
118         }
119     };
120 
121     (variant_names, enum_item)
122 }
123 
124 /// The `select!` macro.
select(input: TokenStream) -> TokenStream125 pub(crate) fn select(input: TokenStream) -> TokenStream {
126     select_inner(input, true)
127 }
128 
129 /// The `select_biased!` macro.
select_biased(input: TokenStream) -> TokenStream130 pub(crate) fn select_biased(input: TokenStream) -> TokenStream {
131     select_inner(input, false)
132 }
133 
select_inner(input: TokenStream, random: bool) -> TokenStream134 fn select_inner(input: TokenStream, random: bool) -> TokenStream {
135     let parsed = syn::parse_macro_input!(input as Select);
136 
137     // should be def_site, but that's unstable
138     let span = Span::call_site();
139 
140     let enum_ident = Ident::new("__PrivResult", span);
141 
142     let (variant_names, enum_item) = declare_result_enum(
143         enum_ident.clone(),
144         parsed.normal_fut_exprs.len(),
145         parsed.complete.is_some(),
146         span,
147     );
148 
149     // bind non-`Ident` future exprs w/ `let`
150     let mut future_let_bindings = Vec::with_capacity(parsed.normal_fut_exprs.len());
151     let bound_future_names: Vec<_> = parsed.normal_fut_exprs.into_iter()
152         .zip(variant_names.iter())
153         .map(|(expr, variant_name)| {
154             match expr {
155                 syn::Expr::Path(path) => {
156                     // Don't bind futures that are already a path.
157                     // This prevents creating redundant stack space
158                     // for them.
159                     // Passing Futures by path requires those Futures to implement Unpin.
160                     // We check for this condition here in order to be able to
161                     // safely use Pin::new_unchecked(&mut #path) later on.
162                     future_let_bindings.push(quote! {
163                         __futures_crate::async_await::assert_fused_future(&#path);
164                         __futures_crate::async_await::assert_unpin(&#path);
165                     });
166                     path
167                 },
168                 _ => {
169                     // Bind and pin the resulting Future on the stack. This is
170                     // necessary to support direct select! calls on !Unpin
171                     // Futures. The Future is not explicitly pinned here with
172                     // a Pin call, but assumed as pinned. The actual Pin is
173                     // created inside the poll() function below to defer the
174                     // creation of the temporary pointer, which would otherwise
175                     // increase the size of the generated Future.
176                     // Safety: This is safe since the lifetime of the Future
177                     // is totally constraint to the lifetime of the select!
178                     // expression, and the Future can't get moved inside it
179                     // (it is shadowed).
180                     future_let_bindings.push(quote! {
181                         let mut #variant_name = #expr;
182                     });
183                     parse_quote! { #variant_name }
184                 }
185             }
186         })
187         .collect();
188 
189     // For each future, make an `&mut dyn FnMut(&mut Context<'_>) -> Option<Poll<__PrivResult<...>>`
190     // to use for polling that individual future. These will then be put in an array.
191     let poll_functions = bound_future_names.iter().zip(variant_names.iter())
192         .map(|(bound_future_name, variant_name)| {
193             // Below we lazily create the Pin on the Future below.
194             // This is done in order to avoid allocating memory in the generator
195             // for the Pin variable.
196             // Safety: This is safe because one of the following condition applies:
197             // 1. The Future is passed by the caller by name, and we assert that
198             //    it implements Unpin.
199             // 2. The Future is created in scope of the select! function and will
200             //    not be moved for the duration of it. It is thereby stack-pinned
201             quote! {
202                 let mut #variant_name = |__cx: &mut __futures_crate::task::Context<'_>| {
203                     let mut #bound_future_name = unsafe {
204                         __futures_crate::Pin::new_unchecked(&mut #bound_future_name)
205                     };
206                     if __futures_crate::future::FusedFuture::is_terminated(&#bound_future_name) {
207                         __futures_crate::None
208                     } else {
209                         __futures_crate::Some(__futures_crate::future::FutureExt::poll_unpin(
210                             &mut #bound_future_name,
211                             __cx,
212                         ).map(#enum_ident::#variant_name))
213                     }
214                 };
215                 let #variant_name: &mut dyn FnMut(
216                     &mut __futures_crate::task::Context<'_>
217                 ) -> __futures_crate::Option<__futures_crate::task::Poll<_>> = &mut #variant_name;
218             }
219         });
220 
221     let none_polled = if parsed.complete.is_some() {
222         quote! {
223             __futures_crate::task::Poll::Ready(#enum_ident::Complete)
224         }
225     } else {
226         quote! {
227             panic!("all futures in select! were completed,\
228                     but no `complete =>` handler was provided")
229         }
230     };
231 
232     let branches = parsed.normal_fut_handlers.into_iter()
233         .zip(variant_names.iter())
234         .map(|((pat, expr), variant_name)| {
235             quote! {
236                 #enum_ident::#variant_name(#pat) => { #expr },
237             }
238         });
239     let branches = quote! { #( #branches )* };
240 
241     let complete_branch = parsed.complete.map(|complete_expr| {
242         quote! {
243             #enum_ident::Complete => { #complete_expr },
244         }
245     });
246 
247     let branches = quote! {
248         #branches
249         #complete_branch
250     };
251 
252     let await_select_fut = if parsed.default.is_some() {
253         // For select! with default this returns the Poll result
254         quote! {
255             __poll_fn(&mut __futures_crate::task::Context::from_waker(
256                 __futures_crate::task::noop_waker_ref()
257             ))
258         }
259     } else {
260         quote! {
261             __futures_crate::future::poll_fn(__poll_fn).await
262         }
263     };
264 
265     let execute_result_expr = if let Some(default_expr) = &parsed.default {
266         // For select! with default __select_result is a Poll, otherwise not
267         quote! {
268             match __select_result {
269                 __futures_crate::task::Poll::Ready(result) => match result {
270                     #branches
271                 },
272                 _ => #default_expr
273             }
274         }
275     } else {
276         quote! {
277             match __select_result {
278                 #branches
279             }
280         }
281     };
282 
283     let shuffle = if random {
284         quote! {
285             __futures_crate::async_await::shuffle(&mut __select_arr);
286         }
287     } else {
288         quote!()
289     };
290 
291     TokenStream::from(quote! { {
292         #enum_item
293 
294         let __select_result = {
295             #( #future_let_bindings )*
296 
297             let mut __poll_fn = |__cx: &mut __futures_crate::task::Context<'_>| {
298                 let mut __any_polled = false;
299 
300                 #( #poll_functions )*
301 
302                 let mut __select_arr = [#( #variant_names ),*];
303                 #shuffle
304                 for poller in &mut __select_arr {
305                     let poller: &mut &mut dyn FnMut(
306                         &mut __futures_crate::task::Context<'_>
307                     ) -> __futures_crate::Option<__futures_crate::task::Poll<_>> = poller;
308                     match poller(__cx) {
309                         __futures_crate::Some(x @ __futures_crate::task::Poll::Ready(_)) =>
310                             return x,
311                         __futures_crate::Some(__futures_crate::task::Poll::Pending) => {
312                             __any_polled = true;
313                         }
314                         __futures_crate::None => {}
315                     }
316                 }
317 
318                 if !__any_polled {
319                     #none_polled
320                 } else {
321                     __futures_crate::task::Poll::Pending
322                 }
323             };
324 
325             #await_select_fut
326         };
327 
328         #execute_result_expr
329     } })
330 }
331