use std::{
num::NonZeroU32,
time::{Duration, Instant},
};
use futures::future::{BoxFuture, FutureExt};
use governor::{clock::Clock, Jitter};
use jsonrpsee::{
server::middleware::rpc::RpcServiceT,
types::{ErrorObject, Id, Request},
MethodResponse,
};
mod metrics;
mod node_health;
mod rate_limit;
pub use metrics::*;
pub use node_health::*;
pub use rate_limit::*;
const MAX_JITTER: Duration = Duration::from_millis(50);
const MAX_RETRIES: usize = 10;
#[derive(Debug, Clone, Default)]
pub struct MiddlewareLayer {
rate_limit: Option<RateLimit>,
metrics: Option<Metrics>,
}
impl MiddlewareLayer {
pub fn new() -> Self {
Self::default()
}
pub fn with_rate_limit_per_minute(self, n: NonZeroU32) -> Self {
Self { rate_limit: Some(RateLimit::per_minute(n)), metrics: self.metrics }
}
pub fn with_metrics(self, metrics: Metrics) -> Self {
Self { rate_limit: self.rate_limit, metrics: Some(metrics) }
}
pub fn ws_connect(&self) {
self.metrics.as_ref().map(|m| m.ws_connect());
}
pub fn ws_disconnect(&self, now: Instant) {
self.metrics.as_ref().map(|m| m.ws_disconnect(now));
}
}
impl<S> tower::Layer<S> for MiddlewareLayer {
type Service = Middleware<S>;
fn layer(&self, service: S) -> Self::Service {
Middleware { service, rate_limit: self.rate_limit.clone(), metrics: self.metrics.clone() }
}
}
pub struct Middleware<S> {
service: S,
rate_limit: Option<RateLimit>,
metrics: Option<Metrics>,
}
impl<'a, S> RpcServiceT<'a> for Middleware<S>
where
S: Send + Sync + RpcServiceT<'a> + Clone + 'static,
{
type Future = BoxFuture<'a, MethodResponse>;
fn call(&self, req: Request<'a>) -> Self::Future {
let now = Instant::now();
self.metrics.as_ref().map(|m| m.on_call(&req));
let service = self.service.clone();
let rate_limit = self.rate_limit.clone();
let metrics = self.metrics.clone();
async move {
let mut is_rate_limited = false;
if let Some(limit) = rate_limit.as_ref() {
let mut attempts = 0;
let jitter = Jitter::up_to(MAX_JITTER);
loop {
if attempts >= MAX_RETRIES {
return reject_too_many_calls(req.id);
}
if let Err(rejected) = limit.inner.check() {
tokio::time::sleep(jitter + rejected.wait_time_from(limit.clock.now()))
.await;
} else {
break;
}
is_rate_limited = true;
attempts += 1;
}
}
let rp = service.call(req.clone()).await;
metrics.as_ref().map(|m| m.on_response(&req, &rp, is_rate_limited, now));
rp
}
.boxed()
}
}
fn reject_too_many_calls(id: Id) -> MethodResponse {
MethodResponse::error(id, ErrorObject::owned(-32999, "RPC rate limit exceeded", None::<()>))
}