frame_election_provider_solution_type/
lib.rs1use proc_macro::TokenStream;
21use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
22use proc_macro_crate::{crate_name, FoundCrate};
23use quote::quote;
24use syn::parse::{Parse, ParseStream, Result};
25
26mod codec;
27mod from_assignment_helpers;
28mod index_assignment;
29mod single_page;
30
31pub(crate) fn vote_field(n: usize) -> Ident {
33 quote::format_ident!("votes{}", n)
34}
35
36pub(crate) fn syn_err(message: &'static str) -> syn::Error {
38 syn::Error::new(Span::call_site(), message)
39}
40
41#[proc_macro]
126pub fn generate_solution_type(item: TokenStream) -> TokenStream {
127 let solution_def = syn::parse_macro_input!(item as SolutionDef);
128
129 let imports = imports().unwrap_or_else(|e| e.to_compile_error());
130
131 let def = single_page::generate(solution_def).unwrap_or_else(|e| e.to_compile_error());
132
133 quote!(
134 #imports
135 #def
136 )
137 .into()
138}
139
140struct SolutionDef {
141 vis: syn::Visibility,
142 ident: syn::Ident,
143 voter_type: syn::Type,
144 target_type: syn::Type,
145 weight_type: syn::Type,
146 max_voters: syn::Type,
147 count: usize,
148 compact_encoding: bool,
149}
150
151fn check_attributes(input: ParseStream) -> syn::Result<bool> {
152 let mut attrs = input.call(syn::Attribute::parse_outer).unwrap_or_default();
153 if attrs.len() > 1 {
154 let extra_attr = attrs.pop().expect("attributes vec with len > 1 can be popped");
155 return Err(syn::Error::new_spanned(
156 extra_attr,
157 "compact solution can accept only #[compact]",
158 ))
159 }
160 if attrs.is_empty() {
161 return Ok(false)
162 }
163 let attr = attrs.pop().expect("attributes vec with len 1 can be popped.");
164 if attr.path().is_ident("compact") {
165 Ok(true)
166 } else {
167 Err(syn::Error::new_spanned(attr, "compact solution can accept only #[compact]"))
168 }
169}
170
171impl Parse for SolutionDef {
172 fn parse(input: ParseStream) -> syn::Result<Self> {
173 let compact_encoding = check_attributes(input)?;
175
176 let vis: syn::Visibility = input.parse()?;
178 <syn::Token![struct]>::parse(input)?;
179 let ident: syn::Ident = input.parse()?;
180
181 <syn::Token![::]>::parse(input)?;
183 let generics: syn::AngleBracketedGenericArguments = input.parse()?;
184
185 if generics.args.len() != 4 {
186 return Err(syn_err("Must provide 4 generic args."))
187 }
188
189 let expected_types = ["VoterIndex", "TargetIndex", "Accuracy", "MaxVoters"];
190
191 let mut types: Vec<syn::Type> = generics
192 .args
193 .iter()
194 .zip(expected_types.iter())
195 .map(|(t, expected)| match t {
196 syn::GenericArgument::Type(ty) => {
197 Err(syn::Error::new_spanned(
199 ty,
200 format!("Expected binding: `{} = ...`", expected),
201 ))
202 },
203 syn::GenericArgument::AssocType(syn::AssocType { ident, ty, .. }) => {
204 if ident == expected {
206 Ok(ty.clone())
207 } else {
208 Err(syn::Error::new_spanned(ident, format!("Expected `{}`", expected)))
209 }
210 },
211 _ => Err(syn_err("Wrong type of generic provided. Must be a `type`.")),
212 })
213 .collect::<Result<_>>()?;
214
215 let max_voters = types.pop().expect("Vector of length 4 can be popped; qed");
216 let weight_type = types.pop().expect("Vector of length 3 can be popped; qed");
217 let target_type = types.pop().expect("Vector of length 2 can be popped; qed");
218 let voter_type = types.pop().expect("Vector of length 1 can be popped; qed");
219
220 let count_expr: syn::ExprParen = input.parse()?;
222 let count = parse_parenthesized_number::<usize>(count_expr)?;
223
224 Ok(Self {
225 vis,
226 ident,
227 voter_type,
228 target_type,
229 weight_type,
230 max_voters,
231 count,
232 compact_encoding,
233 })
234 }
235}
236
237fn parse_parenthesized_number<N: std::str::FromStr>(input_expr: syn::ExprParen) -> syn::Result<N>
238where
239 <N as std::str::FromStr>::Err: std::fmt::Display,
240{
241 let expr = input_expr.expr;
242 let expr_lit = match *expr {
243 syn::Expr::Lit(count_lit) => count_lit.lit,
244 _ => return Err(syn_err("Count must be literal.")),
245 };
246 let int_lit = match expr_lit {
247 syn::Lit::Int(int_lit) => int_lit,
248 _ => return Err(syn_err("Count must be int literal.")),
249 };
250 int_lit.base10_parse::<N>()
251}
252
253fn imports() -> Result<TokenStream2> {
254 match crate_name("frame-election-provider-support") {
255 Ok(FoundCrate::Itself) => Ok(quote! {
256 use crate as _feps;
257 use _feps::private as _fepsp;
258 }),
259 Ok(FoundCrate::Name(frame_election_provider_support)) => {
260 let ident = syn::Ident::new(&frame_election_provider_support, Span::call_site());
261 Ok(quote!(
262 use #ident as _feps;
263 use _feps::private as _fepsp;
264 ))
265 },
266 Err(e) => match crate_name("polkadot-sdk") {
267 Ok(FoundCrate::Name(polkadot_sdk)) => {
268 let ident = syn::Ident::new(&polkadot_sdk, Span::call_site());
269 Ok(quote!(
270 use #ident::frame_election_provider_support as _feps;
271 use _feps::private as _fepsp;
272 ))
273 },
274 _ => Err(syn::Error::new(Span::call_site(), e)),
275 },
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 #[test]
282 fn ui_fail() {
283 if std::env::var("RUN_UI_TESTS").is_err() {
285 return
286 }
287
288 let cases = trybuild::TestCases::new();
289 cases.compile_fail("tests/ui/fail/*.rs");
290 }
291}