frame_support_procedural/
derive_impl.rs1use macro_magic::mm_core::ForeignPath;
21use proc_macro2::TokenStream as TokenStream2;
22use quote::{quote, ToTokens};
23use std::collections::HashSet;
24use syn::{
25 parse2, parse_quote, spanned::Spanned, token, AngleBracketedGenericArguments, Ident, ImplItem,
26 ItemImpl, Path, PathArguments, PathSegment, Result, Token,
27};
28
29mod keyword {
30 syn::custom_keyword!(inject_runtime_type);
31 syn::custom_keyword!(no_aggregated_types);
32}
33
34#[derive(derive_syn_parse::Parse, PartialEq, Eq)]
35pub enum PalletAttrType {
36 #[peek(keyword::inject_runtime_type, name = "inject_runtime_type")]
37 RuntimeType(keyword::inject_runtime_type),
38}
39
40#[derive(derive_syn_parse::Parse)]
41pub struct PalletAttr {
42 _pound: Token![#],
43 #[bracket]
44 _bracket: token::Bracket,
45 #[inside(_bracket)]
46 typ: PalletAttrType,
47}
48
49fn is_runtime_type(item: &syn::ImplItemType) -> bool {
50 item.attrs.iter().any(|attr| {
51 if let Ok(PalletAttr { typ: PalletAttrType::RuntimeType(_), .. }) =
52 parse2::<PalletAttr>(attr.into_token_stream())
53 {
54 return true
55 }
56 false
57 })
58}
59pub struct DeriveImplAttrArgs {
60 pub default_impl_path: Path,
61 pub generics: Option<AngleBracketedGenericArguments>,
62 _as: Option<Token![as]>,
63 pub disambiguation_path: Option<Path>,
64 _comma: Option<Token![,]>,
65 pub no_aggregated_types: Option<keyword::no_aggregated_types>,
66}
67
68impl syn::parse::Parse for DeriveImplAttrArgs {
69 fn parse(input: syn::parse::ParseStream) -> Result<Self> {
70 let mut default_impl_path: Path = input.parse()?;
71 let (default_impl_path, generics) = match default_impl_path.clone().segments.last() {
73 Some(PathSegment { ident, arguments: PathArguments::AngleBracketed(args) }) => {
74 default_impl_path.segments.pop();
75 default_impl_path
76 .segments
77 .push(PathSegment { ident: ident.clone(), arguments: PathArguments::None });
78 (default_impl_path, Some(args.clone()))
79 },
80 Some(PathSegment { arguments: PathArguments::None, .. }) => (default_impl_path, None),
81 _ => return Err(syn::Error::new(default_impl_path.span(), "Invalid default impl path")),
82 };
83
84 let lookahead = input.lookahead1();
85 let (_as, disambiguation_path) = if lookahead.peek(Token![as]) {
86 let _as: Token![as] = input.parse()?;
87 let disambiguation_path: Path = input.parse()?;
88 (Some(_as), Some(disambiguation_path))
89 } else {
90 (None, None)
91 };
92
93 let lookahead = input.lookahead1();
94 let (_comma, no_aggregated_types) = if lookahead.peek(Token![,]) {
95 let _comma: Token![,] = input.parse()?;
96 let no_aggregated_types: keyword::no_aggregated_types = input.parse()?;
97 (Some(_comma), Some(no_aggregated_types))
98 } else {
99 (None, None)
100 };
101
102 Ok(DeriveImplAttrArgs {
103 default_impl_path,
104 generics,
105 _as,
106 disambiguation_path,
107 _comma,
108 no_aggregated_types,
109 })
110 }
111}
112
113impl ForeignPath for DeriveImplAttrArgs {
114 fn foreign_path(&self) -> &Path {
115 &self.default_impl_path
116 }
117}
118
119impl ToTokens for DeriveImplAttrArgs {
120 fn to_tokens(&self, tokens: &mut TokenStream2) {
121 tokens.extend(self.default_impl_path.to_token_stream());
122 tokens.extend(self.generics.to_token_stream());
123 tokens.extend(self._as.to_token_stream());
124 tokens.extend(self.disambiguation_path.to_token_stream());
125 tokens.extend(self._comma.to_token_stream());
126 tokens.extend(self.no_aggregated_types.to_token_stream());
127 }
128}
129
130fn impl_item_ident(impl_item: &ImplItem) -> Option<&Ident> {
136 match impl_item {
137 ImplItem::Const(item) => Some(&item.ident),
138 ImplItem::Fn(item) => Some(&item.sig.ident),
139 ImplItem::Type(item) => Some(&item.ident),
140 ImplItem::Macro(item) => item.mac.path.get_ident(),
141 _ => None,
142 }
143}
144
145fn combine_impls(
158 local_impl: ItemImpl,
159 foreign_impl: ItemImpl,
160 default_impl_path: Path,
161 disambiguation_path: Path,
162 inject_runtime_types: bool,
163 generics: Option<AngleBracketedGenericArguments>,
164) -> ItemImpl {
165 let (existing_local_keys, existing_unsupported_items): (HashSet<ImplItem>, HashSet<ImplItem>) =
166 local_impl
167 .items
168 .iter()
169 .cloned()
170 .partition(|impl_item| impl_item_ident(impl_item).is_some());
171 let existing_local_keys: HashSet<Ident> = existing_local_keys
172 .into_iter()
173 .filter_map(|item| impl_item_ident(&item).cloned())
174 .collect();
175 let mut final_impl = local_impl;
176 let extended_items = foreign_impl.items.into_iter().filter_map(|item| {
177 if let Some(ident) = impl_item_ident(&item) {
178 if existing_local_keys.contains(&ident) {
179 return None
181 }
182 if let ImplItem::Type(typ) = item.clone() {
183 let cfg_attrs = typ
184 .attrs
185 .iter()
186 .filter(|attr| attr.path().get_ident().map_or(false, |ident| ident == "cfg"))
187 .map(|attr| attr.to_token_stream());
188 if is_runtime_type(&typ) {
189 let item: ImplItem = if inject_runtime_types {
190 parse_quote! {
191 #( #cfg_attrs )*
192 type #ident = #ident;
193 }
194 } else {
195 item
196 };
197 return Some(item)
198 }
199 let modified_item: ImplItem = parse_quote! {
201 #( #cfg_attrs )*
202 type #ident = <#default_impl_path #generics as #disambiguation_path>::#ident;
203 };
204 return Some(modified_item)
205 }
206 Some(item)
208 } else {
209 (!existing_unsupported_items.contains(&item))
211 .then_some(item)
213 }
214 });
215 final_impl.items.extend(extended_items);
216 final_impl
217}
218
219fn compute_disambiguation_path(
225 disambiguation_path: Option<Path>,
226 foreign_impl: ItemImpl,
227 default_impl_path: Path,
228) -> Result<Path> {
229 match (disambiguation_path, foreign_impl.clone().trait_) {
230 (Some(disambiguation_path), _) => Ok(disambiguation_path),
231 (None, Some((_, foreign_impl_path, _))) =>
232 if default_impl_path.segments.len() > 1 {
233 let scope = default_impl_path.segments.first();
234 Ok(parse_quote!(#scope :: #foreign_impl_path))
235 } else {
236 Ok(foreign_impl_path)
237 },
238 _ => Err(syn::Error::new(
239 default_impl_path.span(),
240 "Impl statement must have a defined type being implemented \
241 for a defined type such as `impl A for B`",
242 )),
243 }
244}
245
246pub fn derive_impl(
258 default_impl_path: TokenStream2,
259 foreign_tokens: TokenStream2,
260 local_tokens: TokenStream2,
261 disambiguation_path: Option<Path>,
262 no_aggregated_types: Option<keyword::no_aggregated_types>,
263 generics: Option<AngleBracketedGenericArguments>,
264) -> Result<TokenStream2> {
265 let local_impl = parse2::<ItemImpl>(local_tokens)?;
266 let foreign_impl = parse2::<ItemImpl>(foreign_tokens)?;
267 let default_impl_path = parse2::<Path>(default_impl_path)?;
268
269 let disambiguation_path = compute_disambiguation_path(
270 disambiguation_path,
271 foreign_impl.clone(),
272 default_impl_path.clone(),
273 )?;
274
275 let combined_impl = combine_impls(
277 local_impl,
278 foreign_impl,
279 default_impl_path,
280 disambiguation_path,
281 no_aggregated_types.is_none(),
282 generics,
283 );
284
285 Ok(quote!(#combined_impl))
286}
287
288#[test]
289fn test_derive_impl_attr_args_parsing() {
290 parse2::<DeriveImplAttrArgs>(quote!(
291 some::path::TestDefaultConfig as some::path::DefaultConfig
292 ))
293 .unwrap();
294 parse2::<DeriveImplAttrArgs>(quote!(
295 frame_system::prelude::testing::TestDefaultConfig as DefaultConfig
296 ))
297 .unwrap();
298 parse2::<DeriveImplAttrArgs>(quote!(Something as some::path::DefaultConfig)).unwrap();
299 parse2::<DeriveImplAttrArgs>(quote!(Something as DefaultConfig)).unwrap();
300 parse2::<DeriveImplAttrArgs>(quote!(DefaultConfig)).unwrap();
301 assert!(parse2::<DeriveImplAttrArgs>(quote!()).is_err());
302 assert!(parse2::<DeriveImplAttrArgs>(quote!(Config Config)).is_err());
303}
304
305#[test]
306fn test_runtime_type_with_doc() {
307 #[allow(dead_code)]
308 trait TestTrait {
309 type Test;
310 }
311 #[allow(unused)]
312 struct TestStruct;
313 let p = parse2::<ItemImpl>(quote!(
314 impl TestTrait for TestStruct {
315 #[inject_runtime_type]
317 type Test = u32;
318 }
319 ))
320 .unwrap();
321 for item in p.items {
322 if let ImplItem::Type(typ) = item {
323 assert_eq!(is_runtime_type(&typ), true);
324 }
325 }
326}
327
328#[test]
329fn test_disambiguation_path() {
330 let foreign_impl: ItemImpl = parse_quote!(impl SomeTrait for SomeType {});
331 let default_impl_path: Path = parse_quote!(SomeScope::SomeType);
332
333 let disambiguation_path = compute_disambiguation_path(
335 Some(parse_quote!(SomeScope::SomePath)),
336 foreign_impl.clone(),
337 default_impl_path.clone(),
338 );
339 assert_eq!(disambiguation_path.unwrap(), parse_quote!(SomeScope::SomePath));
340
341 let disambiguation_path =
343 compute_disambiguation_path(None, foreign_impl.clone(), default_impl_path.clone());
344 assert_eq!(disambiguation_path.unwrap(), parse_quote!(SomeScope::SomeTrait));
345
346 let disambiguation_path =
348 compute_disambiguation_path(None, foreign_impl.clone(), parse_quote!(SomeType));
349 assert_eq!(disambiguation_path.unwrap(), parse_quote!(SomeTrait));
350}
351
352#[test]
353fn test_derive_impl_attr_args_parsing_with_generic() {
354 let args = parse2::<DeriveImplAttrArgs>(quote!(
355 some::path::TestDefaultConfig<Config> as some::path::DefaultConfig
356 ))
357 .unwrap();
358 assert_eq!(args.default_impl_path, parse_quote!(some::path::TestDefaultConfig));
359 assert_eq!(args.generics.unwrap().args[0], parse_quote!(Config));
360 let args = parse2::<DeriveImplAttrArgs>(quote!(TestDefaultConfig<Config2>)).unwrap();
361 assert_eq!(args.default_impl_path, parse_quote!(TestDefaultConfig));
362 assert_eq!(args.generics.unwrap().args[0], parse_quote!(Config2));
363}