1use inflector::Inflector;
20use proc_macro2::TokenStream as TokenStream2;
21use quote::{format_ident, quote};
22use syn::{
23 Data, DataEnum, DeriveInput, Error, Expr, ExprLit, Field, Fields, GenericArgument, Ident, Lit,
24 Meta, MetaNameValue, PathArguments, Result, Type, TypePath, Variant,
25};
26
27pub fn derive(input: DeriveInput) -> Result<TokenStream2> {
28 let data_enum = match &input.data {
29 Data::Enum(data_enum) => data_enum,
30 _ => return Err(Error::new_spanned(&input, "Expected the `Instruction` enum")),
31 };
32 let builder_raw_impl = generate_builder_raw_impl(&input.ident, data_enum)?;
33 let builder_impl = generate_builder_impl(&input.ident, data_enum)?;
34 let builder_unpaid_impl = generate_builder_unpaid_impl(&input.ident, data_enum)?;
35 let output = quote! {
36 pub trait XcmBuilderState {}
38
39 pub enum AnythingGoes {}
41 pub enum PaymentRequired {}
43 pub enum LoadedHolding {}
45 pub enum ExplicitUnpaidRequired {}
47
48 impl XcmBuilderState for AnythingGoes {}
49 impl XcmBuilderState for PaymentRequired {}
50 impl XcmBuilderState for LoadedHolding {}
51 impl XcmBuilderState for ExplicitUnpaidRequired {}
52
53 pub struct XcmBuilder<Call, S: XcmBuilderState> {
55 pub(crate) instructions: Vec<Instruction<Call>>,
56 pub state: core::marker::PhantomData<S>,
57 }
58
59 impl<Call> Xcm<Call> {
60 pub fn builder() -> XcmBuilder<Call, PaymentRequired> {
61 XcmBuilder::<Call, PaymentRequired> {
62 instructions: Vec::new(),
63 state: core::marker::PhantomData,
64 }
65 }
66 pub fn builder_unpaid() -> XcmBuilder<Call, ExplicitUnpaidRequired> {
67 XcmBuilder::<Call, ExplicitUnpaidRequired> {
68 instructions: Vec::new(),
69 state: core::marker::PhantomData,
70 }
71 }
72 pub fn builder_unsafe() -> XcmBuilder<Call, AnythingGoes> {
73 XcmBuilder::<Call, AnythingGoes> {
74 instructions: Vec::new(),
75 state: core::marker::PhantomData,
76 }
77 }
78 }
79 #builder_impl
80 #builder_unpaid_impl
81 #builder_raw_impl
82 };
83 Ok(output)
84}
85
86fn generate_builder_raw_impl(name: &Ident, data_enum: &DataEnum) -> Result<TokenStream2> {
87 let methods = data_enum
88 .variants
89 .iter()
90 .map(|variant| convert_variant_to_method(name, variant, None))
91 .collect::<Result<Vec<_>>>()?;
92 let output = quote! {
93 impl<Call> XcmBuilder<Call, AnythingGoes> {
94 #(#methods)*
95
96 pub fn build(self) -> Xcm<Call> {
97 Xcm(self.instructions)
98 }
99 }
100 };
101 Ok(output)
102}
103
104fn generate_builder_impl(name: &Ident, data_enum: &DataEnum) -> Result<TokenStream2> {
105 let load_holding_variants = data_enum
107 .variants
108 .iter()
109 .map(|variant| {
110 let maybe_builder_attr = variant.attrs.iter().find(|attr| match attr.meta {
111 Meta::List(ref list) => list.path.is_ident("builder"),
112 _ => false,
113 });
114 let builder_attr = match maybe_builder_attr {
115 Some(builder) => builder.clone(),
116 None => return Ok(None), };
119 let Meta::List(ref list) = builder_attr.meta else { unreachable!("We checked before") };
120 let inner_ident: Ident = syn::parse2(list.tokens.clone()).map_err(|_| {
121 Error::new_spanned(
122 &builder_attr,
123 "Expected `builder(loads_holding)` or `builder(pays_fees)`",
124 )
125 })?;
126 let loads_holding_ident: Ident = syn::parse_quote!(loads_holding);
127 let pays_fees_ident: Ident = syn::parse_quote!(pays_fees);
128 if inner_ident == loads_holding_ident {
129 Ok(Some(variant))
130 } else if inner_ident == pays_fees_ident {
131 Ok(None)
132 } else {
133 Err(Error::new_spanned(
134 &builder_attr,
135 "Expected `builder(loads_holding)` or `builder(pays_fees)`",
136 ))
137 }
138 })
139 .collect::<Result<Vec<_>>>()?;
140
141 let load_holding_methods = load_holding_variants
142 .into_iter()
143 .flatten()
144 .map(|variant| {
145 let method = convert_variant_to_method(
146 name,
147 variant,
148 Some(quote! { XcmBuilder<Call, LoadedHolding> }),
149 )?;
150 Ok(method)
151 })
152 .collect::<Result<Vec<_>>>()?;
153
154 let first_impl = quote! {
155 impl<Call> XcmBuilder<Call, PaymentRequired> {
156 #(#load_holding_methods)*
157 }
158 };
159
160 let allowed_after_load_holding_methods: Vec<TokenStream2> = data_enum
162 .variants
163 .iter()
164 .filter(|variant| variant.ident == "ClearOrigin" || variant.ident == "SetHints")
165 .map(|variant| {
166 let method = convert_variant_to_method(name, variant, None)?;
167 Ok(method)
168 })
169 .collect::<Result<Vec<_>>>()?;
170
171 let pay_fees_variants = data_enum
173 .variants
174 .iter()
175 .map(|variant| {
176 let maybe_builder_attr = variant.attrs.iter().find(|attr| match attr.meta {
177 Meta::List(ref list) => list.path.is_ident("builder"),
178 _ => false,
179 });
180 let builder_attr = match maybe_builder_attr {
181 Some(builder) => builder.clone(),
182 None => return Ok(None), };
184 let Meta::List(ref list) = builder_attr.meta else { unreachable!("We checked before") };
185 let inner_ident: Ident = syn::parse2(list.tokens.clone()).map_err(|_| {
186 Error::new_spanned(
187 &builder_attr,
188 "Expected `builder(loads_holding)` or `builder(pays_fees)`",
189 )
190 })?;
191 let ident_to_match: Ident = syn::parse_quote!(pays_fees);
192 if inner_ident == ident_to_match {
193 Ok(Some(variant))
194 } else {
195 Ok(None) }
197 })
198 .collect::<Result<Vec<_>>>()?;
199
200 let pay_fees_methods = pay_fees_variants
201 .into_iter()
202 .flatten()
203 .map(|variant| {
204 let method = convert_variant_to_method(
205 name,
206 variant,
207 Some(quote! { XcmBuilder<Call, AnythingGoes> }),
208 )?;
209 Ok(method)
210 })
211 .collect::<Result<Vec<_>>>()?;
212
213 let second_impl = quote! {
214 impl<Call> XcmBuilder<Call, LoadedHolding> {
215 #(#allowed_after_load_holding_methods)*
216 #(#pay_fees_methods)*
217 }
218 };
219
220 let output = quote! {
221 #first_impl
222 #second_impl
223 };
224
225 Ok(output)
226}
227
228fn generate_builder_unpaid_impl(name: &Ident, data_enum: &DataEnum) -> Result<TokenStream2> {
229 let unpaid_execution_variant = data_enum
230 .variants
231 .iter()
232 .find(|variant| variant.ident == "UnpaidExecution")
233 .ok_or(Error::new_spanned(&data_enum.variants, "No UnpaidExecution instruction"))?;
234 let method = convert_variant_to_method(
235 name,
236 &unpaid_execution_variant,
237 Some(quote! { XcmBuilder<Call, AnythingGoes> }),
238 )?;
239 Ok(quote! {
240 impl<Call> XcmBuilder<Call, ExplicitUnpaidRequired> {
241 #method
242 }
243 })
244}
245
246enum BoundedOrNormal {
249 Normal(Field),
250 Bounded(Field),
251}
252
253fn convert_variant_to_method(
255 name: &Ident,
256 variant: &Variant,
257 maybe_return_type: Option<TokenStream2>,
258) -> Result<TokenStream2> {
259 let variant_name = &variant.ident;
260 let method_name_string = &variant_name.to_string().to_snake_case();
261 let method_name = syn::Ident::new(method_name_string, variant_name.span());
262 let docs = get_doc_comments(variant);
263 let method = match &variant.fields {
264 Fields::Unit =>
265 if let Some(return_type) = maybe_return_type {
266 quote! {
267 pub fn #method_name(self) -> #return_type {
268 let mut new_instructions = self.instructions;
269 new_instructions.push(#name::<Call>::#variant_name);
270 XcmBuilder {
271 instructions: new_instructions,
272 state: core::marker::PhantomData,
273 }
274 }
275 }
276 } else {
277 quote! {
278 pub fn #method_name(mut self) -> Self {
279 self.instructions.push(#name::<Call>::#variant_name);
280 self
281 }
282 }
283 },
284 Fields::Unnamed(fields) => {
285 let arg_names: Vec<_> = fields
286 .unnamed
287 .iter()
288 .enumerate()
289 .map(|(index, _)| format_ident!("arg{}", index))
290 .collect();
291 let arg_types: Vec<_> = fields.unnamed.iter().map(|field| &field.ty).collect();
292 if let Some(return_type) = maybe_return_type {
293 quote! {
294 pub fn #method_name(self, #(#arg_names: impl Into<#arg_types>),*) -> #return_type {
295 let mut new_instructions = self.instructions;
296 #(let #arg_names = #arg_names.into();)*
297 new_instructions.push(#name::<Call>::#variant_name(#(#arg_names),*));
298 XcmBuilder {
299 instructions: new_instructions,
300 state: core::marker::PhantomData,
301 }
302 }
303 }
304 } else {
305 quote! {
306 pub fn #method_name(mut self, #(#arg_names: impl Into<#arg_types>),*) -> Self {
307 #(let #arg_names = #arg_names.into();)*
308 self.instructions.push(#name::<Call>::#variant_name(#(#arg_names),*));
309 self
310 }
311 }
312 }
313 },
314 Fields::Named(fields) => {
315 let fields: Vec<_> = fields
316 .named
317 .iter()
318 .map(|field| {
319 if let Type::Path(TypePath { path, .. }) = &field.ty {
320 for segment in &path.segments {
321 if segment.ident == format_ident!("BoundedVec") {
322 return BoundedOrNormal::Bounded(field.clone());
323 }
324 }
325 BoundedOrNormal::Normal(field.clone())
326 } else {
327 BoundedOrNormal::Normal(field.clone())
328 }
329 })
330 .collect();
331 let arg_names: Vec<_> = fields
332 .iter()
333 .map(|field| match field {
334 BoundedOrNormal::Bounded(field) => &field.ident,
335 BoundedOrNormal::Normal(field) => &field.ident,
336 })
337 .collect();
338 let arg_types: Vec<_> = fields
339 .iter()
340 .map(|field| match field {
341 BoundedOrNormal::Bounded(field) => {
342 let inner_type =
343 extract_generic_argument(&field.ty, 0, "BoundedVec's inner type")?;
344 Ok(quote! {
345 Vec<#inner_type>
346 })
347 },
348 BoundedOrNormal::Normal(field) => {
349 let inner_type = &field.ty;
350 Ok(quote! {
351 impl Into<#inner_type>
352 })
353 },
354 })
355 .collect::<Result<Vec<_>>>()?;
356 let bounded_names: Vec<_> = fields
357 .iter()
358 .filter_map(|field| match field {
359 BoundedOrNormal::Bounded(field) => Some(&field.ident),
360 BoundedOrNormal::Normal(_) => None,
361 })
362 .collect();
363 let normal_names: Vec<_> = fields
364 .iter()
365 .filter_map(|field| match field {
366 BoundedOrNormal::Normal(field) => Some(&field.ident),
367 BoundedOrNormal::Bounded(_) => None,
368 })
369 .collect();
370 let comma_in_the_middle = if normal_names.is_empty() {
371 quote! {}
372 } else {
373 quote! {,}
374 };
375 if let Some(return_type) = maybe_return_type {
376 quote! {
377 pub fn #method_name(self, #(#arg_names: #arg_types),*) -> #return_type {
378 let mut new_instructions = self.instructions;
379 #(let #normal_names = #normal_names.into();)*
380 #(let #bounded_names = BoundedVec::truncate_from(#bounded_names);)*
381 new_instructions.push(#name::<Call>::#variant_name { #(#normal_names),* #comma_in_the_middle #(#bounded_names),* });
382 XcmBuilder {
383 instructions: new_instructions,
384 state: core::marker::PhantomData,
385 }
386 }
387 }
388 } else {
389 quote! {
390 pub fn #method_name(mut self, #(#arg_names: #arg_types),*) -> Self {
391 #(let #normal_names = #normal_names.into();)*
392 #(let #bounded_names = BoundedVec::truncate_from(#bounded_names);)*
393 self.instructions.push(#name::<Call>::#variant_name { #(#normal_names),* #comma_in_the_middle #(#bounded_names),* });
394 self
395 }
396 }
397 }
398 },
399 };
400 Ok(quote! {
401 #(#docs)*
402 #method
403 })
404}
405
406fn get_doc_comments(variant: &Variant) -> Vec<TokenStream2> {
407 variant
408 .attrs
409 .iter()
410 .filter_map(|attr| match &attr.meta {
411 Meta::NameValue(MetaNameValue {
412 value: Expr::Lit(ExprLit { lit: Lit::Str(literal), .. }),
413 ..
414 }) if attr.path().is_ident("doc") => Some(literal.value()),
415 _ => None,
416 })
417 .map(|doc| syn::parse_str::<TokenStream2>(&format!("/// {}", doc)).unwrap())
418 .collect()
419}
420
421fn extract_generic_argument<'a>(
422 field_ty: &'a Type,
423 index: usize,
424 expected_msg: &str,
425) -> Result<&'a Ident> {
426 if let Type::Path(type_path) = field_ty {
427 if let Some(segment) = type_path.path.segments.last() {
428 if let PathArguments::AngleBracketed(angle_brackets) = &segment.arguments {
429 let args: Vec<_> = angle_brackets.args.iter().collect();
430 if let Some(GenericArgument::Type(Type::Path(TypePath { path, .. }))) =
431 args.get(index)
432 {
433 return path.get_ident().ok_or_else(|| {
434 Error::new_spanned(
435 path,
436 format!("Expected an identifier for {}", expected_msg),
437 )
438 });
439 }
440 return Err(Error::new_spanned(
441 angle_brackets,
442 format!("Expected a generic argument at index {} for {}", index, expected_msg),
443 ));
444 }
445 return Err(Error::new_spanned(
446 &segment.arguments,
447 format!("Expected angle-bracketed arguments for {}", expected_msg),
448 ));
449 }
450 return Err(Error::new_spanned(
451 &type_path.path,
452 format!("Expected at least one path segment for {}", expected_msg),
453 ));
454 }
455 Err(Error::new_spanned(field_ty, format!("Expected a path type for {}", expected_msg)))
456}