scale_info_derive/
trait_bounds.rs

1// Copyright 2019-2022 Parity Technologies (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use alloc::vec::Vec;
16use proc_macro2::Ident;
17use syn::{
18    parse_quote,
19    punctuated::Punctuated,
20    spanned::Spanned,
21    visit::{self, Visit},
22    Generics, Result, Type, TypePath, WhereClause,
23};
24
25use crate::{attr::Attributes, utils};
26
27/// Generates a where clause for a `TypeInfo` impl, adding `TypeInfo + 'static` bounds to all
28/// relevant generic types including associated types (e.g. `T::A: TypeInfo`), correctly dealing
29/// with self-referential types.
30///
31/// # Effect of attributes
32///
33/// `#[scale_info(skip_type_params(..))]`
34///
35/// Will not add `TypeInfo` bounds for any type parameters skipped via this attribute.
36///
37/// `#[scale_info(bounds(..))]`
38///
39/// Replaces *all* auto-generated trait bounds with the user-defined ones.
40pub fn make_where_clause<'a>(
41    attrs: &'a Attributes,
42    input_ident: &'a Ident,
43    generics: &'a Generics,
44    data: &'a syn::Data,
45    scale_info: &syn::Path,
46) -> Result<WhereClause> {
47    let mut where_clause = generics
48        .where_clause
49        .clone()
50        .unwrap_or_else(|| WhereClause {
51            where_token: <syn::Token![where]>::default(),
52            predicates: Punctuated::new(),
53        });
54
55    // Use custom bounds as where clause.
56    if let Some(custom_bounds) = attrs.bounds() {
57        custom_bounds.extend_where_clause(&mut where_clause);
58
59        // `'static` lifetime bounds are always required for type parameters, because of the
60        // requirement on `std::any::TypeId::of` for any field type constructor.
61        for type_param in generics.type_params() {
62            let ident = &type_param.ident;
63            where_clause.predicates.push(parse_quote!(#ident: 'static))
64        }
65
66        return Ok(where_clause);
67    }
68
69    for lifetime in generics.lifetimes() {
70        where_clause
71            .predicates
72            .push(parse_quote!(#lifetime: 'static))
73    }
74
75    let ty_params_ids = generics
76        .type_params()
77        .map(|type_param| type_param.ident.clone())
78        .collect::<Vec<Ident>>();
79
80    if ty_params_ids.is_empty() {
81        return Ok(where_clause);
82    }
83
84    let types = collect_types_to_bind(input_ident, data, &ty_params_ids)?;
85
86    types.into_iter().for_each(|(ty, is_compact)| {
87        if is_compact {
88            where_clause
89                .predicates
90                .push(parse_quote!(#ty : #scale_info :: scale::HasCompact));
91        } else {
92            where_clause
93                .predicates
94                .push(parse_quote!(#ty : #scale_info ::TypeInfo + 'static));
95        }
96    });
97
98    generics.type_params().for_each(|type_param| {
99        let ident = type_param.ident.clone();
100        let mut bounds = type_param.bounds.clone();
101        if attrs
102            .skip_type_params()
103            .map_or(true, |skip| !skip.skip(type_param))
104        {
105            bounds.push(parse_quote!(#scale_info ::TypeInfo));
106        }
107        bounds.push(parse_quote!('static));
108        where_clause
109            .predicates
110            .push(parse_quote!( #ident : #bounds));
111    });
112
113    Ok(where_clause)
114}
115
116/// Visits the ast and checks if the given type contains one of the given
117/// idents.
118fn type_contains_idents(ty: &Type, idents: &[Ident]) -> bool {
119    struct ContainIdents<'a> {
120        result: bool,
121        idents: &'a [Ident],
122    }
123
124    impl<'a, 'ast> Visit<'ast> for ContainIdents<'a> {
125        fn visit_ident(&mut self, i: &'ast Ident) {
126            if self.idents.iter().any(|id| id == i) {
127                self.result = true;
128            }
129        }
130    }
131
132    let mut visitor = ContainIdents {
133        result: false,
134        idents,
135    };
136    visitor.visit_type(ty);
137    visitor.result
138}
139
140/// Checks if the given type or any containing type path starts with the given ident.
141fn type_or_sub_type_path_starts_with_ident(ty: &Type, ident: &Ident) -> bool {
142    // Visits the ast and checks if the a type path starts with the given ident.
143    struct TypePathStartsWithIdent<'a> {
144        result: bool,
145        ident: &'a Ident,
146    }
147
148    impl<'a, 'ast> Visit<'ast> for TypePathStartsWithIdent<'a> {
149        fn visit_type_path(&mut self, i: &'ast TypePath) {
150            if i.qself.is_none() {
151                if let Some(segment) = i.path.segments.first() {
152                    if &segment.ident == self.ident {
153                        self.result = true;
154                        return;
155                    }
156                }
157            }
158            visit::visit_type_path(self, i);
159        }
160    }
161
162    let mut visitor = TypePathStartsWithIdent {
163        result: false,
164        ident,
165    };
166    visitor.visit_type(ty);
167    visitor.result
168}
169
170/// Returns all types that must be added to the where clause with a boolean
171/// indicating if the field is [`scale::Compact`] or not.
172fn collect_types_to_bind(
173    input_ident: &Ident,
174    data: &syn::Data,
175    ty_params: &[Ident],
176) -> Result<Vec<(Type, bool)>> {
177    let types_from_fields = |fields: &Punctuated<syn::Field, _>| -> Vec<(Type, bool)> {
178        fields
179            .iter()
180            .filter(|field| {
181                // Only add a bound if the type uses a generic.
182                type_contains_idents(&field.ty, ty_params)
183                &&
184                // Remove all remaining types that start/contain the input ident
185                // to not have them in the where clause.
186                !type_or_sub_type_path_starts_with_ident(&field.ty, input_ident)
187            })
188            .map(|f| (f.ty.clone(), utils::is_compact(f)))
189            .collect()
190    };
191
192    let types = match *data {
193        syn::Data::Struct(ref data) => match &data.fields {
194            syn::Fields::Named(syn::FieldsNamed { named: fields, .. })
195            | syn::Fields::Unnamed(syn::FieldsUnnamed {
196                unnamed: fields, ..
197            }) => types_from_fields(fields),
198            syn::Fields::Unit => Vec::new(),
199        },
200
201        syn::Data::Enum(ref data) => data
202            .variants
203            .iter()
204            .flat_map(|variant| match &variant.fields {
205                syn::Fields::Named(syn::FieldsNamed { named: fields, .. })
206                | syn::Fields::Unnamed(syn::FieldsUnnamed {
207                    unnamed: fields, ..
208                }) => types_from_fields(fields),
209                syn::Fields::Unit => Vec::new(),
210            })
211            .collect(),
212
213        syn::Data::Union(ref data) => {
214            return Err(syn::Error::new(
215                data.union_token.span(),
216                "Union types are not supported.",
217            ))
218        }
219    };
220
221    Ok(types)
222}