jsonrpsee_proc_macros/
render_server.rs

1// Copyright 2019-2021 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27use std::collections::HashSet;
28use std::str::FromStr;
29
30use super::RpcDescription;
31use crate::{
32	helpers::{generate_where_clause, is_option},
33	rpc_macro::RpcFnArg,
34};
35use proc_macro2::{Span, TokenStream as TokenStream2};
36use quote::{quote, quote_spanned};
37use syn::Attribute;
38
39impl RpcDescription {
40	pub(super) fn render_server(&self) -> Result<TokenStream2, syn::Error> {
41		let trait_name = quote::format_ident!("{}Server", &self.trait_def.ident);
42		let generics = self.trait_def.generics.clone();
43		let (impl_generics, _, where_clause) = generics.split_for_impl();
44
45		let method_impls = self.render_methods()?;
46		let into_rpc_impl = self.render_into_rpc()?;
47		let async_trait = self.jrps_server_item(quote! { core::__reexports::async_trait });
48
49		// Doc-comment to be associated with the server.
50		let doc_comment = format!("Server trait implementation for the `{}` RPC API.", &self.trait_def.ident);
51
52		let trait_impl = quote! {
53			#[#async_trait]
54			#[doc = #doc_comment]
55			pub trait #trait_name #impl_generics: Sized + Send + Sync + 'static #where_clause {
56				#method_impls
57				#into_rpc_impl
58			}
59		};
60
61		Ok(trait_impl)
62	}
63
64	fn render_methods(&self) -> Result<TokenStream2, syn::Error> {
65		let methods = self.methods.iter().map(|method| {
66			let docs = &method.docs;
67			let mut method_sig = method.signature.clone();
68
69			if method.with_extensions {
70				let ext_ty = self.jrps_server_item(quote! { Extensions });
71				// Add `Extension` as the second parameter to the signature.
72				let ext: syn::FnArg = syn::parse_quote!(ext: &#ext_ty);
73				method_sig.sig.inputs.insert(1, ext);
74			}
75
76			quote! {
77				#docs
78				#method_sig
79			}
80		});
81
82		let subscriptions = self.subscriptions.iter().map(|sub| {
83			let docs = &sub.docs;
84			let subscription_sink_ty = self.jrps_server_item(quote! { PendingSubscriptionSink });
85
86			// Add `SubscriptionSink` as the second input parameter to the signature.
87			let subscription_sink: syn::FnArg = syn::parse_quote!(subscription_sink: #subscription_sink_ty);
88			let mut sub_sig = sub.signature.clone();
89			sub_sig.sig.inputs.insert(1, subscription_sink);
90
91			if sub.with_extensions {
92				let ext_ty = self.jrps_server_item(quote! { Extensions });
93				// Add `Extension` as the third parameter to the signature.
94				let ext: syn::FnArg = syn::parse_quote!(ext: &#ext_ty);
95				sub_sig.sig.inputs.insert(2, ext);
96			}
97
98			quote! {
99				#docs
100				#sub_sig
101			}
102		});
103
104		Ok(quote! {
105			#(#methods)*
106			#(#subscriptions)*
107		})
108	}
109
110	/// Helper that will ignore results of `register_*` method calls, and panic if there have been
111	/// any errors in debug builds.
112	///
113	/// The debug assert is a safeguard should the contract that guarantees the method names to
114	/// never conflict in the macro be broken in the future.
115	fn handle_register_result(&self, tokens: TokenStream2) -> TokenStream2 {
116		let reexports = self.jrps_server_item(quote! { core::__reexports });
117		quote! {{
118			let _res = #tokens;
119			#[cfg(debug_assertions)]
120			if _res.is_err() {
121				#reexports::panic_fail_register();
122			}
123		}}
124	}
125
126	fn render_into_rpc(&self) -> Result<TokenStream2, syn::Error> {
127		let rpc_module = self.jrps_server_item(quote! { RpcModule });
128
129		let mut registered = HashSet::new();
130		let mut errors = Vec::new();
131		let mut check_name = |name: &str, span: Span| {
132			if registered.contains(name) {
133				let message = format!("{name:?} is already defined");
134				errors.push(quote_spanned!(span => compile_error!(#message);));
135			} else {
136				registered.insert(name.to_string());
137			}
138		};
139
140		let methods = self
141			.methods
142			.iter()
143			.map(|method| {
144				// Rust method to invoke (e.g. `self.<foo>(...)`).
145				let rust_method_name = &method.signature.sig.ident;
146				// Name of the RPC method (e.g. `foo_makeSpam`).
147				let rpc_method_name = self.rpc_identifier(&method.name);
148				// `parsing` is the code associated with parsing structure from the
149				// provided `Params` object.
150				// `params_seq` is the comma-delimited sequence of parameters we're passing to the rust function
151				// called..
152				let (parsing, params_seq) = self.render_params_decoding(&method.params, None);
153
154				let into_response = self.jrps_server_item(quote! { IntoResponse });
155
156				check_name(&rpc_method_name, rust_method_name.span());
157
158				if method.signature.sig.asyncness.is_some() {
159					if method.with_extensions {
160						self.handle_register_result(quote! {
161							rpc.register_async_method(#rpc_method_name, |params, context, ext| async move {
162								#parsing
163								#into_response::into_response(context.as_ref().#rust_method_name(&ext, #params_seq).await)
164							})
165						})
166					} else {
167						self.handle_register_result(quote! {
168							rpc.register_async_method(#rpc_method_name, |params, context, _| async move {
169								#parsing
170								#into_response::into_response(context.as_ref().#rust_method_name(#params_seq).await)
171							})
172						})
173					}
174				} else {
175					let register_kind =
176						if method.blocking { quote!(register_blocking_method) } else { quote!(register_method) };
177
178					if method.with_extensions {
179						self.handle_register_result(quote! {
180							rpc.#register_kind(#rpc_method_name, |params, context, ext| {
181								#parsing
182								#into_response::into_response(context.#rust_method_name(&ext, #params_seq))
183							})
184						})
185					} else {
186						self.handle_register_result(quote! {
187							rpc.#register_kind(#rpc_method_name, |params, context, _| {
188								#parsing
189								#into_response::into_response(context.#rust_method_name(#params_seq))
190							})
191						})
192					}
193				}
194			})
195			.collect::<Vec<_>>();
196
197		let subscriptions = self
198			.subscriptions
199			.iter()
200			.map(|sub| {
201				// Rust method to invoke (e.g. `self.<foo>(...)`).
202				let rust_method_name = &sub.signature.sig.ident;
203				// Name of the RPC method to subscribe to (e.g. `foo_sub`).
204				let rpc_sub_name = self.rpc_identifier(&sub.name);
205				// Name of `method` in the subscription response.
206				let rpc_notif_name_override = sub.notif_name_override.as_ref().map(|m| self.rpc_identifier(m));
207				// Name of the RPC method to unsubscribe (e.g. `foo_sub`).
208				let rpc_unsub_name = self.rpc_identifier(&sub.unsubscribe);
209				// `parsing` is the code associated with parsing structure from the
210				// provided `Params` object.
211				// `params_seq` is the comma-delimited sequence of parameters.
212				let pending = proc_macro2::Ident::new("pending", rust_method_name.span());
213				let (parsing, params_seq) = self.render_params_decoding(&sub.params, Some(pending));
214				let sub_err = self.jrps_server_item(quote! { SubscriptionCloseResponse });
215				let into_sub_response = self.jrps_server_item(quote! { IntoSubscriptionCloseResponse });
216
217				check_name(&rpc_sub_name, rust_method_name.span());
218				check_name(&rpc_unsub_name, rust_method_name.span());
219
220				let rpc_notif_name = match rpc_notif_name_override {
221					Some(notif) => {
222						check_name(&notif, rust_method_name.span());
223						notif
224					}
225					None => rpc_sub_name.clone(),
226				};
227
228				if sub.signature.sig.asyncness.is_some() {
229					if sub.with_extensions {
230						self.handle_register_result(quote! {
231							rpc.register_subscription(#rpc_sub_name, #rpc_notif_name, #rpc_unsub_name, |params, mut pending, context, ext| async move {
232								#parsing
233								#into_sub_response::into_response(context.as_ref().#rust_method_name(pending, &ext, #params_seq).await)
234							})
235						})
236					} else {
237						self.handle_register_result(quote! {
238							rpc.register_subscription(#rpc_sub_name, #rpc_notif_name, #rpc_unsub_name, |params, mut pending, context, _| async move {
239								#parsing
240								#into_sub_response::into_response(context.as_ref().#rust_method_name(pending, #params_seq).await)
241							})
242						})
243					}
244				} else if sub.with_extensions {
245					self.handle_register_result(quote! {
246						rpc.register_subscription_raw(#rpc_sub_name, #rpc_notif_name, #rpc_unsub_name, |params, mut pending, context, ext| {
247							#parsing
248							let _ = context.as_ref().#rust_method_name(pending, &ext, #params_seq);
249							#sub_err::None
250						})
251					})
252				} else {
253					self.handle_register_result(quote! {
254						rpc.register_subscription_raw(#rpc_sub_name, #rpc_notif_name, #rpc_unsub_name, |params, mut pending, context, _| {
255							#parsing
256							let _ = context.as_ref().#rust_method_name(pending, #params_seq);
257							#sub_err::None
258						})
259					})
260				}
261			})
262			.collect::<Vec<_>>();
263
264		let method_aliases = self
265			.methods
266			.iter()
267			.map(|method| {
268				let rpc_name = self.rpc_identifier(&method.name);
269				let rust_method_name = &method.signature.sig.ident;
270
271				// Rust method to invoke (e.g. `self.<foo>(...)`).
272				let aliases: Vec<TokenStream2> = method
273					.aliases
274					.iter()
275					.map(|alias| {
276						check_name(alias, rust_method_name.span());
277						self.handle_register_result(quote! {
278							rpc.register_alias(#alias, #rpc_name)
279						})
280					})
281					.collect();
282
283				quote!( #(#aliases)* )
284			})
285			.collect::<Vec<_>>();
286
287		let subscription_aliases = self
288			.subscriptions
289			.iter()
290			.map(|method| {
291				let sub_name = self.rpc_identifier(&method.name);
292				let unsub_name = self.rpc_identifier(&method.unsubscribe);
293				let rust_method_name = &method.signature.sig.ident;
294
295				let sub: Vec<TokenStream2> = method
296					.aliases
297					.iter()
298					.map(|alias| {
299						check_name(alias, rust_method_name.span());
300						self.handle_register_result(quote! {
301							rpc.register_alias(#alias, #sub_name)
302						})
303					})
304					.collect();
305				let unsub: Vec<TokenStream2> = method
306					.unsubscribe_aliases
307					.iter()
308					.map(|alias| {
309						check_name(alias, rust_method_name.span());
310						self.handle_register_result(quote! {
311							rpc.register_alias(#alias, #unsub_name)
312						})
313					})
314					.collect();
315
316				quote! (
317					#(#sub)*
318					#(#unsub)*
319				)
320			})
321			.collect::<Vec<_>>();
322
323		let doc_comment = "Collects all the methods and subscriptions defined in the trait \
324								and adds them into a single `RpcModule`.";
325
326		let sub_tys: Vec<syn::Type> = self.subscriptions.clone().into_iter().map(|s| s.item).collect();
327		let where_clause = generate_where_clause(&self.trait_def, &sub_tys, false, self.server_bounds.as_ref());
328
329		// NOTE(niklasad1): empty where clause is valid rust syntax.
330		Ok(quote! {
331			#[doc = #doc_comment]
332			fn into_rpc(self) -> #rpc_module<Self> where #(#where_clause,)* {
333				let mut rpc = #rpc_module::new(self);
334
335				#(#errors)*
336				#(#methods)*
337				#(#subscriptions)*
338				#(#method_aliases)*
339				#(#subscription_aliases)*
340
341				rpc
342			}
343		})
344	}
345
346	fn render_params_decoding(
347		&self,
348		params: &[RpcFnArg],
349		sub: Option<proc_macro2::Ident>,
350	) -> (TokenStream2, TokenStream2) {
351		if params.is_empty() {
352			return (TokenStream2::default(), TokenStream2::default());
353		}
354
355		let params_fields_seq = params.iter().map(RpcFnArg::arg_pat);
356		let params_fields = quote! { #(#params_fields_seq),* };
357
358		let reexports = self.jrps_server_item(quote! { core::__reexports });
359
360		let error_ret = if let Some(pending) = &sub {
361			let tokio = quote! { #reexports::tokio };
362			let sub_err = self.jrps_server_item(quote! { SubscriptionCloseResponse });
363			quote! {
364				#tokio::spawn(#pending.reject(e));
365				return #sub_err::None;
366			}
367		} else {
368			let response_payload = self.jrps_server_item(quote! { ResponsePayload });
369			quote! {
370				return #response_payload::error(e);
371			}
372		};
373
374		// Code to decode sequence of parameters from a JSON array.
375		let decode_array = {
376			let decode_fields = params.iter().map(|RpcFnArg { arg_pat, ty, .. }| {
377				let is_option = is_option(ty);
378				let next_method = if is_option { quote!(optional_next) } else { quote!(next) };
379				quote! {
380					let #arg_pat: #ty = match seq.#next_method() {
381						Ok(v) => v,
382						Err(e) => {
383							#reexports::log_fail_parse(stringify!(#arg_pat), stringify!(#ty), &e, #is_option);
384							#error_ret
385						}
386					};
387				}
388			});
389
390			quote! {
391				let mut seq = params.sequence();
392				#(#decode_fields);*
393				(#params_fields)
394			}
395		};
396
397		// Code to decode sequence of parameters from a JSON object (aka map).
398		let decode_map = {
399			let generics = (0..params.len()).map(|n| quote::format_ident!("G{}", n));
400
401			let serde = self.jrps_server_item(quote! { core::__reexports::serde });
402			let serde_crate = serde.to_string();
403
404			let fields = params.iter().zip(generics.clone()).map(|(fn_arg, ty)| {
405				let arg_pat = fn_arg.arg_pat();
406				let name = fn_arg.name();
407
408				let mut alias_vals = String::new();
409				alias_vals.push_str(&format!(r#"alias = "{}""#, heck::ToSnakeCase::to_snake_case(name.as_str())));
410				alias_vals.push(',');
411				alias_vals
412					.push_str(&format!(r#"alias = "{}""#, heck::ToLowerCamelCase::to_lower_camel_case(name.as_str())));
413
414				let serde_rename = quote!(#[serde(rename = #name)]);
415
416				let alias = TokenStream2::from_str(alias_vals.as_str()).unwrap();
417
418				let serde_alias: Attribute = syn::parse_quote! {
419					#[serde(#alias)]
420				};
421
422				quote! {
423					#serde_alias
424					#serde_rename
425					#arg_pat: #ty,
426				}
427			});
428			let destruct = params.iter().map(RpcFnArg::arg_pat).map(|a| quote!(parsed.#a));
429			let types = params.iter().map(RpcFnArg::ty);
430
431			quote! {
432				#[derive(#serde::Deserialize)]
433				#[serde(crate = #serde_crate)]
434				struct ParamsObject<#(#generics,)*> {
435					#(#fields)*
436				}
437
438				let parsed: ParamsObject<#(#types,)*> = match params.parse() {
439					Ok(p) => p,
440					Err(e) => {
441						#reexports::log_fail_parse_as_object(&e);
442						#error_ret
443					}
444				};
445
446				(#(#destruct),*)
447			}
448		};
449
450		let parsing = quote! {
451			let (#params_fields) = if params.is_object() {
452				#decode_map
453			} else {
454				#decode_array
455			};
456		};
457
458		(parsing, params_fields)
459	}
460}