referrerpolicy=no-referrer-when-downgrade

xcm_procedural/
builder_pattern.rs

1// Copyright (C) Parity Technologies (UK) Ltd.
2// This file is part of Polkadot.
3
4// Polkadot is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// Polkadot is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with Polkadot.  If not, see <http://www.gnu.org/licenses/>.
16
17//! Derive macro for creating XCMs with a builder pattern
18
19use inflector::Inflector;
20use proc_macro2::TokenStream as TokenStream2;
21use quote::{format_ident, quote};
22use syn::{
23	Data, DataEnum, DeriveInput, Error, Expr, ExprLit, Field, Fields, GenericArgument, Ident, Lit,
24	Meta, MetaNameValue, PathArguments, Result, Type, TypePath, Variant,
25};
26
27pub fn derive(input: DeriveInput) -> Result<TokenStream2> {
28	let data_enum = match &input.data {
29		Data::Enum(data_enum) => data_enum,
30		_ => return Err(Error::new_spanned(&input, "Expected the `Instruction` enum")),
31	};
32	let builder_raw_impl = generate_builder_raw_impl(&input.ident, data_enum)?;
33	let builder_impl = generate_builder_impl(&input.ident, data_enum)?;
34	let builder_unpaid_impl = generate_builder_unpaid_impl(&input.ident, data_enum)?;
35	let output = quote! {
36		/// A trait for types that track state inside the XcmBuilder
37		pub trait XcmBuilderState {}
38
39		/// Access to all the instructions
40		pub enum AnythingGoes {}
41		/// You need to pay for execution
42		pub enum PaymentRequired {}
43		/// The holding register was loaded, now to buy execution
44		pub enum LoadedHolding {}
45		/// Need to explicitly state it won't pay for fees
46		pub enum ExplicitUnpaidRequired {}
47
48		impl XcmBuilderState for AnythingGoes {}
49		impl XcmBuilderState for PaymentRequired {}
50		impl XcmBuilderState for LoadedHolding {}
51		impl XcmBuilderState for ExplicitUnpaidRequired {}
52
53		/// Type used to build XCM programs
54		pub struct XcmBuilder<Call, S: XcmBuilderState> {
55			pub(crate) instructions: Vec<Instruction<Call>>,
56			pub state: core::marker::PhantomData<S>,
57		}
58
59		impl<Call> Xcm<Call> {
60			pub fn builder() -> XcmBuilder<Call, PaymentRequired> {
61				XcmBuilder::<Call, PaymentRequired> {
62					instructions: Vec::new(),
63					state: core::marker::PhantomData,
64				}
65			}
66			pub fn builder_unpaid() -> XcmBuilder<Call, ExplicitUnpaidRequired> {
67				XcmBuilder::<Call, ExplicitUnpaidRequired> {
68					instructions: Vec::new(),
69					state: core::marker::PhantomData,
70				}
71			}
72			pub fn builder_unsafe() -> XcmBuilder<Call, AnythingGoes> {
73				XcmBuilder::<Call, AnythingGoes> {
74					instructions: Vec::new(),
75					state: core::marker::PhantomData,
76				}
77			}
78		}
79		#builder_impl
80		#builder_unpaid_impl
81		#builder_raw_impl
82	};
83	Ok(output)
84}
85
86fn generate_builder_raw_impl(name: &Ident, data_enum: &DataEnum) -> Result<TokenStream2> {
87	let methods = data_enum
88		.variants
89		.iter()
90		.map(|variant| convert_variant_to_method(name, variant, None))
91		.collect::<Result<Vec<_>>>()?;
92	let output = quote! {
93		impl<Call> XcmBuilder<Call, AnythingGoes> {
94			#(#methods)*
95
96			pub fn build(self) -> Xcm<Call> {
97				Xcm(self.instructions)
98			}
99		}
100	};
101	Ok(output)
102}
103
104fn generate_builder_impl(name: &Ident, data_enum: &DataEnum) -> Result<TokenStream2> {
105	// We first require an instruction that load the holding register
106	let load_holding_variants = data_enum
107		.variants
108		.iter()
109		.map(|variant| {
110			let maybe_builder_attr = variant.attrs.iter().find(|attr| match attr.meta {
111				Meta::List(ref list) => list.path.is_ident("builder"),
112				_ => false,
113			});
114			let builder_attr = match maybe_builder_attr {
115				Some(builder) => builder.clone(),
116				None => return Ok(None), /* It's not going to be an instruction that loads the
117				                          * holding register */
118			};
119			let Meta::List(ref list) = builder_attr.meta else { unreachable!("We checked before") };
120			let inner_ident: Ident = syn::parse2(list.tokens.clone()).map_err(|_| {
121				Error::new_spanned(
122					&builder_attr,
123					"Expected `builder(loads_holding)` or `builder(pays_fees)`",
124				)
125			})?;
126			let loads_holding_ident: Ident = syn::parse_quote!(loads_holding);
127			let pays_fees_ident: Ident = syn::parse_quote!(pays_fees);
128			if inner_ident == loads_holding_ident {
129				Ok(Some(variant))
130			} else if inner_ident == pays_fees_ident {
131				Ok(None)
132			} else {
133				Err(Error::new_spanned(
134					&builder_attr,
135					"Expected `builder(loads_holding)` or `builder(pays_fees)`",
136				))
137			}
138		})
139		.collect::<Result<Vec<_>>>()?;
140
141	let load_holding_methods = load_holding_variants
142		.into_iter()
143		.flatten()
144		.map(|variant| {
145			let method = convert_variant_to_method(
146				name,
147				variant,
148				Some(quote! { XcmBuilder<Call, LoadedHolding> }),
149			)?;
150			Ok(method)
151		})
152		.collect::<Result<Vec<_>>>()?;
153
154	let first_impl = quote! {
155		impl<Call> XcmBuilder<Call, PaymentRequired> {
156			#(#load_holding_methods)*
157		}
158	};
159
160	// Some operations are allowed after the holding register is loaded
161	let allowed_after_load_holding_methods: Vec<TokenStream2> = data_enum
162		.variants
163		.iter()
164		.filter(|variant| variant.ident == "ClearOrigin" || variant.ident == "SetHints")
165		.map(|variant| {
166			let method = convert_variant_to_method(name, variant, None)?;
167			Ok(method)
168		})
169		.collect::<Result<Vec<_>>>()?;
170
171	// Then we require fees to be paid
172	let pay_fees_variants = data_enum
173		.variants
174		.iter()
175		.map(|variant| {
176			let maybe_builder_attr = variant.attrs.iter().find(|attr| match attr.meta {
177				Meta::List(ref list) => list.path.is_ident("builder"),
178				_ => false,
179			});
180			let builder_attr = match maybe_builder_attr {
181				Some(builder) => builder.clone(),
182				None => return Ok(None), /* It's not going to be an instruction that pays fees */
183			};
184			let Meta::List(ref list) = builder_attr.meta else { unreachable!("We checked before") };
185			let inner_ident: Ident = syn::parse2(list.tokens.clone()).map_err(|_| {
186				Error::new_spanned(
187					&builder_attr,
188					"Expected `builder(loads_holding)` or `builder(pays_fees)`",
189				)
190			})?;
191			let ident_to_match: Ident = syn::parse_quote!(pays_fees);
192			if inner_ident == ident_to_match {
193				Ok(Some(variant))
194			} else {
195				Ok(None) // Must have been `loads_holding` instead.
196			}
197		})
198		.collect::<Result<Vec<_>>>()?;
199
200	let pay_fees_methods = pay_fees_variants
201		.into_iter()
202		.flatten()
203		.map(|variant| {
204			let method = convert_variant_to_method(
205				name,
206				variant,
207				Some(quote! { XcmBuilder<Call, AnythingGoes> }),
208			)?;
209			Ok(method)
210		})
211		.collect::<Result<Vec<_>>>()?;
212
213	let second_impl = quote! {
214		impl<Call> XcmBuilder<Call, LoadedHolding> {
215			#(#allowed_after_load_holding_methods)*
216			#(#pay_fees_methods)*
217		}
218	};
219
220	let output = quote! {
221		#first_impl
222		#second_impl
223	};
224
225	Ok(output)
226}
227
228fn generate_builder_unpaid_impl(name: &Ident, data_enum: &DataEnum) -> Result<TokenStream2> {
229	let unpaid_execution_variant = data_enum
230		.variants
231		.iter()
232		.find(|variant| variant.ident == "UnpaidExecution")
233		.ok_or(Error::new_spanned(&data_enum.variants, "No UnpaidExecution instruction"))?;
234	let method = convert_variant_to_method(
235		name,
236		&unpaid_execution_variant,
237		Some(quote! { XcmBuilder<Call, AnythingGoes> }),
238	)?;
239	Ok(quote! {
240		impl<Call> XcmBuilder<Call, ExplicitUnpaidRequired> {
241			#method
242		}
243	})
244}
245
246// Small helper enum to differentiate between fields that use a `BoundedVec`
247// and the rest.
248enum BoundedOrNormal {
249	Normal(Field),
250	Bounded(Field),
251}
252
253// Have to call with `XcmBuilder<Call, LoadedHolding>` in allowed_after_load_holding_methods.
254fn convert_variant_to_method(
255	name: &Ident,
256	variant: &Variant,
257	maybe_return_type: Option<TokenStream2>,
258) -> Result<TokenStream2> {
259	let variant_name = &variant.ident;
260	let method_name_string = &variant_name.to_string().to_snake_case();
261	let method_name = syn::Ident::new(method_name_string, variant_name.span());
262	let docs = get_doc_comments(variant);
263	let method = match &variant.fields {
264		Fields::Unit =>
265			if let Some(return_type) = maybe_return_type {
266				quote! {
267					pub fn #method_name(self) -> #return_type {
268						let mut new_instructions = self.instructions;
269						new_instructions.push(#name::<Call>::#variant_name);
270						XcmBuilder {
271							instructions: new_instructions,
272							state: core::marker::PhantomData,
273						}
274					}
275				}
276			} else {
277				quote! {
278					pub fn #method_name(mut self) -> Self {
279						self.instructions.push(#name::<Call>::#variant_name);
280						self
281					}
282				}
283			},
284		Fields::Unnamed(fields) => {
285			let arg_names: Vec<_> = fields
286				.unnamed
287				.iter()
288				.enumerate()
289				.map(|(index, _)| format_ident!("arg{}", index))
290				.collect();
291			let arg_types: Vec<_> = fields.unnamed.iter().map(|field| &field.ty).collect();
292			if let Some(return_type) = maybe_return_type {
293				quote! {
294					pub fn #method_name(self, #(#arg_names: impl Into<#arg_types>),*) -> #return_type {
295						let mut new_instructions = self.instructions;
296						#(let #arg_names = #arg_names.into();)*
297						new_instructions.push(#name::<Call>::#variant_name(#(#arg_names),*));
298						XcmBuilder {
299							instructions: new_instructions,
300							state: core::marker::PhantomData,
301						}
302					}
303				}
304			} else {
305				quote! {
306					pub fn #method_name(mut self, #(#arg_names: impl Into<#arg_types>),*) -> Self {
307						#(let #arg_names = #arg_names.into();)*
308						self.instructions.push(#name::<Call>::#variant_name(#(#arg_names),*));
309						self
310					}
311				}
312			}
313		},
314		Fields::Named(fields) => {
315			let fields: Vec<_> = fields
316				.named
317				.iter()
318				.map(|field| {
319					if let Type::Path(TypePath { path, .. }) = &field.ty {
320						for segment in &path.segments {
321							if segment.ident == format_ident!("BoundedVec") {
322								return BoundedOrNormal::Bounded(field.clone());
323							}
324						}
325						BoundedOrNormal::Normal(field.clone())
326					} else {
327						BoundedOrNormal::Normal(field.clone())
328					}
329				})
330				.collect();
331			let arg_names: Vec<_> = fields
332				.iter()
333				.map(|field| match field {
334					BoundedOrNormal::Bounded(field) => &field.ident,
335					BoundedOrNormal::Normal(field) => &field.ident,
336				})
337				.collect();
338			let arg_types: Vec<_> = fields
339				.iter()
340				.map(|field| match field {
341					BoundedOrNormal::Bounded(field) => {
342						let inner_type =
343							extract_generic_argument(&field.ty, 0, "BoundedVec's inner type")?;
344						Ok(quote! {
345							Vec<#inner_type>
346						})
347					},
348					BoundedOrNormal::Normal(field) => {
349						let inner_type = &field.ty;
350						Ok(quote! {
351							impl Into<#inner_type>
352						})
353					},
354				})
355				.collect::<Result<Vec<_>>>()?;
356			let bounded_names: Vec<_> = fields
357				.iter()
358				.filter_map(|field| match field {
359					BoundedOrNormal::Bounded(field) => Some(&field.ident),
360					BoundedOrNormal::Normal(_) => None,
361				})
362				.collect();
363			let normal_names: Vec<_> = fields
364				.iter()
365				.filter_map(|field| match field {
366					BoundedOrNormal::Normal(field) => Some(&field.ident),
367					BoundedOrNormal::Bounded(_) => None,
368				})
369				.collect();
370			let comma_in_the_middle = if normal_names.is_empty() {
371				quote! {}
372			} else {
373				quote! {,}
374			};
375			if let Some(return_type) = maybe_return_type {
376				quote! {
377					pub fn #method_name(self, #(#arg_names: #arg_types),*) -> #return_type {
378						let mut new_instructions = self.instructions;
379						#(let #normal_names = #normal_names.into();)*
380						#(let #bounded_names = BoundedVec::truncate_from(#bounded_names);)*
381						new_instructions.push(#name::<Call>::#variant_name { #(#normal_names),* #comma_in_the_middle #(#bounded_names),* });
382						XcmBuilder {
383							instructions: new_instructions,
384							state: core::marker::PhantomData,
385						}
386					}
387				}
388			} else {
389				quote! {
390					pub fn #method_name(mut self, #(#arg_names: #arg_types),*) -> Self {
391						#(let #normal_names = #normal_names.into();)*
392						#(let #bounded_names = BoundedVec::truncate_from(#bounded_names);)*
393						self.instructions.push(#name::<Call>::#variant_name { #(#normal_names),* #comma_in_the_middle #(#bounded_names),* });
394						self
395					}
396				}
397			}
398		},
399	};
400	Ok(quote! {
401		#(#docs)*
402		#method
403	})
404}
405
406fn get_doc_comments(variant: &Variant) -> Vec<TokenStream2> {
407	variant
408		.attrs
409		.iter()
410		.filter_map(|attr| match &attr.meta {
411			Meta::NameValue(MetaNameValue {
412				value: Expr::Lit(ExprLit { lit: Lit::Str(literal), .. }),
413				..
414			}) if attr.path().is_ident("doc") => Some(literal.value()),
415			_ => None,
416		})
417		.map(|doc| syn::parse_str::<TokenStream2>(&format!("/// {}", doc)).unwrap())
418		.collect()
419}
420
421fn extract_generic_argument<'a>(
422	field_ty: &'a Type,
423	index: usize,
424	expected_msg: &str,
425) -> Result<&'a Ident> {
426	if let Type::Path(type_path) = field_ty {
427		if let Some(segment) = type_path.path.segments.last() {
428			if let PathArguments::AngleBracketed(angle_brackets) = &segment.arguments {
429				let args: Vec<_> = angle_brackets.args.iter().collect();
430				if let Some(GenericArgument::Type(Type::Path(TypePath { path, .. }))) =
431					args.get(index)
432				{
433					return path.get_ident().ok_or_else(|| {
434						Error::new_spanned(
435							path,
436							format!("Expected an identifier for {}", expected_msg),
437						)
438					});
439				}
440				return Err(Error::new_spanned(
441					angle_brackets,
442					format!("Expected a generic argument at index {} for {}", index, expected_msg),
443				));
444			}
445			return Err(Error::new_spanned(
446				&segment.arguments,
447				format!("Expected angle-bracketed arguments for {}", expected_msg),
448			));
449		}
450		return Err(Error::new_spanned(
451			&type_path.path,
452			format!("Expected at least one path segment for {}", expected_msg),
453		));
454	}
455	Err(Error::new_spanned(field_ty, format!("Expected a path type for {}", expected_msg)))
456}