1use proc_macro2::{Span, TokenStream};
20use proc_macro_crate::{crate_name, FoundCrate};
21use quote::quote;
22use syn::{DeriveInput, Error, Ident, Path};
23
24const CRATE_NAME: &str = "sc-chain-spec";
25const ATTRIBUTE_NAME: &str = "forks";
26
27pub fn extension_derive(ast: &DeriveInput) -> proc_macro::TokenStream {
32 derive(ast, |crate_name, name, generics: &syn::Generics, field_names, field_types, fields| {
33 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
34 let forks = fields
35 .named
36 .iter()
37 .find_map(|f| {
38 if f.attrs.iter().any(|attr| attr.path().is_ident(ATTRIBUTE_NAME)) {
39 let typ = &f.ty;
40 Some(quote! { #typ })
41 } else {
42 None
43 }
44 })
45 .unwrap_or_else(|| quote! { #crate_name::NoExtension });
46
47 quote! {
48 impl #impl_generics #crate_name::Extension for #name #ty_generics #where_clause {
49 type Forks = #forks;
50
51 fn get<T: 'static>(&self) -> Option<&T> {
52 use std::any::{Any, TypeId};
53
54 match TypeId::of::<T>() {
55 #( x if x == TypeId::of::<#field_types>() => <dyn Any>::downcast_ref(&self.#field_names) ),*,
56 _ => None,
57 }
58 }
59
60 fn get_any(&self, t: std::any::TypeId) -> &dyn std::any::Any {
61 use std::any::{Any, TypeId};
62
63 match t {
64 #( x if x == TypeId::of::<#field_types>() => &self.#field_names ),*,
65 _ => self,
66 }
67 }
68
69 fn get_any_mut(&mut self, t: std::any::TypeId) -> &mut dyn std::any::Any {
70 use std::any::{Any, TypeId};
71
72 match t {
73 #( x if x == TypeId::of::<#field_types>() => &mut self.#field_names ),*,
74 _ => self,
75 }
76 }
77 }
78 }
79 })
80}
81
82pub fn group_derive(ast: &DeriveInput) -> proc_macro::TokenStream {
84 derive(ast, |crate_name, name, generics: &syn::Generics, field_names, field_types, _fields| {
85 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
86 let fork_name = Ident::new(&format!("{}Fork", name), Span::call_site());
87
88 let fork_fields = generate_fork_fields(crate_name, &field_names, &field_types);
89 let to_fork = generate_base_to_fork(&fork_name, &field_names);
90 let combine_with = generate_combine_with(&field_names);
91 let to_base = generate_fork_to_base(name, &field_names);
92 let serde_crate_name = match proc_macro_crate::crate_name("serde") {
93 Ok(FoundCrate::Itself) => Ident::new("serde", Span::call_site()),
94 Ok(FoundCrate::Name(name)) => Ident::new(&name, Span::call_site()),
95 Err(e) => {
96 let err =
97 Error::new(Span::call_site(), &format!("Could not find `serde` crate: {}", e))
98 .to_compile_error();
99
100 return quote!( #err )
101 },
102 };
103
104 quote! {
105 #[derive(
106 Debug,
107 Clone,
108 PartialEq,
109 #serde_crate_name::Serialize,
110 #serde_crate_name::Deserialize,
111 ChainSpecExtension,
112 )]
113 pub struct #fork_name #ty_generics #where_clause {
114 #fork_fields
115 }
116
117 impl #impl_generics #crate_name::Group for #name #ty_generics #where_clause {
118 type Fork = #fork_name #ty_generics;
119
120 fn to_fork(self) -> Self::Fork {
121 use #crate_name::Group;
122 #to_fork
123 }
124 }
125
126 impl #impl_generics #crate_name::Fork for #fork_name #ty_generics #where_clause {
127 type Base = #name #ty_generics;
128
129 fn combine_with(&mut self, other: Self) {
130 use #crate_name::Fork;
131 #combine_with
132 }
133
134 fn to_base(self) -> Option<Self::Base> {
135 use #crate_name::Fork;
136 #to_base
137 }
138 }
139 }
140 })
141}
142
143pub fn derive(
144 ast: &DeriveInput,
145 derive: impl Fn(
146 &Path,
147 &Ident,
148 &syn::Generics,
149 Vec<&Ident>,
150 Vec<&syn::Type>,
151 &syn::FieldsNamed,
152 ) -> TokenStream,
153) -> proc_macro::TokenStream {
154 let err = || {
155 let err = Error::new(
156 Span::call_site(),
157 "ChainSpecGroup is only available for structs with named fields.",
158 )
159 .to_compile_error();
160 quote!( #err ).into()
161 };
162
163 let data = match &ast.data {
164 syn::Data::Struct(ref data) => data,
165 _ => return err(),
166 };
167
168 let fields = match &data.fields {
169 syn::Fields::Named(ref named) => named,
170 _ => return err(),
171 };
172
173 let name = &ast.ident;
174 let crate_path = match crate_name(CRATE_NAME) {
175 Ok(FoundCrate::Itself) => CRATE_NAME.replace("-", "_"),
176 Ok(FoundCrate::Name(chain_spec_name)) => chain_spec_name,
177 Err(e) => match crate_name("polkadot-sdk") {
178 Ok(FoundCrate::Name(sdk)) => format!("{sdk}::{CRATE_NAME}").replace("-", "_"),
179 _ => {
180 return Error::new(Span::call_site(), &e).to_compile_error().into();
181 },
182 },
183 };
184 let crate_path =
185 syn::parse_str::<Path>(&crate_path).expect("crate_name returns valid path; qed");
186 let field_names = fields.named.iter().flat_map(|x| x.ident.as_ref()).collect::<Vec<_>>();
187 let field_types = fields.named.iter().map(|x| &x.ty).collect::<Vec<_>>();
188
189 derive(&crate_path, name, &ast.generics, field_names, field_types, fields).into()
190}
191
192fn generate_fork_fields(crate_path: &Path, names: &[&Ident], types: &[&syn::Type]) -> TokenStream {
193 let crate_path = std::iter::repeat(crate_path);
194 quote! {
195 #( pub #names: Option<<#types as #crate_path::Group>::Fork>, )*
196 }
197}
198
199fn generate_base_to_fork(fork_name: &Ident, names: &[&Ident]) -> TokenStream {
200 let names2 = names.to_vec();
201
202 quote! {
203 #fork_name {
204 #( #names: Some(self.#names2.to_fork()), )*
205 }
206 }
207}
208
209fn generate_combine_with(names: &[&Ident]) -> TokenStream {
210 let names2 = names.to_vec();
211
212 quote! {
213 #( self.#names.combine_with(other.#names2); )*
214 }
215}
216
217fn generate_fork_to_base(fork: &Ident, names: &[&Ident]) -> TokenStream {
218 let names2 = names.to_vec();
219
220 quote! {
221 Some(#fork {
222 #( #names: self.#names2?.to_base()?, )*
223 })
224 }
225}