use inflector::Inflector;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{
Data, DataEnum, DeriveInput, Error, Expr, ExprLit, Fields, GenericArgument, Ident, Lit, Meta,
MetaNameValue, PathArguments, Result, Type, TypePath, Variant,
};
pub fn derive(input: DeriveInput) -> Result<TokenStream2> {
let data_enum = match &input.data {
Data::Enum(data_enum) => data_enum,
_ => return Err(Error::new_spanned(&input, "Expected the `Instruction` enum")),
};
let builder_raw_impl = generate_builder_raw_impl(&input.ident, data_enum)?;
let builder_impl = generate_builder_impl(&input.ident, data_enum)?;
let builder_unpaid_impl = generate_builder_unpaid_impl(&input.ident, data_enum)?;
let output = quote! {
pub trait XcmBuilderState {}
pub enum AnythingGoes {}
pub enum PaymentRequired {}
pub enum LoadedHolding {}
pub enum ExplicitUnpaidRequired {}
impl XcmBuilderState for AnythingGoes {}
impl XcmBuilderState for PaymentRequired {}
impl XcmBuilderState for LoadedHolding {}
impl XcmBuilderState for ExplicitUnpaidRequired {}
pub struct XcmBuilder<Call, S: XcmBuilderState> {
pub(crate) instructions: Vec<Instruction<Call>>,
pub state: core::marker::PhantomData<S>,
}
impl<Call> Xcm<Call> {
pub fn builder() -> XcmBuilder<Call, PaymentRequired> {
XcmBuilder::<Call, PaymentRequired> {
instructions: Vec::new(),
state: core::marker::PhantomData,
}
}
pub fn builder_unpaid() -> XcmBuilder<Call, ExplicitUnpaidRequired> {
XcmBuilder::<Call, ExplicitUnpaidRequired> {
instructions: Vec::new(),
state: core::marker::PhantomData,
}
}
pub fn builder_unsafe() -> XcmBuilder<Call, AnythingGoes> {
XcmBuilder::<Call, AnythingGoes> {
instructions: Vec::new(),
state: core::marker::PhantomData,
}
}
}
#builder_impl
#builder_unpaid_impl
#builder_raw_impl
};
Ok(output)
}
fn generate_builder_raw_impl(name: &Ident, data_enum: &DataEnum) -> Result<TokenStream2> {
let methods = data_enum
.variants
.iter()
.map(|variant| convert_variant_to_method(name, variant, None))
.collect::<Result<Vec<_>>>()?;
let output = quote! {
impl<Call> XcmBuilder<Call, AnythingGoes> {
#(#methods)*
pub fn build(self) -> Xcm<Call> {
Xcm(self.instructions)
}
}
};
Ok(output)
}
fn generate_builder_impl(name: &Ident, data_enum: &DataEnum) -> Result<TokenStream2> {
let load_holding_variants = data_enum
.variants
.iter()
.map(|variant| {
let maybe_builder_attr = variant.attrs.iter().find(|attr| match attr.meta {
Meta::List(ref list) => list.path.is_ident("builder"),
_ => false,
});
let builder_attr = match maybe_builder_attr {
Some(builder) => builder.clone(),
None => return Ok(None), };
let Meta::List(ref list) = builder_attr.meta else { unreachable!("We checked before") };
let inner_ident: Ident = syn::parse2(list.tokens.clone()).map_err(|_| {
Error::new_spanned(
&builder_attr,
"Expected `builder(loads_holding)` or `builder(pays_fees)`",
)
})?;
let loads_holding_ident: Ident = syn::parse_quote!(loads_holding);
let pays_fees_ident: Ident = syn::parse_quote!(pays_fees);
if inner_ident == loads_holding_ident {
Ok(Some(variant))
} else if inner_ident == pays_fees_ident {
Ok(None)
} else {
Err(Error::new_spanned(
&builder_attr,
"Expected `builder(loads_holding)` or `builder(pays_fees)`",
))
}
})
.collect::<Result<Vec<_>>>()?;
let load_holding_methods = load_holding_variants
.into_iter()
.flatten()
.map(|variant| {
let method = convert_variant_to_method(
name,
variant,
Some(quote! { XcmBuilder<Call, LoadedHolding> }),
)?;
Ok(method)
})
.collect::<Result<Vec<_>>>()?;
let first_impl = quote! {
impl<Call> XcmBuilder<Call, PaymentRequired> {
#(#load_holding_methods)*
}
};
let allowed_after_load_holding_methods: Vec<TokenStream2> = data_enum
.variants
.iter()
.filter(|variant| variant.ident == "ClearOrigin" || variant.ident == "SetHints")
.map(|variant| {
let method = convert_variant_to_method(name, variant, None)?;
Ok(method)
})
.collect::<Result<Vec<_>>>()?;
let pay_fees_variants = data_enum
.variants
.iter()
.map(|variant| {
let maybe_builder_attr = variant.attrs.iter().find(|attr| match attr.meta {
Meta::List(ref list) => list.path.is_ident("builder"),
_ => false,
});
let builder_attr = match maybe_builder_attr {
Some(builder) => builder.clone(),
None => return Ok(None), };
let Meta::List(ref list) = builder_attr.meta else { unreachable!("We checked before") };
let inner_ident: Ident = syn::parse2(list.tokens.clone()).map_err(|_| {
Error::new_spanned(
&builder_attr,
"Expected `builder(loads_holding)` or `builder(pays_fees)`",
)
})?;
let ident_to_match: Ident = syn::parse_quote!(pays_fees);
if inner_ident == ident_to_match {
Ok(Some(variant))
} else {
Ok(None) }
})
.collect::<Result<Vec<_>>>()?;
let pay_fees_methods = pay_fees_variants
.into_iter()
.flatten()
.map(|variant| {
let method = convert_variant_to_method(
name,
variant,
Some(quote! { XcmBuilder<Call, AnythingGoes> }),
)?;
Ok(method)
})
.collect::<Result<Vec<_>>>()?;
let second_impl = quote! {
impl<Call> XcmBuilder<Call, LoadedHolding> {
#(#allowed_after_load_holding_methods)*
#(#pay_fees_methods)*
}
};
let output = quote! {
#first_impl
#second_impl
};
Ok(output)
}
fn generate_builder_unpaid_impl(name: &Ident, data_enum: &DataEnum) -> Result<TokenStream2> {
let unpaid_execution_variant = data_enum
.variants
.iter()
.find(|variant| variant.ident == "UnpaidExecution")
.ok_or(Error::new_spanned(&data_enum.variants, "No UnpaidExecution instruction"))?;
let method = convert_variant_to_method(
name,
&unpaid_execution_variant,
Some(quote! { XcmBuilder<Call, AnythingGoes> }),
)?;
Ok(quote! {
impl<Call> XcmBuilder<Call, ExplicitUnpaidRequired> {
#method
}
})
}
fn convert_variant_to_method(
name: &Ident,
variant: &Variant,
maybe_return_type: Option<TokenStream2>,
) -> Result<TokenStream2> {
let variant_name = &variant.ident;
let method_name_string = &variant_name.to_string().to_snake_case();
let method_name = syn::Ident::new(method_name_string, variant_name.span());
let docs = get_doc_comments(variant);
let method = match &variant.fields {
Fields::Unit =>
if let Some(return_type) = maybe_return_type {
quote! {
pub fn #method_name(self) -> #return_type {
let mut new_instructions = self.instructions;
new_instructions.push(#name::<Call>::#variant_name);
XcmBuilder {
instructions: new_instructions,
state: core::marker::PhantomData,
}
}
}
} else {
quote! {
pub fn #method_name(mut self) -> Self {
self.instructions.push(#name::<Call>::#variant_name);
self
}
}
},
Fields::Unnamed(fields) => {
let arg_names: Vec<_> = fields
.unnamed
.iter()
.enumerate()
.map(|(index, _)| format_ident!("arg{}", index))
.collect();
let arg_types: Vec<_> = fields.unnamed.iter().map(|field| &field.ty).collect();
if let Some(return_type) = maybe_return_type {
quote! {
pub fn #method_name(self, #(#arg_names: impl Into<#arg_types>),*) -> #return_type {
let mut new_instructions = self.instructions;
#(let #arg_names = #arg_names.into();)*
new_instructions.push(#name::<Call>::#variant_name(#(#arg_names),*));
XcmBuilder {
instructions: new_instructions,
state: core::marker::PhantomData,
}
}
}
} else {
quote! {
pub fn #method_name(mut self, #(#arg_names: impl Into<#arg_types>),*) -> Self {
#(let #arg_names = #arg_names.into();)*
self.instructions.push(#name::<Call>::#variant_name(#(#arg_names),*));
self
}
}
}
},
Fields::Named(fields) => {
let normal_fields: Vec<_> = fields
.named
.iter()
.filter(|field| {
if let Type::Path(TypePath { path, .. }) = &field.ty {
for segment in &path.segments {
if segment.ident == format_ident!("BoundedVec") {
return false;
}
}
true
} else {
true
}
})
.collect();
let bounded_fields: Vec<_> = fields
.named
.iter()
.filter(|field| {
if let Type::Path(TypePath { path, .. }) = &field.ty {
for segment in &path.segments {
if segment.ident == format_ident!("BoundedVec") {
return true;
}
}
false
} else {
false
}
})
.collect();
let arg_names: Vec<_> = normal_fields.iter().map(|field| &field.ident).collect();
let arg_types: Vec<_> = normal_fields.iter().map(|field| &field.ty).collect();
let bounded_names: Vec<_> = bounded_fields.iter().map(|field| &field.ident).collect();
let bounded_types = bounded_fields
.iter()
.map(|field| extract_generic_argument(&field.ty, 0, "BoundedVec's inner type"))
.collect::<Result<Vec<_>>>()?;
let bounded_sizes = bounded_fields
.iter()
.map(|field| extract_generic_argument(&field.ty, 1, "BoundedVec's size"))
.collect::<Result<Vec<_>>>()?;
let comma_in_the_middle = if normal_fields.is_empty() {
quote! {}
} else {
quote! {,}
};
if let Some(return_type) = maybe_return_type {
quote! {
pub fn #method_name(self, #(#arg_names: impl Into<#arg_types>),* #comma_in_the_middle #(#bounded_names: Vec<#bounded_types>),*) -> #return_type {
let mut new_instructions = self.instructions;
#(let #arg_names = #arg_names.into();)*
#(let #bounded_names = BoundedVec::<#bounded_types, #bounded_sizes>::truncate_from(#bounded_names);)*
new_instructions.push(#name::<Call>::#variant_name { #(#arg_names),* #comma_in_the_middle #(#bounded_names),* });
XcmBuilder {
instructions: new_instructions,
state: core::marker::PhantomData,
}
}
}
} else {
quote! {
pub fn #method_name(mut self, #(#arg_names: impl Into<#arg_types>),* #comma_in_the_middle #(#bounded_names: Vec<#bounded_types>),*) -> Self {
#(let #arg_names = #arg_names.into();)*
#(let #bounded_names = BoundedVec::<#bounded_types, #bounded_sizes>::truncate_from(#bounded_names);)*
self.instructions.push(#name::<Call>::#variant_name { #(#arg_names),* #comma_in_the_middle #(#bounded_names),* });
self
}
}
}
},
};
Ok(quote! {
#(#docs)*
#method
})
}
fn get_doc_comments(variant: &Variant) -> Vec<TokenStream2> {
variant
.attrs
.iter()
.filter_map(|attr| match &attr.meta {
Meta::NameValue(MetaNameValue {
value: Expr::Lit(ExprLit { lit: Lit::Str(literal), .. }),
..
}) if attr.path().is_ident("doc") => Some(literal.value()),
_ => None,
})
.map(|doc| syn::parse_str::<TokenStream2>(&format!("/// {}", doc)).unwrap())
.collect()
}
fn extract_generic_argument<'a>(
field_ty: &'a Type,
index: usize,
expected_msg: &str,
) -> Result<&'a Ident> {
if let Type::Path(type_path) = field_ty {
if let Some(segment) = type_path.path.segments.last() {
if let PathArguments::AngleBracketed(angle_brackets) = &segment.arguments {
let args: Vec<_> = angle_brackets.args.iter().collect();
if let Some(GenericArgument::Type(Type::Path(TypePath { path, .. }))) =
args.get(index)
{
return path.get_ident().ok_or_else(|| {
Error::new_spanned(
path,
format!("Expected an identifier for {}", expected_msg),
)
});
}
return Err(Error::new_spanned(
angle_brackets,
format!("Expected a generic argument at index {} for {}", index, expected_msg),
));
}
return Err(Error::new_spanned(
&segment.arguments,
format!("Expected angle-bracketed arguments for {}", expected_msg),
));
}
return Err(Error::new_spanned(
&type_path.path,
format!("Expected at least one path segment for {}", expected_msg),
));
}
Err(Error::new_spanned(field_ty, format!("Expected a path type for {}", expected_msg)))
}