parity_scale_codec_derive/
trait_bounds.rs1use std::iter;
16
17use proc_macro2::Ident;
18use syn::{
19 spanned::Spanned,
20 visit::{self, Visit},
21 Generics, Result, Type, TypePath,
22};
23
24use crate::utils::{self, CustomTraitBound};
25
26struct ContainIdents<'a> {
28 result: bool,
29 idents: &'a [Ident],
30}
31
32impl<'a, 'ast> Visit<'ast> for ContainIdents<'a> {
33 fn visit_ident(&mut self, i: &'ast Ident) {
34 if self.idents.iter().any(|id| id == i) {
35 self.result = true;
36 }
37 }
38}
39
40fn type_contain_idents(ty: &Type, idents: &[Ident]) -> bool {
42 let mut visitor = ContainIdents { result: false, idents };
43 visitor.visit_type(ty);
44 visitor.result
45}
46
47struct TypePathStartsWithIdent<'a> {
49 result: bool,
50 ident: &'a Ident,
51}
52
53impl<'a, 'ast> Visit<'ast> for TypePathStartsWithIdent<'a> {
54 fn visit_type_path(&mut self, i: &'ast TypePath) {
55 if let Some(segment) = i.path.segments.first() {
56 if &segment.ident == self.ident {
57 self.result = true;
58 return
59 }
60 }
61
62 visit::visit_type_path(self, i);
63 }
64}
65
66fn type_path_or_sub_starts_with_ident(ty: &TypePath, ident: &Ident) -> bool {
68 let mut visitor = TypePathStartsWithIdent { result: false, ident };
69 visitor.visit_type_path(ty);
70 visitor.result
71}
72
73fn type_or_sub_type_path_starts_with_ident(ty: &Type, ident: &Ident) -> bool {
75 let mut visitor = TypePathStartsWithIdent { result: false, ident };
76 visitor.visit_type(ty);
77 visitor.result
78}
79
80struct FindTypePathsNotStartOrContainIdent<'a> {
84 result: Vec<TypePath>,
85 ident: &'a Ident,
86}
87
88impl<'a, 'ast> Visit<'ast> for FindTypePathsNotStartOrContainIdent<'a> {
89 fn visit_type_path(&mut self, i: &'ast TypePath) {
90 if type_path_or_sub_starts_with_ident(i, self.ident) {
91 visit::visit_type_path(self, i);
92 } else {
93 self.result.push(i.clone());
94 }
95 }
96}
97
98fn find_type_paths_not_start_or_contain_ident(ty: &Type, ident: &Ident) -> Vec<TypePath> {
102 let mut visitor = FindTypePathsNotStartOrContainIdent { result: Vec::new(), ident };
103 visitor.visit_type(ty);
104 visitor.result
105}
106
107#[allow(clippy::too_many_arguments)]
108pub fn add<N>(
110 input_ident: &Ident,
111 generics: &mut Generics,
112 data: &syn::Data,
113 custom_trait_bound: Option<CustomTraitBound<N>>,
114 codec_bound: syn::Path,
115 codec_skip_bound: Option<syn::Path>,
116 dumb_trait_bounds: bool,
117 crate_path: &syn::Path,
118) -> Result<()> {
119 let skip_type_params = match custom_trait_bound {
120 Some(CustomTraitBound::SpecifiedBounds { bounds, .. }) => {
121 generics.make_where_clause().predicates.extend(bounds);
122 return Ok(())
123 },
124 Some(CustomTraitBound::SkipTypeParams { type_names, .. }) =>
125 type_names.into_iter().collect::<Vec<_>>(),
126 None => Vec::new(),
127 };
128
129 let ty_params = generics
130 .type_params()
131 .filter(|tp| skip_type_params.iter().all(|skip| skip != &tp.ident))
132 .map(|tp| tp.ident.clone())
133 .collect::<Vec<_>>();
134 if ty_params.is_empty() {
135 return Ok(())
136 }
137
138 let codec_types =
139 get_types_to_add_trait_bound(input_ident, data, &ty_params, dumb_trait_bounds)?;
140
141 let compact_types = collect_types(data, utils::is_compact)?
142 .into_iter()
143 .filter(|ty| type_contain_idents(ty, &ty_params))
145 .collect::<Vec<_>>();
146
147 let skip_types = if codec_skip_bound.is_some() {
148 let needs_default_bound = |f: &syn::Field| utils::should_skip(&f.attrs);
149 collect_types(data, needs_default_bound)?
150 .into_iter()
151 .filter(|ty| type_contain_idents(ty, &ty_params))
153 .collect::<Vec<_>>()
154 } else {
155 Vec::new()
156 };
157
158 if !codec_types.is_empty() || !compact_types.is_empty() || !skip_types.is_empty() {
159 let where_clause = generics.make_where_clause();
160
161 codec_types
162 .into_iter()
163 .for_each(|ty| where_clause.predicates.push(parse_quote!(#ty : #codec_bound)));
164
165 let has_compact_bound: syn::Path = parse_quote!(#crate_path::HasCompact);
166 compact_types
167 .into_iter()
168 .for_each(|ty| where_clause.predicates.push(parse_quote!(#ty : #has_compact_bound)));
169
170 skip_types.into_iter().for_each(|ty| {
171 let codec_skip_bound = codec_skip_bound.as_ref();
172 where_clause.predicates.push(parse_quote!(#ty : #codec_skip_bound))
173 });
174 }
175
176 Ok(())
177}
178
179fn get_types_to_add_trait_bound(
181 input_ident: &Ident,
182 data: &syn::Data,
183 ty_params: &[Ident],
184 dumb_trait_bound: bool,
185) -> Result<Vec<Type>> {
186 if dumb_trait_bound {
187 Ok(ty_params.iter().map(|t| parse_quote!( #t )).collect())
188 } else {
189 let needs_codec_bound = |f: &syn::Field| {
190 !utils::is_compact(f) &&
191 utils::get_encoded_as_type(f).is_none() &&
192 !utils::should_skip(&f.attrs)
193 };
194 let res = collect_types(data, needs_codec_bound)?
195 .into_iter()
196 .filter(|ty| type_contain_idents(ty, ty_params))
198 .flat_map(|ty| {
201 find_type_paths_not_start_or_contain_ident(&ty, input_ident)
202 .into_iter()
203 .map(Type::Path)
204 .filter(|ty| type_contain_idents(ty, ty_params))
206 .chain(iter::once(ty))
208 })
209 .filter(|ty| !type_or_sub_type_path_starts_with_ident(ty, input_ident))
212 .collect();
213
214 Ok(res)
215 }
216}
217
218fn collect_types(data: &syn::Data, type_filter: fn(&syn::Field) -> bool) -> Result<Vec<syn::Type>> {
219 use syn::*;
220
221 let types = match *data {
222 Data::Struct(ref data) => match &data.fields {
223 | Fields::Named(FieldsNamed { named: fields, .. }) |
224 Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) =>
225 fields.iter().filter(|f| type_filter(f)).map(|f| f.ty.clone()).collect(),
226
227 Fields::Unit => Vec::new(),
228 },
229
230 Data::Enum(ref data) => data
231 .variants
232 .iter()
233 .filter(|variant| !utils::should_skip(&variant.attrs))
234 .flat_map(|variant| match &variant.fields {
235 | Fields::Named(FieldsNamed { named: fields, .. }) |
236 Fields::Unnamed(FieldsUnnamed { unnamed: fields, .. }) =>
237 fields.iter().filter(|f| type_filter(f)).map(|f| f.ty.clone()).collect(),
238
239 Fields::Unit => Vec::new(),
240 })
241 .collect(),
242
243 Data::Union(ref data) =>
244 return Err(Error::new(data.union_token.span(), "Union types are not supported.")),
245 };
246
247 Ok(types)
248}