1 //! The futures-rs `join! macro implementation.
2
3 use proc_macro::TokenStream;
4 use proc_macro2::{Span, TokenStream as TokenStream2};
5 use quote::{format_ident, quote};
6 use syn::parse::{Parse, ParseStream};
7 use syn::{Expr, Ident, Token};
8
9 #[derive(Default)]
10 struct Join {
11 fut_exprs: Vec<Expr>,
12 }
13
14 impl Parse for Join {
parse(input: ParseStream<'_>) -> syn::Result<Self>15 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
16 let mut join = Self::default();
17
18 while !input.is_empty() {
19 join.fut_exprs.push(input.parse::<Expr>()?);
20
21 if !input.is_empty() {
22 input.parse::<Token![,]>()?;
23 }
24 }
25
26 Ok(join)
27 }
28 }
29
bind_futures( fut_exprs: Vec<Expr>, span: Span, ) -> (Vec<TokenStream2>, Vec<Ident>)30 fn bind_futures(
31 fut_exprs: Vec<Expr>,
32 span: Span,
33 ) -> (Vec<TokenStream2>, Vec<Ident>) {
34 let mut future_let_bindings = Vec::with_capacity(fut_exprs.len());
35 let future_names: Vec<_> = fut_exprs
36 .into_iter()
37 .enumerate()
38 .map(|(i, expr)| {
39 let name = format_ident!("_fut{}", i, span = span);
40 future_let_bindings.push(quote! {
41 // Move future into a local so that it is pinned in one place and
42 // is no longer accessible by the end user.
43 let mut #name = __futures_crate::future::maybe_done(#expr);
44 });
45 name
46 })
47 .collect();
48
49 (future_let_bindings, future_names)
50 }
51
52 /// The `join!` macro.
join(input: TokenStream) -> TokenStream53 pub(crate) fn join(input: TokenStream) -> TokenStream {
54 let parsed = syn::parse_macro_input!(input as Join);
55
56 // should be def_site, but that's unstable
57 let span = Span::call_site();
58
59 let (future_let_bindings, future_names) = bind_futures(parsed.fut_exprs, span);
60
61 let poll_futures = future_names.iter().map(|fut| {
62 quote! {
63 __all_done &= __futures_crate::future::Future::poll(
64 unsafe { __futures_crate::Pin::new_unchecked(&mut #fut) }, __cx).is_ready();
65 }
66 });
67 let take_outputs = future_names.iter().map(|fut| {
68 quote! {
69 unsafe { __futures_crate::Pin::new_unchecked(&mut #fut) }.take_output().unwrap(),
70 }
71 });
72
73 TokenStream::from(quote! { {
74 #( #future_let_bindings )*
75
76 __futures_crate::future::poll_fn(move |__cx: &mut __futures_crate::task::Context<'_>| {
77 let mut __all_done = true;
78 #( #poll_futures )*
79 if __all_done {
80 __futures_crate::task::Poll::Ready((
81 #( #take_outputs )*
82 ))
83 } else {
84 __futures_crate::task::Poll::Pending
85 }
86 }).await
87 } })
88 }
89
90 /// The `try_join!` macro.
try_join(input: TokenStream) -> TokenStream91 pub(crate) fn try_join(input: TokenStream) -> TokenStream {
92 let parsed = syn::parse_macro_input!(input as Join);
93
94 // should be def_site, but that's unstable
95 let span = Span::call_site();
96
97 let (future_let_bindings, future_names) = bind_futures(parsed.fut_exprs, span);
98
99 let poll_futures = future_names.iter().map(|fut| {
100 quote! {
101 if __futures_crate::future::Future::poll(
102 unsafe { __futures_crate::Pin::new_unchecked(&mut #fut) }, __cx).is_pending()
103 {
104 __all_done = false;
105 } else if unsafe { __futures_crate::Pin::new_unchecked(&mut #fut) }.output_mut().unwrap().is_err() {
106 // `.err().unwrap()` rather than `.unwrap_err()` so that we don't introduce
107 // a `T: Debug` bound.
108 // Also, for an error type of ! any code after `err().unwrap()` is unreachable.
109 #[allow(unreachable_code)]
110 return __futures_crate::task::Poll::Ready(
111 __futures_crate::Err(
112 unsafe { __futures_crate::Pin::new_unchecked(&mut #fut) }.take_output().unwrap().err().unwrap()
113 )
114 );
115 }
116 }
117 });
118 let take_outputs = future_names.iter().map(|fut| {
119 quote! {
120 // `.ok().unwrap()` rather than `.unwrap()` so that we don't introduce
121 // an `E: Debug` bound.
122 // Also, for an ok type of ! any code after `ok().unwrap()` is unreachable.
123 #[allow(unreachable_code)]
124 unsafe { __futures_crate::Pin::new_unchecked(&mut #fut) }.take_output().unwrap().ok().unwrap(),
125 }
126 });
127
128 TokenStream::from(quote! { {
129 #( #future_let_bindings )*
130
131 #[allow(clippy::diverging_sub_expression)]
132 __futures_crate::future::poll_fn(move |__cx: &mut __futures_crate::task::Context<'_>| {
133 let mut __all_done = true;
134 #( #poll_futures )*
135 if __all_done {
136 __futures_crate::task::Poll::Ready(
137 __futures_crate::Ok((
138 #( #take_outputs )*
139 ))
140 )
141 } else {
142 __futures_crate::task::Poll::Pending
143 }
144 }).await
145 } })
146 }
147