referrerpolicy=no-referrer-when-downgrade

frame_support_procedural/
derive_impl.rs

1// This file is part of Substrate.
2
3// Copyright (C) Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Implementation of the `derive_impl` attribute macro.
19
20use macro_magic::mm_core::ForeignPath;
21use proc_macro2::TokenStream as TokenStream2;
22use quote::{quote, ToTokens};
23use std::collections::HashSet;
24use syn::{
25	parse2, parse_quote, spanned::Spanned, token, AngleBracketedGenericArguments, Ident, ImplItem,
26	ItemImpl, Path, PathArguments, PathSegment, Result, Token,
27};
28
29mod keyword {
30	syn::custom_keyword!(inject_runtime_type);
31	syn::custom_keyword!(no_aggregated_types);
32}
33
34#[derive(derive_syn_parse::Parse, PartialEq, Eq)]
35pub enum PalletAttrType {
36	#[peek(keyword::inject_runtime_type, name = "inject_runtime_type")]
37	RuntimeType(keyword::inject_runtime_type),
38}
39
40#[derive(derive_syn_parse::Parse)]
41pub struct PalletAttr {
42	_pound: Token![#],
43	#[bracket]
44	_bracket: token::Bracket,
45	#[inside(_bracket)]
46	typ: PalletAttrType,
47}
48
49fn is_runtime_type(item: &syn::ImplItemType) -> bool {
50	item.attrs.iter().any(|attr| {
51		if let Ok(PalletAttr { typ: PalletAttrType::RuntimeType(_), .. }) =
52			parse2::<PalletAttr>(attr.into_token_stream())
53		{
54			return true
55		}
56		false
57	})
58}
59pub struct DeriveImplAttrArgs {
60	pub default_impl_path: Path,
61	pub generics: Option<AngleBracketedGenericArguments>,
62	_as: Option<Token![as]>,
63	pub disambiguation_path: Option<Path>,
64	_comma: Option<Token![,]>,
65	pub no_aggregated_types: Option<keyword::no_aggregated_types>,
66}
67
68impl syn::parse::Parse for DeriveImplAttrArgs {
69	fn parse(input: syn::parse::ParseStream) -> Result<Self> {
70		let mut default_impl_path: Path = input.parse()?;
71		// Extract the generics if any
72		let (default_impl_path, generics) = match default_impl_path.clone().segments.last() {
73			Some(PathSegment { ident, arguments: PathArguments::AngleBracketed(args) }) => {
74				default_impl_path.segments.pop();
75				default_impl_path
76					.segments
77					.push(PathSegment { ident: ident.clone(), arguments: PathArguments::None });
78				(default_impl_path, Some(args.clone()))
79			},
80			Some(PathSegment { arguments: PathArguments::None, .. }) => (default_impl_path, None),
81			_ => return Err(syn::Error::new(default_impl_path.span(), "Invalid default impl path")),
82		};
83
84		let lookahead = input.lookahead1();
85		let (_as, disambiguation_path) = if lookahead.peek(Token![as]) {
86			let _as: Token![as] = input.parse()?;
87			let disambiguation_path: Path = input.parse()?;
88			(Some(_as), Some(disambiguation_path))
89		} else {
90			(None, None)
91		};
92
93		let lookahead = input.lookahead1();
94		let (_comma, no_aggregated_types) = if lookahead.peek(Token![,]) {
95			let _comma: Token![,] = input.parse()?;
96			let no_aggregated_types: keyword::no_aggregated_types = input.parse()?;
97			(Some(_comma), Some(no_aggregated_types))
98		} else {
99			(None, None)
100		};
101
102		Ok(DeriveImplAttrArgs {
103			default_impl_path,
104			generics,
105			_as,
106			disambiguation_path,
107			_comma,
108			no_aggregated_types,
109		})
110	}
111}
112
113impl ForeignPath for DeriveImplAttrArgs {
114	fn foreign_path(&self) -> &Path {
115		&self.default_impl_path
116	}
117}
118
119impl ToTokens for DeriveImplAttrArgs {
120	fn to_tokens(&self, tokens: &mut TokenStream2) {
121		tokens.extend(self.default_impl_path.to_token_stream());
122		tokens.extend(self.generics.to_token_stream());
123		tokens.extend(self._as.to_token_stream());
124		tokens.extend(self.disambiguation_path.to_token_stream());
125		tokens.extend(self._comma.to_token_stream());
126		tokens.extend(self.no_aggregated_types.to_token_stream());
127	}
128}
129
130/// Gets the [`Ident`] representation of the given [`ImplItem`], if one exists. Otherwise
131/// returns [`None`].
132///
133/// Used by [`combine_impls`] to determine whether we can compare [`ImplItem`]s by [`Ident`]
134/// or not.
135fn impl_item_ident(impl_item: &ImplItem) -> Option<&Ident> {
136	match impl_item {
137		ImplItem::Const(item) => Some(&item.ident),
138		ImplItem::Fn(item) => Some(&item.sig.ident),
139		ImplItem::Type(item) => Some(&item.ident),
140		ImplItem::Macro(item) => item.mac.path.get_ident(),
141		_ => None,
142	}
143}
144
145/// The real meat behind `derive_impl`. Takes in a `local_impl`, which is the impl for which we
146/// want to implement defaults (i.e. the one the attribute macro is attached to), and a
147/// `foreign_impl`, which is the impl containing the defaults we want to use, and returns an
148/// [`ItemImpl`] containing the final generated impl.
149///
150/// This process has the following caveats:
151/// * Colliding items that have an ident are not copied into `local_impl`
152/// * Uncolliding items that have an ident are copied into `local_impl` but are qualified as `type
153///   #ident = <#default_impl_path as #disambiguation_path>::#ident;`
154/// * Items that lack an ident are de-duplicated so only unique items that lack an ident are copied
155///   into `local_impl`. Items that lack an ident and also exist verbatim in `local_impl` are not
156///   copied over.
157fn combine_impls(
158	local_impl: ItemImpl,
159	foreign_impl: ItemImpl,
160	default_impl_path: Path,
161	disambiguation_path: Path,
162	inject_runtime_types: bool,
163	generics: Option<AngleBracketedGenericArguments>,
164) -> ItemImpl {
165	let (existing_local_keys, existing_unsupported_items): (HashSet<ImplItem>, HashSet<ImplItem>) =
166		local_impl
167			.items
168			.iter()
169			.cloned()
170			.partition(|impl_item| impl_item_ident(impl_item).is_some());
171	let existing_local_keys: HashSet<Ident> = existing_local_keys
172		.into_iter()
173		.filter_map(|item| impl_item_ident(&item).cloned())
174		.collect();
175	let mut final_impl = local_impl;
176	let extended_items = foreign_impl.items.into_iter().filter_map(|item| {
177		if let Some(ident) = impl_item_ident(&item) {
178			if existing_local_keys.contains(&ident) {
179				// do not copy colliding items that have an ident
180				return None
181			}
182			if let ImplItem::Type(typ) = item.clone() {
183				let cfg_attrs = typ
184					.attrs
185					.iter()
186					.filter(|attr| attr.path().get_ident().map_or(false, |ident| ident == "cfg"))
187					.map(|attr| attr.to_token_stream());
188				if is_runtime_type(&typ) {
189					let item: ImplItem = if inject_runtime_types {
190						parse_quote! {
191							#( #cfg_attrs )*
192							type #ident = #ident;
193						}
194					} else {
195						item
196					};
197					return Some(item)
198				}
199				// modify and insert uncolliding type items
200				let modified_item: ImplItem = parse_quote! {
201					#( #cfg_attrs )*
202					type #ident = <#default_impl_path #generics as #disambiguation_path>::#ident;
203				};
204				return Some(modified_item)
205			}
206			// copy uncolliding non-type items that have an ident
207			Some(item)
208		} else {
209			// do not copy colliding items that lack an ident
210			(!existing_unsupported_items.contains(&item))
211				// copy uncolliding items without an ident verbatim
212				.then_some(item)
213		}
214	});
215	final_impl.items.extend(extended_items);
216	final_impl
217}
218
219/// Computes the disambiguation path for the `derive_impl` attribute macro.
220///
221/// When specified explicitly using `as [disambiguation_path]` in the macro attr, the
222/// disambiguation is used as is. If not, we infer the disambiguation path from the
223/// `foreign_impl_path` and the computed scope.
224fn compute_disambiguation_path(
225	disambiguation_path: Option<Path>,
226	foreign_impl: ItemImpl,
227	default_impl_path: Path,
228) -> Result<Path> {
229	match (disambiguation_path, foreign_impl.clone().trait_) {
230		(Some(disambiguation_path), _) => Ok(disambiguation_path),
231		(None, Some((_, foreign_impl_path, _))) =>
232			if default_impl_path.segments.len() > 1 {
233				let scope = default_impl_path.segments.first();
234				Ok(parse_quote!(#scope :: #foreign_impl_path))
235			} else {
236				Ok(foreign_impl_path)
237			},
238		_ => Err(syn::Error::new(
239			default_impl_path.span(),
240			"Impl statement must have a defined type being implemented \
241			for a defined type such as `impl A for B`",
242		)),
243	}
244}
245
246/// Internal implementation behind [`#[derive_impl(..)]`](`macro@crate::derive_impl`).
247///
248/// `default_impl_path`: the module path of the external `impl` statement whose tokens we are
249///	                     importing via `macro_magic`
250///
251/// `foreign_tokens`: the tokens for the external `impl` statement
252///
253/// `local_tokens`: the tokens for the local `impl` statement this attribute is attached to
254///
255/// `disambiguation_path`: the module path of the external trait we will use to qualify
256///                        defaults imported from the external `impl` statement
257pub fn derive_impl(
258	default_impl_path: TokenStream2,
259	foreign_tokens: TokenStream2,
260	local_tokens: TokenStream2,
261	disambiguation_path: Option<Path>,
262	no_aggregated_types: Option<keyword::no_aggregated_types>,
263	generics: Option<AngleBracketedGenericArguments>,
264) -> Result<TokenStream2> {
265	let local_impl = parse2::<ItemImpl>(local_tokens)?;
266	let foreign_impl = parse2::<ItemImpl>(foreign_tokens)?;
267	let default_impl_path = parse2::<Path>(default_impl_path)?;
268
269	let disambiguation_path = compute_disambiguation_path(
270		disambiguation_path,
271		foreign_impl.clone(),
272		default_impl_path.clone(),
273	)?;
274
275	// generate the combined impl
276	let combined_impl = combine_impls(
277		local_impl,
278		foreign_impl,
279		default_impl_path,
280		disambiguation_path,
281		no_aggregated_types.is_none(),
282		generics,
283	);
284
285	Ok(quote!(#combined_impl))
286}
287
288#[test]
289fn test_derive_impl_attr_args_parsing() {
290	parse2::<DeriveImplAttrArgs>(quote!(
291		some::path::TestDefaultConfig as some::path::DefaultConfig
292	))
293	.unwrap();
294	parse2::<DeriveImplAttrArgs>(quote!(
295		frame_system::prelude::testing::TestDefaultConfig as DefaultConfig
296	))
297	.unwrap();
298	parse2::<DeriveImplAttrArgs>(quote!(Something as some::path::DefaultConfig)).unwrap();
299	parse2::<DeriveImplAttrArgs>(quote!(Something as DefaultConfig)).unwrap();
300	parse2::<DeriveImplAttrArgs>(quote!(DefaultConfig)).unwrap();
301	assert!(parse2::<DeriveImplAttrArgs>(quote!()).is_err());
302	assert!(parse2::<DeriveImplAttrArgs>(quote!(Config Config)).is_err());
303}
304
305#[test]
306fn test_runtime_type_with_doc() {
307	#[allow(dead_code)]
308	trait TestTrait {
309		type Test;
310	}
311	#[allow(unused)]
312	struct TestStruct;
313	let p = parse2::<ItemImpl>(quote!(
314		impl TestTrait for TestStruct {
315			/// Some doc
316			#[inject_runtime_type]
317			type Test = u32;
318		}
319	))
320	.unwrap();
321	for item in p.items {
322		if let ImplItem::Type(typ) = item {
323			assert_eq!(is_runtime_type(&typ), true);
324		}
325	}
326}
327
328#[test]
329fn test_disambiguation_path() {
330	let foreign_impl: ItemImpl = parse_quote!(impl SomeTrait for SomeType {});
331	let default_impl_path: Path = parse_quote!(SomeScope::SomeType);
332
333	// disambiguation path is specified
334	let disambiguation_path = compute_disambiguation_path(
335		Some(parse_quote!(SomeScope::SomePath)),
336		foreign_impl.clone(),
337		default_impl_path.clone(),
338	);
339	assert_eq!(disambiguation_path.unwrap(), parse_quote!(SomeScope::SomePath));
340
341	// disambiguation path is not specified and the default_impl_path has more than one segment
342	let disambiguation_path =
343		compute_disambiguation_path(None, foreign_impl.clone(), default_impl_path.clone());
344	assert_eq!(disambiguation_path.unwrap(), parse_quote!(SomeScope::SomeTrait));
345
346	// disambiguation path is not specified and the default_impl_path has only one segment
347	let disambiguation_path =
348		compute_disambiguation_path(None, foreign_impl.clone(), parse_quote!(SomeType));
349	assert_eq!(disambiguation_path.unwrap(), parse_quote!(SomeTrait));
350}
351
352#[test]
353fn test_derive_impl_attr_args_parsing_with_generic() {
354	let args = parse2::<DeriveImplAttrArgs>(quote!(
355		some::path::TestDefaultConfig<Config> as some::path::DefaultConfig
356	))
357	.unwrap();
358	assert_eq!(args.default_impl_path, parse_quote!(some::path::TestDefaultConfig));
359	assert_eq!(args.generics.unwrap().args[0], parse_quote!(Config));
360	let args = parse2::<DeriveImplAttrArgs>(quote!(TestDefaultConfig<Config2>)).unwrap();
361	assert_eq!(args.default_impl_path, parse_quote!(TestDefaultConfig));
362	assert_eq!(args.generics.unwrap().args[0], parse_quote!(Config2));
363}