1 //! The futures-rs `join! macro implementation.
2 
3 use proc_macro::TokenStream;
4 use proc_macro2::{Span, TokenStream as TokenStream2};
5 use quote::{format_ident, quote};
6 use syn::parse::{Parse, ParseStream};
7 use syn::{Expr, Ident, Token};
8 
9 #[derive(Default)]
10 struct Join {
11     fut_exprs: Vec<Expr>,
12 }
13 
14 impl Parse for Join {
parse(input: ParseStream<'_>) -> syn::Result<Self>15     fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
16         let mut join = Self::default();
17 
18         while !input.is_empty() {
19             join.fut_exprs.push(input.parse::<Expr>()?);
20 
21             if !input.is_empty() {
22                 input.parse::<Token![,]>()?;
23             }
24         }
25 
26         Ok(join)
27     }
28 }
29 
bind_futures( fut_exprs: Vec<Expr>, span: Span, ) -> (Vec<TokenStream2>, Vec<Ident>)30 fn bind_futures(
31     fut_exprs: Vec<Expr>,
32     span: Span,
33 ) -> (Vec<TokenStream2>, Vec<Ident>) {
34     let mut future_let_bindings = Vec::with_capacity(fut_exprs.len());
35     let future_names: Vec<_> = fut_exprs
36         .into_iter()
37         .enumerate()
38         .map(|(i, expr)| {
39             let name = format_ident!("_fut{}", i, span = span);
40             future_let_bindings.push(quote! {
41                 // Move future into a local so that it is pinned in one place and
42                 // is no longer accessible by the end user.
43                 let mut #name = __futures_crate::future::maybe_done(#expr);
44             });
45             name
46         })
47         .collect();
48 
49     (future_let_bindings, future_names)
50 }
51 
52 /// The `join!` macro.
join(input: TokenStream) -> TokenStream53 pub(crate) fn join(input: TokenStream) -> TokenStream {
54     let parsed = syn::parse_macro_input!(input as Join);
55 
56     // should be def_site, but that's unstable
57     let span = Span::call_site();
58 
59     let (future_let_bindings, future_names) = bind_futures(parsed.fut_exprs, span);
60 
61     let poll_futures = future_names.iter().map(|fut| {
62         quote! {
63             __all_done &= __futures_crate::future::Future::poll(
64                 unsafe { __futures_crate::Pin::new_unchecked(&mut #fut) }, __cx).is_ready();
65         }
66     });
67     let take_outputs = future_names.iter().map(|fut| {
68         quote! {
69             unsafe { __futures_crate::Pin::new_unchecked(&mut #fut) }.take_output().unwrap(),
70         }
71     });
72 
73     TokenStream::from(quote! { {
74         #( #future_let_bindings )*
75 
76         __futures_crate::future::poll_fn(move |__cx: &mut __futures_crate::task::Context<'_>| {
77             let mut __all_done = true;
78             #( #poll_futures )*
79             if __all_done {
80                 __futures_crate::task::Poll::Ready((
81                     #( #take_outputs )*
82                 ))
83             } else {
84                 __futures_crate::task::Poll::Pending
85             }
86         }).await
87     } })
88 }
89 
90 /// The `try_join!` macro.
try_join(input: TokenStream) -> TokenStream91 pub(crate) fn try_join(input: TokenStream) -> TokenStream {
92     let parsed = syn::parse_macro_input!(input as Join);
93 
94     // should be def_site, but that's unstable
95     let span = Span::call_site();
96 
97     let (future_let_bindings, future_names) = bind_futures(parsed.fut_exprs, span);
98 
99     let poll_futures = future_names.iter().map(|fut| {
100         quote! {
101             if __futures_crate::future::Future::poll(
102                 unsafe { __futures_crate::Pin::new_unchecked(&mut #fut) }, __cx).is_pending()
103             {
104                 __all_done = false;
105             } else if unsafe { __futures_crate::Pin::new_unchecked(&mut #fut) }.output_mut().unwrap().is_err() {
106                 // `.err().unwrap()` rather than `.unwrap_err()` so that we don't introduce
107                 // a `T: Debug` bound.
108                 // Also, for an error type of ! any code after `err().unwrap()` is unreachable.
109                 #[allow(unreachable_code)]
110                 return __futures_crate::task::Poll::Ready(
111                     __futures_crate::Err(
112                         unsafe { __futures_crate::Pin::new_unchecked(&mut #fut) }.take_output().unwrap().err().unwrap()
113                     )
114                 );
115             }
116         }
117     });
118     let take_outputs = future_names.iter().map(|fut| {
119         quote! {
120             // `.ok().unwrap()` rather than `.unwrap()` so that we don't introduce
121             // an `E: Debug` bound.
122             // Also, for an ok type of ! any code after `ok().unwrap()` is unreachable.
123             #[allow(unreachable_code)]
124             unsafe { __futures_crate::Pin::new_unchecked(&mut #fut) }.take_output().unwrap().ok().unwrap(),
125         }
126     });
127 
128     TokenStream::from(quote! { {
129         #( #future_let_bindings )*
130 
131         #[allow(clippy::diverging_sub_expression)]
132         __futures_crate::future::poll_fn(move |__cx: &mut __futures_crate::task::Context<'_>| {
133             let mut __all_done = true;
134             #( #poll_futures )*
135             if __all_done {
136                 __futures_crate::task::Poll::Ready(
137                     __futures_crate::Ok((
138                         #( #take_outputs )*
139                     ))
140                 )
141             } else {
142                 __futures_crate::task::Poll::Pending
143             }
144         }).await
145     } })
146 }
147