referrerpolicy=no-referrer-when-downgrade

frame_election_provider_solution_type/
lib.rs

1// This file is part of Substrate.
2
3// Copyright (C) Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Proc macro for a npos solution type.
19
20use proc_macro::TokenStream;
21use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
22use proc_macro_crate::{crate_name, FoundCrate};
23use quote::quote;
24use syn::parse::{Parse, ParseStream, Result};
25
26mod codec;
27mod from_assignment_helpers;
28mod index_assignment;
29mod single_page;
30
31/// Get the name of a filed based on voter count.
32pub(crate) fn vote_field(n: usize) -> Ident {
33	quote::format_ident!("votes{}", n)
34}
35
36/// Generate a `syn::Error`.
37pub(crate) fn syn_err(message: &'static str) -> syn::Error {
38	syn::Error::new(Span::call_site(), message)
39}
40
41/// Generates a struct to store the election result in a small/compact way. This can encode a
42/// structure which is the equivalent of a `sp_npos_elections::Assignment<_>`.
43///
44/// The following data types can be configured by the macro.
45///
46/// - The identifier of the voter. This can be any type that supports `parity-scale-codec`'s compact
47///   encoding.
48/// - The identifier of the target. This can be any type that supports `parity-scale-codec`'s
49///   compact encoding.
50/// - The accuracy of the ratios. This must be one of the `PerThing` types defined in
51///   `sp-arithmetic`.
52/// - The maximum number of voters. This must be of type `Get<u32>`. Check <https://github.com/paritytech/substrate/issues/10866>
53///   for more details. This is used to bound the struct, by leveraging the fact that `votes1.len()
54///   < votes2.len() < ... < votesn.len()` (the details of the struct is explained further below).
55///   We know that `sum_i votes_i.len() <= MaxVoters`, and we know that the maximum size of the
56///   struct would be achieved if all voters fall in the last bucket. One can also check the tests
57///   and more specifically `max_encoded_len_exact` for a concrete example.
58///
59/// Moreover, the maximum number of edges per voter (distribution per assignment) also need to be
60/// specified. Attempting to convert from/to an assignment with more distributions will fail.
61///
62/// For example, the following generates a public struct with name `TestSolution` with `u16` voter
63/// type, `u8` target type and `Perbill` accuracy with maximum of 4 edges per voter.
64///
65/// ```
66/// # use frame_election_provider_solution_type::generate_solution_type;
67/// # use sp_arithmetic::per_things::Perbill;
68/// # use frame_support::traits::ConstU32;
69/// generate_solution_type!(pub struct TestSolution::<
70///     VoterIndex = u16,
71///     TargetIndex = u8,
72///     Accuracy = Perbill,
73///     MaxVoters = ConstU32::<10>,
74/// >(4));
75/// ```
76///
77/// The output of this macro will roughly look like:
78///
79/// ```ignore
80/// struct TestSolution {
81/// 	voters1: vec![(u16 /* voter */, u8 /* target */)]
82/// 	voters2: vec![
83/// 		(u16 /* voter */, [u8 /* first target*/, Perbill /* proportion for first target */], u8 /* last target */)
84/// 	]
85/// 	voters3: vec![
86/// 		(u16 /* voter */,  [
87/// 			(u8 /* first target*/, Perbill /* proportion for first target */ ),
88/// 			(u8 /* second target */, Perbill /* proportion for second target*/)
89/// 		], u8 /* last target */)
90/// 		],
91/// 	voters4: ...,
92/// }
93///
94/// impl NposSolution for TestSolution {};
95/// impl Solution for TestSolution {};
96/// ```
97///
98/// The given struct provides function to convert from/to `Assignment` as part of
99/// `frame_election_provider_support::NposSolution` trait:
100///
101/// - `fn from_assignment<..>(..)`
102/// - `fn into_assignment<..>(..)`
103///
104/// ## Compact Encoding
105///
106/// The generated struct is by default deriving both `Encode` and `Decode`. This is okay but could
107/// lead to many `0`s in the solution. If prefixed with `#[compact]`, then a custom compact encoding
108/// for numbers will be used, similar to how `parity-scale-codec`'s `Compact` works.
109///
110/// ```
111/// # use frame_election_provider_solution_type::generate_solution_type;
112/// # use frame_election_provider_support::NposSolution;
113/// # use sp_arithmetic::per_things::Perbill;
114/// # use frame_support::traits::ConstU32;
115/// generate_solution_type!(
116///     #[compact]
117///     pub struct TestSolutionCompact::<
118///          VoterIndex = u16,
119///          TargetIndex = u8,
120///          Accuracy = Perbill,
121///          MaxVoters = ConstU32::<10>,
122///     >(8)
123/// );
124/// ```
125#[proc_macro]
126pub fn generate_solution_type(item: TokenStream) -> TokenStream {
127	let solution_def = syn::parse_macro_input!(item as SolutionDef);
128
129	let imports = imports().unwrap_or_else(|e| e.to_compile_error());
130
131	let def = single_page::generate(solution_def).unwrap_or_else(|e| e.to_compile_error());
132
133	quote!(
134		#imports
135		#def
136	)
137	.into()
138}
139
140struct SolutionDef {
141	vis: syn::Visibility,
142	ident: syn::Ident,
143	voter_type: syn::Type,
144	target_type: syn::Type,
145	weight_type: syn::Type,
146	max_voters: syn::Type,
147	count: usize,
148	compact_encoding: bool,
149}
150
151fn check_attributes(input: ParseStream) -> syn::Result<bool> {
152	let mut attrs = input.call(syn::Attribute::parse_outer).unwrap_or_default();
153	if attrs.len() > 1 {
154		let extra_attr = attrs.pop().expect("attributes vec with len > 1 can be popped");
155		return Err(syn::Error::new_spanned(
156			extra_attr,
157			"compact solution can accept only #[compact]",
158		))
159	}
160	if attrs.is_empty() {
161		return Ok(false)
162	}
163	let attr = attrs.pop().expect("attributes vec with len 1 can be popped.");
164	if attr.path().is_ident("compact") {
165		Ok(true)
166	} else {
167		Err(syn::Error::new_spanned(attr, "compact solution can accept only #[compact]"))
168	}
169}
170
171impl Parse for SolutionDef {
172	fn parse(input: ParseStream) -> syn::Result<Self> {
173		// optional #[compact]
174		let compact_encoding = check_attributes(input)?;
175
176		// <vis> struct <name>
177		let vis: syn::Visibility = input.parse()?;
178		<syn::Token![struct]>::parse(input)?;
179		let ident: syn::Ident = input.parse()?;
180
181		// ::<V, T, W>
182		<syn::Token![::]>::parse(input)?;
183		let generics: syn::AngleBracketedGenericArguments = input.parse()?;
184
185		if generics.args.len() != 4 {
186			return Err(syn_err("Must provide 4 generic args."))
187		}
188
189		let expected_types = ["VoterIndex", "TargetIndex", "Accuracy", "MaxVoters"];
190
191		let mut types: Vec<syn::Type> = generics
192			.args
193			.iter()
194			.zip(expected_types.iter())
195			.map(|(t, expected)| match t {
196				syn::GenericArgument::Type(ty) => {
197					// this is now an error
198					Err(syn::Error::new_spanned(
199						ty,
200						format!("Expected binding: `{} = ...`", expected),
201					))
202				},
203				syn::GenericArgument::AssocType(syn::AssocType { ident, ty, .. }) => {
204					// check that we have the right keyword for this position in the argument list
205					if ident == expected {
206						Ok(ty.clone())
207					} else {
208						Err(syn::Error::new_spanned(ident, format!("Expected `{}`", expected)))
209					}
210				},
211				_ => Err(syn_err("Wrong type of generic provided. Must be a `type`.")),
212			})
213			.collect::<Result<_>>()?;
214
215		let max_voters = types.pop().expect("Vector of length 4 can be popped; qed");
216		let weight_type = types.pop().expect("Vector of length 3 can be popped; qed");
217		let target_type = types.pop().expect("Vector of length 2 can be popped; qed");
218		let voter_type = types.pop().expect("Vector of length 1 can be popped; qed");
219
220		// (<count>)
221		let count_expr: syn::ExprParen = input.parse()?;
222		let count = parse_parenthesized_number::<usize>(count_expr)?;
223
224		Ok(Self {
225			vis,
226			ident,
227			voter_type,
228			target_type,
229			weight_type,
230			max_voters,
231			count,
232			compact_encoding,
233		})
234	}
235}
236
237fn parse_parenthesized_number<N: std::str::FromStr>(input_expr: syn::ExprParen) -> syn::Result<N>
238where
239	<N as std::str::FromStr>::Err: std::fmt::Display,
240{
241	let expr = input_expr.expr;
242	let expr_lit = match *expr {
243		syn::Expr::Lit(count_lit) => count_lit.lit,
244		_ => return Err(syn_err("Count must be literal.")),
245	};
246	let int_lit = match expr_lit {
247		syn::Lit::Int(int_lit) => int_lit,
248		_ => return Err(syn_err("Count must be int literal.")),
249	};
250	int_lit.base10_parse::<N>()
251}
252
253fn imports() -> Result<TokenStream2> {
254	match crate_name("frame-election-provider-support") {
255		Ok(FoundCrate::Itself) => Ok(quote! {
256			use crate as _feps;
257			use _feps::private as _fepsp;
258		}),
259		Ok(FoundCrate::Name(frame_election_provider_support)) => {
260			let ident = syn::Ident::new(&frame_election_provider_support, Span::call_site());
261			Ok(quote!(
262					use #ident as _feps;
263					use _feps::private as _fepsp;
264			))
265		},
266		Err(e) => match crate_name("polkadot-sdk") {
267			Ok(FoundCrate::Name(polkadot_sdk)) => {
268				let ident = syn::Ident::new(&polkadot_sdk, Span::call_site());
269				Ok(quote!(
270					use #ident::frame_election_provider_support as _feps;
271					use _feps::private as _fepsp;
272				))
273			},
274			_ => Err(syn::Error::new(Span::call_site(), e)),
275		},
276	}
277}
278
279#[cfg(test)]
280mod tests {
281	#[test]
282	fn ui_fail() {
283		// Only run the ui tests when `RUN_UI_TESTS` is set.
284		if std::env::var("RUN_UI_TESTS").is_err() {
285			return
286		}
287
288		let cases = trybuild::TestCases::new();
289		cases.compile_fail("tests/ui/fail/*.rs");
290	}
291}