jsonrpsee_proc_macros/
visitor.rs

1// Copyright 2021 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27use std::collections::HashSet;
28
29use syn::visit::{self, Visit};
30use syn::Ident;
31
32/// Visitor that parses generic type parameters from `syn::Type` by traversing the AST.
33/// A `syn::Type` can any type such as `Vec<T>, T, Foo<A<B<V>>>, usize or similar`.
34/// The implementation is based on <https://github.com/serde-rs/serde/blob/master/serde_derive/src/bound.rs>.
35#[derive(Default, Debug)]
36pub(crate) struct FindSubscriptionParams {
37	pub(crate) generic_sub_params: HashSet<Ident>,
38	pub(crate) all_type_params: HashSet<Ident>,
39}
40
41/// Visitor for the entire `RPC trait`.
42pub struct FindAllParams {
43	pub(crate) trait_generics: HashSet<syn::Ident>,
44	pub(crate) input_params: HashSet<syn::Ident>,
45	pub(crate) ret_params: HashSet<syn::Ident>,
46	pub(crate) sub_params: HashSet<syn::Ident>,
47	pub(crate) visiting_return_type: bool,
48	pub(crate) visiting_fn_arg: bool,
49}
50
51impl FindAllParams {
52	/// Create a visitor to traverse the entire RPC trait.
53	/// It takes the already visited subscription parameters as input.
54	pub fn new(sub_params: HashSet<syn::Ident>) -> Self {
55		Self {
56			trait_generics: HashSet::new(),
57			input_params: HashSet::new(),
58			ret_params: HashSet::new(),
59			sub_params,
60			visiting_return_type: false,
61			visiting_fn_arg: false,
62		}
63	}
64}
65
66impl<'ast> Visit<'ast> for FindAllParams {
67	/// Visit generic type param.
68	fn visit_type_param(&mut self, ty_param: &'ast syn::TypeParam) {
69		self.trait_generics.insert(ty_param.ident.clone());
70	}
71
72	/// Visit return type and mark it as `visiting_return_type`.
73	/// To know whether a given Ident is a function argument or return type when traversing.
74	fn visit_return_type(&mut self, return_type: &'ast syn::ReturnType) {
75		self.visiting_return_type = true;
76		visit::visit_return_type(self, return_type);
77		self.visiting_return_type = false
78	}
79
80	/// Visit ident.
81	fn visit_ident(&mut self, ident: &'ast syn::Ident) {
82		if self.trait_generics.contains(ident) {
83			if self.visiting_return_type {
84				self.ret_params.insert(ident.clone());
85			}
86			if self.visiting_fn_arg {
87				self.input_params.insert(ident.clone());
88			}
89		}
90	}
91
92	/// Visit function argument and mark it as `visiting_fn_arg`.
93	/// To know whether a given Ident is a function argument or return type when traversing.
94	fn visit_fn_arg(&mut self, arg: &'ast syn::FnArg) {
95		self.visiting_fn_arg = true;
96		visit::visit_fn_arg(self, arg);
97		self.visiting_fn_arg = false;
98	}
99}
100
101impl FindSubscriptionParams {
102	/// Visit all types and returns all generic [`struct@syn::Ident`]'s that are subscriptions.
103	pub fn visit(mut self, tys: &[syn::Type]) -> HashSet<Ident> {
104		for ty in tys {
105			self.visit_type(ty);
106		}
107		self.generic_sub_params
108	}
109
110	/// Create a new subscription parameters visitor that takes all
111	/// generic parameters on the RPC trait as input in order to determine
112	/// whether a given ident is a generic type param or not when traversing
113	/// one or more types in `FindSubscriptionParams::visit`.
114	pub fn new(all_type_params: HashSet<Ident>) -> Self {
115		Self { generic_sub_params: HashSet::new(), all_type_params }
116	}
117
118	/// Visit path, if it's a leaf path and generic type param then add it as a subscription param.
119	fn visit_path(&mut self, path: &syn::Path) {
120		if path.leading_colon.is_none() && path.segments.len() == 1 {
121			let id = &path.segments[0].ident;
122			if self.all_type_params.contains(id) {
123				self.generic_sub_params.insert(id.clone());
124			}
125		}
126		for segment in &path.segments {
127			self.visit_path_segment(segment);
128		}
129	}
130
131	/// Traverse syntax tree.
132	fn visit_type(&mut self, ty: &syn::Type) {
133		match ty {
134			syn::Type::Array(ty) => self.visit_type(&ty.elem),
135			syn::Type::BareFn(ty) => {
136				for arg in &ty.inputs {
137					self.visit_type(&arg.ty);
138				}
139				self.visit_return_type(&ty.output);
140			}
141			syn::Type::Group(ty) => self.visit_type(&ty.elem),
142			syn::Type::ImplTrait(ty) => {
143				for bound in &ty.bounds {
144					self.visit_type_param_bound(bound);
145				}
146			}
147			syn::Type::Macro(ty) => self.visit_macro(&ty.mac),
148			syn::Type::Paren(ty) => self.visit_type(&ty.elem),
149			syn::Type::Path(ty) => {
150				if let Some(qself) = &ty.qself {
151					self.visit_type(&qself.ty);
152				}
153				self.visit_path(&ty.path);
154			}
155			syn::Type::Ptr(ty) => self.visit_type(&ty.elem),
156			syn::Type::Reference(ty) => self.visit_type(&ty.elem),
157			syn::Type::Slice(ty) => self.visit_type(&ty.elem),
158			syn::Type::TraitObject(ty) => {
159				for bound in &ty.bounds {
160					self.visit_type_param_bound(bound);
161				}
162			}
163			syn::Type::Tuple(ty) => {
164				for elem in &ty.elems {
165					self.visit_type(elem);
166				}
167			}
168			syn::Type::Infer(_) | syn::Type::Never(_) | syn::Type::Verbatim(_) => {}
169			_ => {}
170		}
171	}
172
173	/// Traverse syntax tree.
174	fn visit_path_segment(&mut self, segment: &syn::PathSegment) {
175		self.visit_path_arguments(&segment.arguments);
176	}
177
178	/// Traverse syntax tree.
179	fn visit_path_arguments(&mut self, arguments: &syn::PathArguments) {
180		match arguments {
181			syn::PathArguments::None => {}
182			syn::PathArguments::AngleBracketed(arguments) => {
183				for arg in &arguments.args {
184					match arg {
185						syn::GenericArgument::Type(arg) => self.visit_type(arg),
186						syn::GenericArgument::AssocType(arg) => self.visit_type(&arg.ty),
187						_ => {}
188					}
189				}
190			}
191			syn::PathArguments::Parenthesized(arguments) => {
192				for argument in &arguments.inputs {
193					self.visit_type(argument);
194				}
195				self.visit_return_type(&arguments.output);
196			}
197		}
198	}
199
200	/// Traverse syntax tree.
201	fn visit_return_type(&mut self, return_type: &syn::ReturnType) {
202		match return_type {
203			syn::ReturnType::Default => {}
204			syn::ReturnType::Type(_, output) => self.visit_type(output),
205		}
206	}
207
208	/// Traverse syntax tree.
209	fn visit_type_param_bound(&mut self, bound: &syn::TypeParamBound) {
210		if let syn::TypeParamBound::Trait(bound) = bound {
211			self.visit_path(&bound.path);
212		}
213	}
214
215	// Type parameter should not be considered used by a macro path.
216	//
217	//     struct TypeMacro<T> {
218	//         mac: T!(),
219	//         marker: PhantomData<T>,
220	//     }
221	fn visit_macro(&mut self, _mac: &syn::Macro) {}
222}
223
224#[cfg(test)]
225mod tests {
226	use super::*;
227	use syn::{parse_quote, Type};
228
229	#[test]
230	fn it_works() {
231		let t: Type = parse_quote!(Vec<T>);
232		let id: Ident = parse_quote!(T);
233
234		let mut exp = HashSet::new();
235		exp.insert(id);
236		let generics = exp.clone();
237
238		assert_eq!(exp, FindSubscriptionParams::new(generics).visit(&[t]));
239	}
240
241	#[test]
242	fn several_type_params() {
243		let t: Type = parse_quote!(Vec<(A, B, C)>);
244
245		let mut generics: HashSet<syn::Ident> = HashSet::new();
246		let mut exp = HashSet::new();
247
248		generics.insert(parse_quote!(A));
249		generics.insert(parse_quote!(B));
250		generics.insert(parse_quote!(C));
251		generics.insert(parse_quote!(D));
252
253		exp.insert(parse_quote!(A));
254		exp.insert(parse_quote!(B));
255		exp.insert(parse_quote!(C));
256
257		assert_eq!(exp, FindSubscriptionParams::new(generics).visit(&[t]));
258	}
259
260	#[test]
261	fn nested_type() {
262		let t: Type = parse_quote!(Vec<Foo<A, B>>);
263
264		let mut generics: HashSet<syn::Ident> = HashSet::new();
265		let mut exp = HashSet::new();
266
267		generics.insert(parse_quote!(A));
268		generics.insert(parse_quote!(B));
269		generics.insert(parse_quote!(C));
270		generics.insert(parse_quote!(D));
271
272		exp.insert(parse_quote!(A));
273		exp.insert(parse_quote!(B));
274
275		assert_eq!(exp, FindSubscriptionParams::new(generics).visit(&[t]));
276	}
277}