sc_rpc_server/middleware/
mod.rs1use std::{
22	num::NonZeroU32,
23	time::{Duration, Instant},
24};
25
26use futures::future::{BoxFuture, FutureExt};
27use governor::{clock::Clock, Jitter};
28use jsonrpsee::{
29	server::middleware::rpc::RpcServiceT,
30	types::{ErrorObject, Id, Request},
31	MethodResponse,
32};
33
34mod metrics;
35mod node_health;
36mod rate_limit;
37
38pub use metrics::*;
39pub use node_health::*;
40pub use rate_limit::*;
41
42const MAX_JITTER: Duration = Duration::from_millis(50);
43const MAX_RETRIES: usize = 10;
44
45#[derive(Debug, Clone, Default)]
47pub struct MiddlewareLayer {
48	rate_limit: Option<RateLimit>,
49	metrics: Option<Metrics>,
50}
51
52impl MiddlewareLayer {
53	pub fn new() -> Self {
55		Self::default()
56	}
57
58	pub fn with_rate_limit_per_minute(self, n: NonZeroU32) -> Self {
60		Self { rate_limit: Some(RateLimit::per_minute(n)), metrics: self.metrics }
61	}
62
63	pub fn with_metrics(self, metrics: Metrics) -> Self {
65		Self { rate_limit: self.rate_limit, metrics: Some(metrics) }
66	}
67
68	pub fn ws_connect(&self) {
70		self.metrics.as_ref().map(|m| m.ws_connect());
71	}
72
73	pub fn ws_disconnect(&self, now: Instant) {
75		self.metrics.as_ref().map(|m| m.ws_disconnect(now));
76	}
77}
78
79impl<S> tower::Layer<S> for MiddlewareLayer {
80	type Service = Middleware<S>;
81
82	fn layer(&self, service: S) -> Self::Service {
83		Middleware { service, rate_limit: self.rate_limit.clone(), metrics: self.metrics.clone() }
84	}
85}
86
87pub struct Middleware<S> {
95	service: S,
96	rate_limit: Option<RateLimit>,
97	metrics: Option<Metrics>,
98}
99
100impl<'a, S> RpcServiceT<'a> for Middleware<S>
101where
102	S: Send + Sync + RpcServiceT<'a> + Clone + 'static,
103{
104	type Future = BoxFuture<'a, MethodResponse>;
105
106	fn call(&self, req: Request<'a>) -> Self::Future {
107		let now = Instant::now();
108
109		self.metrics.as_ref().map(|m| m.on_call(&req));
110
111		let service = self.service.clone();
112		let rate_limit = self.rate_limit.clone();
113		let metrics = self.metrics.clone();
114
115		async move {
116			let mut is_rate_limited = false;
117
118			if let Some(limit) = rate_limit.as_ref() {
119				let mut attempts = 0;
120				let jitter = Jitter::up_to(MAX_JITTER);
121
122				loop {
123					if attempts >= MAX_RETRIES {
124						return reject_too_many_calls(req.id);
125					}
126
127					if let Err(rejected) = limit.inner.check() {
128						tokio::time::sleep(jitter + rejected.wait_time_from(limit.clock.now()))
129							.await;
130					} else {
131						break;
132					}
133
134					is_rate_limited = true;
135					attempts += 1;
136				}
137			}
138
139			let rp = service.call(req.clone()).await;
140			metrics.as_ref().map(|m| m.on_response(&req, &rp, is_rate_limited, now));
141
142			rp
143		}
144		.boxed()
145	}
146}
147
148fn reject_too_many_calls(id: Id) -> MethodResponse {
149	MethodResponse::error(id, ErrorObject::owned(-32999, "RPC rate limit exceeded", None::<()>))
150}