jsonrpsee_proc_macros/
visitor.rs1use std::collections::HashSet;
28
29use syn::visit::{self, Visit};
30use syn::Ident;
31
32#[derive(Default, Debug)]
36pub(crate) struct FindSubscriptionParams {
37 pub(crate) generic_sub_params: HashSet<Ident>,
38 pub(crate) all_type_params: HashSet<Ident>,
39}
40
41pub struct FindAllParams {
43 pub(crate) trait_generics: HashSet<syn::Ident>,
44 pub(crate) input_params: HashSet<syn::Ident>,
45 pub(crate) ret_params: HashSet<syn::Ident>,
46 pub(crate) sub_params: HashSet<syn::Ident>,
47 pub(crate) visiting_return_type: bool,
48 pub(crate) visiting_fn_arg: bool,
49}
50
51impl FindAllParams {
52 pub fn new(sub_params: HashSet<syn::Ident>) -> Self {
55 Self {
56 trait_generics: HashSet::new(),
57 input_params: HashSet::new(),
58 ret_params: HashSet::new(),
59 sub_params,
60 visiting_return_type: false,
61 visiting_fn_arg: false,
62 }
63 }
64}
65
66impl<'ast> Visit<'ast> for FindAllParams {
67 fn visit_type_param(&mut self, ty_param: &'ast syn::TypeParam) {
69 self.trait_generics.insert(ty_param.ident.clone());
70 }
71
72 fn visit_return_type(&mut self, return_type: &'ast syn::ReturnType) {
75 self.visiting_return_type = true;
76 visit::visit_return_type(self, return_type);
77 self.visiting_return_type = false
78 }
79
80 fn visit_ident(&mut self, ident: &'ast syn::Ident) {
82 if self.trait_generics.contains(ident) {
83 if self.visiting_return_type {
84 self.ret_params.insert(ident.clone());
85 }
86 if self.visiting_fn_arg {
87 self.input_params.insert(ident.clone());
88 }
89 }
90 }
91
92 fn visit_fn_arg(&mut self, arg: &'ast syn::FnArg) {
95 self.visiting_fn_arg = true;
96 visit::visit_fn_arg(self, arg);
97 self.visiting_fn_arg = false;
98 }
99}
100
101impl FindSubscriptionParams {
102 pub fn visit(mut self, tys: &[syn::Type]) -> HashSet<Ident> {
104 for ty in tys {
105 self.visit_type(ty);
106 }
107 self.generic_sub_params
108 }
109
110 pub fn new(all_type_params: HashSet<Ident>) -> Self {
115 Self { generic_sub_params: HashSet::new(), all_type_params }
116 }
117
118 fn visit_path(&mut self, path: &syn::Path) {
120 if path.leading_colon.is_none() && path.segments.len() == 1 {
121 let id = &path.segments[0].ident;
122 if self.all_type_params.contains(id) {
123 self.generic_sub_params.insert(id.clone());
124 }
125 }
126 for segment in &path.segments {
127 self.visit_path_segment(segment);
128 }
129 }
130
131 fn visit_type(&mut self, ty: &syn::Type) {
133 match ty {
134 syn::Type::Array(ty) => self.visit_type(&ty.elem),
135 syn::Type::BareFn(ty) => {
136 for arg in &ty.inputs {
137 self.visit_type(&arg.ty);
138 }
139 self.visit_return_type(&ty.output);
140 }
141 syn::Type::Group(ty) => self.visit_type(&ty.elem),
142 syn::Type::ImplTrait(ty) => {
143 for bound in &ty.bounds {
144 self.visit_type_param_bound(bound);
145 }
146 }
147 syn::Type::Macro(ty) => self.visit_macro(&ty.mac),
148 syn::Type::Paren(ty) => self.visit_type(&ty.elem),
149 syn::Type::Path(ty) => {
150 if let Some(qself) = &ty.qself {
151 self.visit_type(&qself.ty);
152 }
153 self.visit_path(&ty.path);
154 }
155 syn::Type::Ptr(ty) => self.visit_type(&ty.elem),
156 syn::Type::Reference(ty) => self.visit_type(&ty.elem),
157 syn::Type::Slice(ty) => self.visit_type(&ty.elem),
158 syn::Type::TraitObject(ty) => {
159 for bound in &ty.bounds {
160 self.visit_type_param_bound(bound);
161 }
162 }
163 syn::Type::Tuple(ty) => {
164 for elem in &ty.elems {
165 self.visit_type(elem);
166 }
167 }
168 syn::Type::Infer(_) | syn::Type::Never(_) | syn::Type::Verbatim(_) => {}
169 _ => {}
170 }
171 }
172
173 fn visit_path_segment(&mut self, segment: &syn::PathSegment) {
175 self.visit_path_arguments(&segment.arguments);
176 }
177
178 fn visit_path_arguments(&mut self, arguments: &syn::PathArguments) {
180 match arguments {
181 syn::PathArguments::None => {}
182 syn::PathArguments::AngleBracketed(arguments) => {
183 for arg in &arguments.args {
184 match arg {
185 syn::GenericArgument::Type(arg) => self.visit_type(arg),
186 syn::GenericArgument::AssocType(arg) => self.visit_type(&arg.ty),
187 _ => {}
188 }
189 }
190 }
191 syn::PathArguments::Parenthesized(arguments) => {
192 for argument in &arguments.inputs {
193 self.visit_type(argument);
194 }
195 self.visit_return_type(&arguments.output);
196 }
197 }
198 }
199
200 fn visit_return_type(&mut self, return_type: &syn::ReturnType) {
202 match return_type {
203 syn::ReturnType::Default => {}
204 syn::ReturnType::Type(_, output) => self.visit_type(output),
205 }
206 }
207
208 fn visit_type_param_bound(&mut self, bound: &syn::TypeParamBound) {
210 if let syn::TypeParamBound::Trait(bound) = bound {
211 self.visit_path(&bound.path);
212 }
213 }
214
215 fn visit_macro(&mut self, _mac: &syn::Macro) {}
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use syn::{parse_quote, Type};
228
229 #[test]
230 fn it_works() {
231 let t: Type = parse_quote!(Vec<T>);
232 let id: Ident = parse_quote!(T);
233
234 let mut exp = HashSet::new();
235 exp.insert(id);
236 let generics = exp.clone();
237
238 assert_eq!(exp, FindSubscriptionParams::new(generics).visit(&[t]));
239 }
240
241 #[test]
242 fn several_type_params() {
243 let t: Type = parse_quote!(Vec<(A, B, C)>);
244
245 let mut generics: HashSet<syn::Ident> = HashSet::new();
246 let mut exp = HashSet::new();
247
248 generics.insert(parse_quote!(A));
249 generics.insert(parse_quote!(B));
250 generics.insert(parse_quote!(C));
251 generics.insert(parse_quote!(D));
252
253 exp.insert(parse_quote!(A));
254 exp.insert(parse_quote!(B));
255 exp.insert(parse_quote!(C));
256
257 assert_eq!(exp, FindSubscriptionParams::new(generics).visit(&[t]));
258 }
259
260 #[test]
261 fn nested_type() {
262 let t: Type = parse_quote!(Vec<Foo<A, B>>);
263
264 let mut generics: HashSet<syn::Ident> = HashSet::new();
265 let mut exp = HashSet::new();
266
267 generics.insert(parse_quote!(A));
268 generics.insert(parse_quote!(B));
269 generics.insert(parse_quote!(C));
270 generics.insert(parse_quote!(D));
271
272 exp.insert(parse_quote!(A));
273 exp.insert(parse_quote!(B));
274
275 assert_eq!(exp, FindSubscriptionParams::new(generics).visit(&[t]));
276 }
277}