1use std::collections::HashSet;
28use std::str::FromStr;
29
30use super::RpcDescription;
31use crate::{
32 helpers::{generate_where_clause, is_option},
33 rpc_macro::RpcFnArg,
34};
35use proc_macro2::{Span, TokenStream as TokenStream2};
36use quote::{quote, quote_spanned};
37use syn::Attribute;
38
39impl RpcDescription {
40 pub(super) fn render_server(&self) -> Result<TokenStream2, syn::Error> {
41 let trait_name = quote::format_ident!("{}Server", &self.trait_def.ident);
42 let generics = self.trait_def.generics.clone();
43 let (impl_generics, _, where_clause) = generics.split_for_impl();
44
45 let method_impls = self.render_methods()?;
46 let into_rpc_impl = self.render_into_rpc()?;
47 let async_trait = self.jrps_server_item(quote! { core::__reexports::async_trait });
48
49 let doc_comment = format!("Server trait implementation for the `{}` RPC API.", &self.trait_def.ident);
51
52 let trait_impl = quote! {
53 #[#async_trait]
54 #[doc = #doc_comment]
55 pub trait #trait_name #impl_generics: Sized + Send + Sync + 'static #where_clause {
56 #method_impls
57 #into_rpc_impl
58 }
59 };
60
61 Ok(trait_impl)
62 }
63
64 fn render_methods(&self) -> Result<TokenStream2, syn::Error> {
65 let methods = self.methods.iter().map(|method| {
66 let docs = &method.docs;
67 let mut method_sig = method.signature.clone();
68
69 if method.with_extensions {
70 let ext_ty = self.jrps_server_item(quote! { Extensions });
71 let ext: syn::FnArg = syn::parse_quote!(ext: &#ext_ty);
73 method_sig.sig.inputs.insert(1, ext);
74 }
75
76 quote! {
77 #docs
78 #method_sig
79 }
80 });
81
82 let subscriptions = self.subscriptions.iter().map(|sub| {
83 let docs = &sub.docs;
84 let subscription_sink_ty = self.jrps_server_item(quote! { PendingSubscriptionSink });
85
86 let subscription_sink: syn::FnArg = syn::parse_quote!(subscription_sink: #subscription_sink_ty);
88 let mut sub_sig = sub.signature.clone();
89 sub_sig.sig.inputs.insert(1, subscription_sink);
90
91 if sub.with_extensions {
92 let ext_ty = self.jrps_server_item(quote! { Extensions });
93 let ext: syn::FnArg = syn::parse_quote!(ext: &#ext_ty);
95 sub_sig.sig.inputs.insert(2, ext);
96 }
97
98 quote! {
99 #docs
100 #sub_sig
101 }
102 });
103
104 Ok(quote! {
105 #(#methods)*
106 #(#subscriptions)*
107 })
108 }
109
110 fn handle_register_result(&self, tokens: TokenStream2) -> TokenStream2 {
116 let reexports = self.jrps_server_item(quote! { core::__reexports });
117 quote! {{
118 let _res = #tokens;
119 #[cfg(debug_assertions)]
120 if _res.is_err() {
121 #reexports::panic_fail_register();
122 }
123 }}
124 }
125
126 fn render_into_rpc(&self) -> Result<TokenStream2, syn::Error> {
127 let rpc_module = self.jrps_server_item(quote! { RpcModule });
128
129 let mut registered = HashSet::new();
130 let mut errors = Vec::new();
131 let mut check_name = |name: &str, span: Span| {
132 if registered.contains(name) {
133 let message = format!("{name:?} is already defined");
134 errors.push(quote_spanned!(span => compile_error!(#message);));
135 } else {
136 registered.insert(name.to_string());
137 }
138 };
139
140 let methods = self
141 .methods
142 .iter()
143 .map(|method| {
144 let rust_method_name = &method.signature.sig.ident;
146 let rpc_method_name = self.rpc_identifier(&method.name);
148 let (parsing, params_seq) = self.render_params_decoding(&method.params, None);
153
154 let into_response = self.jrps_server_item(quote! { IntoResponse });
155
156 check_name(&rpc_method_name, rust_method_name.span());
157
158 if method.signature.sig.asyncness.is_some() {
159 if method.with_extensions {
160 self.handle_register_result(quote! {
161 rpc.register_async_method(#rpc_method_name, |params, context, ext| async move {
162 #parsing
163 #into_response::into_response(context.as_ref().#rust_method_name(&ext, #params_seq).await)
164 })
165 })
166 } else {
167 self.handle_register_result(quote! {
168 rpc.register_async_method(#rpc_method_name, |params, context, _| async move {
169 #parsing
170 #into_response::into_response(context.as_ref().#rust_method_name(#params_seq).await)
171 })
172 })
173 }
174 } else {
175 let register_kind =
176 if method.blocking { quote!(register_blocking_method) } else { quote!(register_method) };
177
178 if method.with_extensions {
179 self.handle_register_result(quote! {
180 rpc.#register_kind(#rpc_method_name, |params, context, ext| {
181 #parsing
182 #into_response::into_response(context.#rust_method_name(&ext, #params_seq))
183 })
184 })
185 } else {
186 self.handle_register_result(quote! {
187 rpc.#register_kind(#rpc_method_name, |params, context, _| {
188 #parsing
189 #into_response::into_response(context.#rust_method_name(#params_seq))
190 })
191 })
192 }
193 }
194 })
195 .collect::<Vec<_>>();
196
197 let subscriptions = self
198 .subscriptions
199 .iter()
200 .map(|sub| {
201 let rust_method_name = &sub.signature.sig.ident;
203 let rpc_sub_name = self.rpc_identifier(&sub.name);
205 let rpc_notif_name_override = sub.notif_name_override.as_ref().map(|m| self.rpc_identifier(m));
207 let rpc_unsub_name = self.rpc_identifier(&sub.unsubscribe);
209 let pending = proc_macro2::Ident::new("pending", rust_method_name.span());
213 let (parsing, params_seq) = self.render_params_decoding(&sub.params, Some(pending));
214 let sub_err = self.jrps_server_item(quote! { SubscriptionCloseResponse });
215 let into_sub_response = self.jrps_server_item(quote! { IntoSubscriptionCloseResponse });
216
217 check_name(&rpc_sub_name, rust_method_name.span());
218 check_name(&rpc_unsub_name, rust_method_name.span());
219
220 let rpc_notif_name = match rpc_notif_name_override {
221 Some(notif) => {
222 check_name(¬if, rust_method_name.span());
223 notif
224 }
225 None => rpc_sub_name.clone(),
226 };
227
228 if sub.signature.sig.asyncness.is_some() {
229 if sub.with_extensions {
230 self.handle_register_result(quote! {
231 rpc.register_subscription(#rpc_sub_name, #rpc_notif_name, #rpc_unsub_name, |params, mut pending, context, ext| async move {
232 #parsing
233 #into_sub_response::into_response(context.as_ref().#rust_method_name(pending, &ext, #params_seq).await)
234 })
235 })
236 } else {
237 self.handle_register_result(quote! {
238 rpc.register_subscription(#rpc_sub_name, #rpc_notif_name, #rpc_unsub_name, |params, mut pending, context, _| async move {
239 #parsing
240 #into_sub_response::into_response(context.as_ref().#rust_method_name(pending, #params_seq).await)
241 })
242 })
243 }
244 } else if sub.with_extensions {
245 self.handle_register_result(quote! {
246 rpc.register_subscription_raw(#rpc_sub_name, #rpc_notif_name, #rpc_unsub_name, |params, mut pending, context, ext| {
247 #parsing
248 let _ = context.as_ref().#rust_method_name(pending, &ext, #params_seq);
249 #sub_err::None
250 })
251 })
252 } else {
253 self.handle_register_result(quote! {
254 rpc.register_subscription_raw(#rpc_sub_name, #rpc_notif_name, #rpc_unsub_name, |params, mut pending, context, _| {
255 #parsing
256 let _ = context.as_ref().#rust_method_name(pending, #params_seq);
257 #sub_err::None
258 })
259 })
260 }
261 })
262 .collect::<Vec<_>>();
263
264 let method_aliases = self
265 .methods
266 .iter()
267 .map(|method| {
268 let rpc_name = self.rpc_identifier(&method.name);
269 let rust_method_name = &method.signature.sig.ident;
270
271 let aliases: Vec<TokenStream2> = method
273 .aliases
274 .iter()
275 .map(|alias| {
276 check_name(alias, rust_method_name.span());
277 self.handle_register_result(quote! {
278 rpc.register_alias(#alias, #rpc_name)
279 })
280 })
281 .collect();
282
283 quote!( #(#aliases)* )
284 })
285 .collect::<Vec<_>>();
286
287 let subscription_aliases = self
288 .subscriptions
289 .iter()
290 .map(|method| {
291 let sub_name = self.rpc_identifier(&method.name);
292 let unsub_name = self.rpc_identifier(&method.unsubscribe);
293 let rust_method_name = &method.signature.sig.ident;
294
295 let sub: Vec<TokenStream2> = method
296 .aliases
297 .iter()
298 .map(|alias| {
299 check_name(alias, rust_method_name.span());
300 self.handle_register_result(quote! {
301 rpc.register_alias(#alias, #sub_name)
302 })
303 })
304 .collect();
305 let unsub: Vec<TokenStream2> = method
306 .unsubscribe_aliases
307 .iter()
308 .map(|alias| {
309 check_name(alias, rust_method_name.span());
310 self.handle_register_result(quote! {
311 rpc.register_alias(#alias, #unsub_name)
312 })
313 })
314 .collect();
315
316 quote! (
317 #(#sub)*
318 #(#unsub)*
319 )
320 })
321 .collect::<Vec<_>>();
322
323 let doc_comment = "Collects all the methods and subscriptions defined in the trait \
324 and adds them into a single `RpcModule`.";
325
326 let sub_tys: Vec<syn::Type> = self.subscriptions.clone().into_iter().map(|s| s.item).collect();
327 let where_clause = generate_where_clause(&self.trait_def, &sub_tys, false, self.server_bounds.as_ref());
328
329 Ok(quote! {
331 #[doc = #doc_comment]
332 fn into_rpc(self) -> #rpc_module<Self> where #(#where_clause,)* {
333 let mut rpc = #rpc_module::new(self);
334
335 #(#errors)*
336 #(#methods)*
337 #(#subscriptions)*
338 #(#method_aliases)*
339 #(#subscription_aliases)*
340
341 rpc
342 }
343 })
344 }
345
346 fn render_params_decoding(
347 &self,
348 params: &[RpcFnArg],
349 sub: Option<proc_macro2::Ident>,
350 ) -> (TokenStream2, TokenStream2) {
351 if params.is_empty() {
352 return (TokenStream2::default(), TokenStream2::default());
353 }
354
355 let params_fields_seq = params.iter().map(RpcFnArg::arg_pat);
356 let params_fields = quote! { #(#params_fields_seq),* };
357
358 let reexports = self.jrps_server_item(quote! { core::__reexports });
359
360 let error_ret = if let Some(pending) = &sub {
361 let tokio = quote! { #reexports::tokio };
362 let sub_err = self.jrps_server_item(quote! { SubscriptionCloseResponse });
363 quote! {
364 #tokio::spawn(#pending.reject(e));
365 return #sub_err::None;
366 }
367 } else {
368 let response_payload = self.jrps_server_item(quote! { ResponsePayload });
369 quote! {
370 return #response_payload::error(e);
371 }
372 };
373
374 let decode_array = {
376 let decode_fields = params.iter().map(|RpcFnArg { arg_pat, ty, .. }| {
377 let is_option = is_option(ty);
378 let next_method = if is_option { quote!(optional_next) } else { quote!(next) };
379 quote! {
380 let #arg_pat: #ty = match seq.#next_method() {
381 Ok(v) => v,
382 Err(e) => {
383 #reexports::log_fail_parse(stringify!(#arg_pat), stringify!(#ty), &e, #is_option);
384 #error_ret
385 }
386 };
387 }
388 });
389
390 quote! {
391 let mut seq = params.sequence();
392 #(#decode_fields);*
393 (#params_fields)
394 }
395 };
396
397 let decode_map = {
399 let generics = (0..params.len()).map(|n| quote::format_ident!("G{}", n));
400
401 let serde = self.jrps_server_item(quote! { core::__reexports::serde });
402 let serde_crate = serde.to_string();
403
404 let fields = params.iter().zip(generics.clone()).map(|(fn_arg, ty)| {
405 let arg_pat = fn_arg.arg_pat();
406 let name = fn_arg.name();
407
408 let mut alias_vals = String::new();
409 alias_vals.push_str(&format!(r#"alias = "{}""#, heck::ToSnakeCase::to_snake_case(name.as_str())));
410 alias_vals.push(',');
411 alias_vals
412 .push_str(&format!(r#"alias = "{}""#, heck::ToLowerCamelCase::to_lower_camel_case(name.as_str())));
413
414 let serde_rename = quote!(#[serde(rename = #name)]);
415
416 let alias = TokenStream2::from_str(alias_vals.as_str()).unwrap();
417
418 let serde_alias: Attribute = syn::parse_quote! {
419 #[serde(#alias)]
420 };
421
422 quote! {
423 #serde_alias
424 #serde_rename
425 #arg_pat: #ty,
426 }
427 });
428 let destruct = params.iter().map(RpcFnArg::arg_pat).map(|a| quote!(parsed.#a));
429 let types = params.iter().map(RpcFnArg::ty);
430
431 quote! {
432 #[derive(#serde::Deserialize)]
433 #[serde(crate = #serde_crate)]
434 struct ParamsObject<#(#generics,)*> {
435 #(#fields)*
436 }
437
438 let parsed: ParamsObject<#(#types,)*> = match params.parse() {
439 Ok(p) => p,
440 Err(e) => {
441 #reexports::log_fail_parse_as_object(&e);
442 #error_ret
443 }
444 };
445
446 (#(#destruct),*)
447 }
448 };
449
450 let parsing = quote! {
451 let (#params_fields) = if params.is_object() {
452 #decode_map
453 } else {
454 #decode_array
455 };
456 };
457
458 (parsing, params_fields)
459 }
460}