1 // Copyright 2018 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #![recursion_limit = "128"]
6
7 extern crate proc_macro;
8
9 use proc_macro2::{Ident, TokenStream};
10 use quote::quote;
11 use syn::{parse_macro_input, Data, DeriveInput, Field, Fields, Index, Member, Variant};
12
13 #[cfg(test)]
14 mod tests;
15
16 // The method for packing an enum into a u64 is as follows:
17 // 1) Reserve the lowest "ceil(log_2(x))" bits where x is the number of enum variants.
18 // 2) Store the enum variant's index (0-based index based on order in the enum definition) in
19 // reserved bits.
20 // 3) If there is data in the enum variant, store the data in remaining bits.
21 // The method for unpacking is as follows
22 // 1) Mask the raw token to just the reserved bits
23 // 2) Match the reserved bits to the enum variant token.
24 // 3) If the indicated enum variant had data, extract it from the unreserved bits.
25
26 // Calculates the number of bits needed to store the variant index. Essentially the log base 2
27 // of the number of variants, rounded up.
variant_bits(variants: &[Variant]) -> u3228 fn variant_bits(variants: &[Variant]) -> u32 {
29 if variants.is_empty() {
30 // The degenerate case of no variants.
31 0
32 } else {
33 variants.len().next_power_of_two().trailing_zeros()
34 }
35 }
36
37 // Name of the field if it has one, otherwise 0 assuming this is the zeroth
38 // field of a tuple variant.
field_member(field: &Field) -> Member39 fn field_member(field: &Field) -> Member {
40 match &field.ident {
41 Some(name) => Member::Named(name.clone()),
42 None => Member::Unnamed(Index::from(0)),
43 }
44 }
45
46 // Generates the function body for `as_raw_token`.
generate_as_raw_token(enum_name: &Ident, variants: &[Variant]) -> TokenStream47 fn generate_as_raw_token(enum_name: &Ident, variants: &[Variant]) -> TokenStream {
48 let variant_bits = variant_bits(variants);
49
50 // Each iteration corresponds to one variant's match arm.
51 let cases = variants.iter().enumerate().map(|(index, variant)| {
52 let variant_name = &variant.ident;
53 let index = index as u64;
54
55 // The capture string is for everything between the variant identifier and the `=>` in
56 // the match arm: the variant's data capture.
57 let capture = variant.fields.iter().next().map(|field| {
58 let member = field_member(&field);
59 quote!({ #member: data })
60 });
61
62 // The modifier string ORs the variant index with extra bits from the variant data
63 // field.
64 let modifier = match variant.fields {
65 Fields::Named(_) | Fields::Unnamed(_) => Some(quote! {
66 | ((data as u64) << #variant_bits)
67 }),
68 Fields::Unit => None,
69 };
70
71 // Assembly of the match arm.
72 quote! {
73 #enum_name::#variant_name #capture => #index #modifier
74 }
75 });
76
77 quote! {
78 match *self {
79 #(
80 #cases,
81 )*
82 }
83 }
84 }
85
86 // Generates the function body for `from_raw_token`.
generate_from_raw_token(enum_name: &Ident, variants: &[Variant]) -> TokenStream87 fn generate_from_raw_token(enum_name: &Ident, variants: &[Variant]) -> TokenStream {
88 let variant_bits = variant_bits(variants);
89 let variant_mask = ((1 << variant_bits) - 1) as u64;
90
91 // Each iteration corresponds to one variant's match arm.
92 let cases = variants.iter().enumerate().map(|(index, variant)| {
93 let variant_name = &variant.ident;
94 let index = index as u64;
95
96 // The data string is for extracting the enum variant's data bits out of the raw token
97 // data, which includes both variant index and data bits.
98 let data = variant.fields.iter().next().map(|field| {
99 let member = field_member(&field);
100 let ty = &field.ty;
101 quote!({ #member: (data >> #variant_bits) as #ty })
102 });
103
104 // Assembly of the match arm.
105 quote! {
106 #index => #enum_name::#variant_name #data
107 }
108 });
109
110 quote! {
111 // The match expression only matches the bits for the variant index.
112 match data & #variant_mask {
113 #(
114 #cases,
115 )*
116 _ => unreachable!(),
117 }
118 }
119 }
120
121 // The proc_macro::TokenStream type can only be constructed from within a
122 // procedural macro, meaning that unit tests are not able to invoke `fn
123 // poll_token` below as an ordinary Rust function. We factor out the logic into
124 // a signature that deals with Syn and proc-macro2 types only which are not
125 // restricted to a procedural macro invocation.
poll_token_inner(input: DeriveInput) -> TokenStream126 fn poll_token_inner(input: DeriveInput) -> TokenStream {
127 let variants: Vec<Variant> = match input.data {
128 Data::Enum(data) => data.variants.into_iter().collect(),
129 Data::Struct(_) | Data::Union(_) => panic!("input must be an enum"),
130 };
131
132 for variant in &variants {
133 assert!(variant.fields.iter().count() <= 1);
134 }
135
136 // Given our basic model of a user given enum that is suitable as a token, we generate the
137 // implementation. The implementation is NOT always well formed, such as when a variant's data
138 // type is not bit shiftable or castable to u64, but we let Rust generate such errors as it
139 // would be difficult to detect every kind of error. Importantly, every implementation that we
140 // generate here and goes on to compile succesfully is sound.
141
142 let enum_name = input.ident;
143 let as_raw_token = generate_as_raw_token(&enum_name, &variants);
144 let from_raw_token = generate_from_raw_token(&enum_name, &variants);
145
146 quote! {
147 impl PollToken for #enum_name {
148 fn as_raw_token(&self) -> u64 {
149 #as_raw_token
150 }
151
152 fn from_raw_token(data: u64) -> Self {
153 #from_raw_token
154 }
155 }
156 }
157 }
158
159 /// Implements the PollToken trait for a given `enum`.
160 ///
161 /// There are limitations on what `enum`s this custom derive will work on:
162 ///
163 /// * Each variant must be a unit variant (no data), or have a single (un)named data field.
164 /// * If a variant has data, it must be a primitive type castable to and from a `u64`.
165 /// * If a variant data has size greater than or equal to a `u64`, its most significant bits must be
166 /// zero. The number of bits truncated is equal to the number of bits used to store the variant
167 /// index plus the number of bits above 64.
168 #[proc_macro_derive(PollToken)]
poll_token(input: proc_macro::TokenStream) -> proc_macro::TokenStream169 pub fn poll_token(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
170 let input = parse_macro_input!(input as DeriveInput);
171 poll_token_inner(input).into()
172 }
173