parity_scale_codec_derive/
trait_bounds.rs

1// Copyright 2019 Parity Technologies
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::iter;
16
17use proc_macro2::Ident;
18use syn::{
19	spanned::Spanned,
20	visit::{self, Visit},
21	Generics, Result, Type, TypePath,
22};
23
24use crate::utils::{self, CustomTraitBound};
25
26/// Visits the ast and checks if one of the given idents is found.
27struct ContainIdents<'a> {
28	result: bool,
29	idents: &'a [Ident],
30}
31
32impl<'a, 'ast> Visit<'ast> for ContainIdents<'a> {
33	fn visit_ident(&mut self, i: &'ast Ident) {
34		if self.idents.iter().any(|id| id == i) {
35			self.result = true;
36		}
37	}
38}
39
40/// Checks if the given type contains one of the given idents.
41fn type_contain_idents(ty: &Type, idents: &[Ident]) -> bool {
42	let mut visitor = ContainIdents { result: false, idents };
43	visitor.visit_type(ty);
44	visitor.result
45}
46
47/// Visits the ast and checks if the a type path starts with the given ident.
48struct TypePathStartsWithIdent<'a> {
49	result: bool,
50	ident: &'a Ident,
51}
52
53impl<'a, 'ast> Visit<'ast> for TypePathStartsWithIdent<'a> {
54	fn visit_type_path(&mut self, i: &'ast TypePath) {
55		if let Some(segment) = i.path.segments.first() {
56			if &segment.ident == self.ident {
57				self.result = true;
58				return
59			}
60		}
61
62		visit::visit_type_path(self, i);
63	}
64}
65
66/// Checks if the given type path or any containing type path starts with the given ident.
67fn type_path_or_sub_starts_with_ident(ty: &TypePath, ident: &Ident) -> bool {
68	let mut visitor = TypePathStartsWithIdent { result: false, ident };
69	visitor.visit_type_path(ty);
70	visitor.result
71}
72
73/// Checks if the given type or any containing type path starts with the given ident.
74fn type_or_sub_type_path_starts_with_ident(ty: &Type, ident: &Ident) -> bool {
75	let mut visitor = TypePathStartsWithIdent { result: false, ident };
76	visitor.visit_type(ty);
77	visitor.result
78}
79
80/// Visits the ast and collects all type paths that do not start or contain the given ident.
81///
82/// Returns `T`, `N`, `A` for `Vec<(Recursive<T, N>, A)>` with `Recursive` as ident.
83struct FindTypePathsNotStartOrContainIdent<'a> {
84	result: Vec<TypePath>,
85	ident: &'a Ident,
86}
87
88impl<'a, 'ast> Visit<'ast> for FindTypePathsNotStartOrContainIdent<'a> {
89	fn visit_type_path(&mut self, i: &'ast TypePath) {
90		if type_path_or_sub_starts_with_ident(i, self.ident) {
91			visit::visit_type_path(self, i);
92		} else {
93			self.result.push(i.clone());
94		}
95	}
96}
97
98/// Collects all type paths that do not start or contain the given ident in the given type.
99///
100/// Returns `T`, `N`, `A` for `Vec<(Recursive<T, N>, A)>` with `Recursive` as ident.
101fn find_type_paths_not_start_or_contain_ident(ty: &Type, ident: &Ident) -> Vec<TypePath> {
102	let mut visitor = FindTypePathsNotStartOrContainIdent { result: Vec::new(), ident };
103	visitor.visit_type(ty);
104	visitor.result
105}
106
107#[allow(clippy::too_many_arguments)]
108/// Add required trait bounds to all generic types.
109pub fn add<N>(
110	input_ident: &Ident,
111	generics: &mut Generics,
112	data: &syn::Data,
113	custom_trait_bound: Option<CustomTraitBound<N>>,
114	codec_bound: syn::Path,
115	codec_skip_bound: Option<syn::Path>,
116	dumb_trait_bounds: bool,
117	crate_path: &syn::Path,
118) -> Result<()> {
119	let skip_type_params = match custom_trait_bound {
120		Some(CustomTraitBound::SpecifiedBounds { bounds, .. }) => {
121			generics.make_where_clause().predicates.extend(bounds);
122			return Ok(())
123		},
124		Some(CustomTraitBound::SkipTypeParams { type_names, .. }) =>
125			type_names.into_iter().collect::<Vec<_>>(),
126		None => Vec::new(),
127	};
128
129	let ty_params = generics
130		.type_params()
131		.filter(|tp| skip_type_params.iter().all(|skip| skip != &tp.ident))
132		.map(|tp| tp.ident.clone())
133		.collect::<Vec<_>>();
134	if ty_params.is_empty() {
135		return Ok(())
136	}
137
138	let codec_types =
139		get_types_to_add_trait_bound(input_ident, data, &ty_params, dumb_trait_bounds)?;
140
141	let compact_types = collect_types(data, utils::is_compact)?
142		.into_iter()
143		// Only add a bound if the type uses a generic
144		.filter(|ty| type_contain_idents(ty, &ty_params))
145		.collect::<Vec<_>>();
146
147	let skip_types = if codec_skip_bound.is_some() {
148		let needs_default_bound = |f: &syn::Field| utils::should_skip(&f.attrs);
149		collect_types(data, needs_default_bound)?
150			.into_iter()
151			// Only add a bound if the type uses a generic
152			.filter(|ty| type_contain_idents(ty, &ty_params))
153			.collect::<Vec<_>>()
154	} else {
155		Vec::new()
156	};
157
158	if !codec_types.is_empty() || !compact_types.is_empty() || !skip_types.is_empty() {
159		let where_clause = generics.make_where_clause();
160
161		codec_types
162			.into_iter()
163			.for_each(|ty| where_clause.predicates.push(parse_quote!(#ty : #codec_bound)));
164
165		let has_compact_bound: syn::Path = parse_quote!(#crate_path::HasCompact);
166		compact_types
167			.into_iter()
168			.for_each(|ty| where_clause.predicates.push(parse_quote!(#ty : #has_compact_bound)));
169
170		skip_types.into_iter().for_each(|ty| {
171			let codec_skip_bound = codec_skip_bound.as_ref();
172			where_clause.predicates.push(parse_quote!(#ty : #codec_skip_bound))
173		});
174	}
175
176	Ok(())
177}
178
179/// Returns all types that must be added to the where clause with the respective trait bound.
180fn get_types_to_add_trait_bound(
181	input_ident: &Ident,
182	data: &syn::Data,
183	ty_params: &[Ident],
184	dumb_trait_bound: bool,
185) -> Result<Vec<Type>> {
186	if dumb_trait_bound {
187		Ok(ty_params.iter().map(|t| parse_quote!( #t )).collect())
188	} else {
189		let needs_codec_bound = |f: &syn::Field| {
190			!utils::is_compact(f) &&
191				utils::get_encoded_as_type(f).is_none() &&
192				!utils::should_skip(&f.attrs)
193		};
194		let res = collect_types(data, needs_codec_bound)?
195			.into_iter()
196			// Only add a bound if the type uses a generic
197			.filter(|ty| type_contain_idents(ty, ty_params))
198			// If a struct contains itself as field type, we can not add this type into the where
199			// clause. This is required to work a round the following compiler bug: https://github.com/rust-lang/rust/issues/47032
200			.flat_map(|ty| {
201				find_type_paths_not_start_or_contain_ident(&ty, input_ident)
202					.into_iter()
203					.map(Type::Path)
204					// Remove again types that do not contain any of our generic parameters
205					.filter(|ty| type_contain_idents(ty, ty_params))
206					// Add back the original type, as we don't want to loose it.
207					.chain(iter::once(ty))
208			})
209			// Remove all remaining types that start/contain the input ident to not have them in the
210			// where clause.
211			.filter(|ty| !type_or_sub_type_path_starts_with_ident(ty, input_ident))
212			.collect();
213
214		Ok(res)
215	}
216}
217
218fn collect_types(data: &syn::Data, type_filter: fn(&syn::Field) -> bool) -> Result<Vec<syn::Type>> {
219	use syn::*;
220
221	let types = match *data {
222		Data::Struct(ref data) => match &data.fields {
223			| Fields::Named(FieldsNamed { named: fields, .. }) |
224			Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) =>
225				fields.iter().filter(|f| type_filter(f)).map(|f| f.ty.clone()).collect(),
226
227			Fields::Unit => Vec::new(),
228		},
229
230		Data::Enum(ref data) => data
231			.variants
232			.iter()
233			.filter(|variant| !utils::should_skip(&variant.attrs))
234			.flat_map(|variant| match &variant.fields {
235				| Fields::Named(FieldsNamed { named: fields, .. }) |
236				Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) =>
237					fields.iter().filter(|f| type_filter(f)).map(|f| f.ty.clone()).collect(),
238
239				Fields::Unit => Vec::new(),
240			})
241			.collect(),
242
243		Data::Union(ref data) =>
244			return Err(Error::new(data.union_token.span(), "Union types are not supported.")),
245	};
246
247	Ok(types)
248}