use std::error::Error as StdError;
use std::future::Future;
use std::net::{SocketAddr, TcpListener as StdTcpListener};
use std::pin::Pin;
use std::sync::atomic::AtomicU32;
use std::sync::Arc;
use std::task::Poll;
use std::time::Duration;
use crate::future::{session_close, ConnectionGuard, ServerHandle, SessionClose, SessionClosedFuture, StopHandle};
use crate::middleware::rpc::{RpcService, RpcServiceBuilder, RpcServiceCfg, RpcServiceT};
use crate::transport::ws::BackgroundTaskParams;
use crate::transport::{http, ws};
use crate::utils::deserialize;
use crate::{Extensions, HttpBody, HttpRequest, HttpResponse, LOG_TARGET};
use futures_util::future::{self, Either, FutureExt};
use futures_util::io::{BufReader, BufWriter};
use hyper::body::Bytes;
use hyper_util::rt::{TokioExecutor, TokioIo};
use jsonrpsee_core::id_providers::RandomIntegerIdProvider;
use jsonrpsee_core::server::helpers::prepare_error;
use jsonrpsee_core::server::{
BatchResponseBuilder, BoundedSubscriptions, ConnectionId, MethodResponse, MethodSink, Methods,
};
use jsonrpsee_core::traits::IdProvider;
use jsonrpsee_core::{BoxError, JsonRawValue, TEN_MB_SIZE_BYTES};
use jsonrpsee_types::error::{
reject_too_big_batch_request, ErrorCode, BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG,
};
use jsonrpsee_types::{ErrorObject, Id, InvalidRequest, Notification};
use soketto::handshake::http::is_upgrade_request;
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio::sync::{mpsc, watch, OwnedSemaphorePermit};
use tokio_util::compat::TokioAsyncReadCompatExt;
use tower::layer::util::Identity;
use tower::{Layer, Service};
use tracing::{instrument, Instrument};
type Notif<'a> = Notification<'a, Option<&'a JsonRawValue>>;
const MAX_CONNECTIONS: u32 = 100;
pub struct Server<HttpMiddleware = Identity, RpcMiddleware = Identity> {
listener: TcpListener,
server_cfg: ServerConfig,
rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
http_middleware: tower::ServiceBuilder<HttpMiddleware>,
}
impl Server<Identity, Identity> {
pub fn builder() -> Builder<Identity, Identity> {
Builder::new()
}
}
impl<RpcMiddleware, HttpMiddleware> std::fmt::Debug for Server<RpcMiddleware, HttpMiddleware> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Server").field("listener", &self.listener).field("server_cfg", &self.server_cfg).finish()
}
}
impl<RpcMiddleware, HttpMiddleware> Server<RpcMiddleware, HttpMiddleware> {
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.listener.local_addr()
}
}
impl<HttpMiddleware, RpcMiddleware, Body> Server<HttpMiddleware, RpcMiddleware>
where
RpcMiddleware: tower::Layer<RpcService> + Clone + Send + 'static,
for<'a> <RpcMiddleware as Layer<RpcService>>::Service: RpcServiceT<'a>,
HttpMiddleware: Layer<TowerServiceNoHttp<RpcMiddleware>> + Send + 'static,
<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service:
Send + Clone + Service<HttpRequest, Response = HttpResponse<Body>, Error = BoxError>,
<<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service as Service<HttpRequest>>::Future: Send,
Body: http_body::Body<Data = Bytes> + Send + 'static,
<Body as http_body::Body>::Error: Into<BoxError>,
<Body as http_body::Body>::Data: Send,
{
pub fn start(mut self, methods: impl Into<Methods>) -> ServerHandle {
let methods = methods.into();
let (stop_tx, stop_rx) = watch::channel(());
let stop_handle = StopHandle::new(stop_rx);
match self.server_cfg.tokio_runtime.take() {
Some(rt) => rt.spawn(self.start_inner(methods, stop_handle)),
None => tokio::spawn(self.start_inner(methods, stop_handle)),
};
ServerHandle::new(stop_tx)
}
async fn start_inner(self, methods: Methods, stop_handle: StopHandle) {
let mut id: u32 = 0;
let connection_guard = ConnectionGuard::new(self.server_cfg.max_connections as usize);
let listener = self.listener;
let stopped = stop_handle.clone().shutdown();
tokio::pin!(stopped);
let (drop_on_completion, mut process_connection_awaiter) = mpsc::channel::<()>(1);
loop {
match try_accept_conn(&listener, stopped).await {
AcceptConnection::Established { socket, remote_addr, stop } => {
process_connection(ProcessConnection {
http_middleware: &self.http_middleware,
rpc_middleware: self.rpc_middleware.clone(),
remote_addr,
methods: methods.clone(),
stop_handle: stop_handle.clone(),
conn_id: id,
server_cfg: self.server_cfg.clone(),
conn_guard: &connection_guard,
socket,
drop_on_completion: drop_on_completion.clone(),
});
id = id.wrapping_add(1);
stopped = stop;
}
AcceptConnection::Err((e, stop)) => {
tracing::debug!(target: LOG_TARGET, "Error while awaiting a new connection: {:?}", e);
stopped = stop;
}
AcceptConnection::Shutdown => break,
}
}
drop(drop_on_completion);
while process_connection_awaiter.recv().await.is_some() {
}
}
}
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub(crate) max_request_body_size: u32,
pub(crate) max_response_body_size: u32,
pub(crate) max_connections: u32,
pub(crate) max_subscriptions_per_connection: u32,
pub(crate) batch_requests_config: BatchRequestConfig,
pub(crate) tokio_runtime: Option<tokio::runtime::Handle>,
pub(crate) enable_http: bool,
pub(crate) enable_ws: bool,
pub(crate) message_buffer_capacity: u32,
pub(crate) ping_config: Option<PingConfig>,
pub(crate) id_provider: Arc<dyn IdProvider>,
pub(crate) tcp_no_delay: bool,
}
#[derive(Debug, Clone)]
pub struct ServerConfigBuilder {
max_request_body_size: u32,
max_response_body_size: u32,
max_connections: u32,
max_subscriptions_per_connection: u32,
batch_requests_config: BatchRequestConfig,
enable_http: bool,
enable_ws: bool,
message_buffer_capacity: u32,
ping_config: Option<PingConfig>,
id_provider: Arc<dyn IdProvider>,
}
#[derive(Debug, Clone)]
pub struct TowerServiceBuilder<RpcMiddleware, HttpMiddleware> {
pub(crate) server_cfg: ServerConfig,
pub(crate) rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
pub(crate) http_middleware: tower::ServiceBuilder<HttpMiddleware>,
pub(crate) conn_id: Arc<AtomicU32>,
pub(crate) conn_guard: ConnectionGuard,
}
#[derive(Debug, Copy, Clone)]
pub enum BatchRequestConfig {
Disabled,
Limit(u32),
Unlimited,
}
#[derive(Debug, Clone)]
pub struct ConnectionState {
pub(crate) stop_handle: StopHandle,
pub(crate) conn_id: u32,
pub(crate) _conn_permit: Arc<OwnedSemaphorePermit>,
}
impl ConnectionState {
pub fn new(stop_handle: StopHandle, conn_id: u32, conn_permit: OwnedSemaphorePermit) -> ConnectionState {
Self { stop_handle, conn_id, _conn_permit: Arc::new(conn_permit) }
}
}
#[derive(Debug, Copy, Clone)]
pub struct PingConfig {
pub(crate) ping_interval: Duration,
pub(crate) inactive_limit: Duration,
pub(crate) max_failures: usize,
}
impl Default for PingConfig {
fn default() -> Self {
Self { ping_interval: Duration::from_secs(30), max_failures: 1, inactive_limit: Duration::from_secs(40) }
}
}
impl PingConfig {
pub fn new() -> Self {
Self::default()
}
pub fn ping_interval(mut self, ping_interval: Duration) -> Self {
self.ping_interval = ping_interval;
self
}
pub fn inactive_limit(mut self, inactivity_limit: Duration) -> Self {
self.inactive_limit = inactivity_limit;
self
}
pub fn max_failures(mut self, max: usize) -> Self {
assert!(max > 0);
self.max_failures = max;
self
}
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
max_request_body_size: TEN_MB_SIZE_BYTES,
max_response_body_size: TEN_MB_SIZE_BYTES,
max_connections: MAX_CONNECTIONS,
max_subscriptions_per_connection: 1024,
batch_requests_config: BatchRequestConfig::Unlimited,
tokio_runtime: None,
enable_http: true,
enable_ws: true,
message_buffer_capacity: 1024,
ping_config: None,
id_provider: Arc::new(RandomIntegerIdProvider),
tcp_no_delay: true,
}
}
}
impl ServerConfig {
pub fn builder() -> ServerConfigBuilder {
ServerConfigBuilder::default()
}
}
impl Default for ServerConfigBuilder {
fn default() -> Self {
let this = ServerConfig::default();
ServerConfigBuilder {
max_request_body_size: this.max_request_body_size,
max_response_body_size: this.max_response_body_size,
max_connections: this.max_connections,
max_subscriptions_per_connection: this.max_subscriptions_per_connection,
batch_requests_config: this.batch_requests_config,
enable_http: this.enable_http,
enable_ws: this.enable_ws,
message_buffer_capacity: this.message_buffer_capacity,
ping_config: this.ping_config,
id_provider: this.id_provider,
}
}
}
impl ServerConfigBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn max_request_body_size(mut self, size: u32) -> Self {
self.max_request_body_size = size;
self
}
pub fn max_response_body_size(mut self, size: u32) -> Self {
self.max_response_body_size = size;
self
}
pub fn max_connections(mut self, max: u32) -> Self {
self.max_connections = max;
self
}
pub fn set_batch_request_config(mut self, cfg: BatchRequestConfig) -> Self {
self.batch_requests_config = cfg;
self
}
pub fn max_subscriptions_per_connection(mut self, max: u32) -> Self {
self.max_subscriptions_per_connection = max;
self
}
pub fn http_only(mut self) -> Self {
self.enable_http = true;
self.enable_ws = false;
self
}
pub fn ws_only(mut self) -> Self {
self.enable_http = false;
self.enable_ws = true;
self
}
pub fn set_message_buffer_capacity(mut self, c: u32) -> Self {
self.message_buffer_capacity = c;
self
}
pub fn enable_ws_ping(mut self, config: PingConfig) -> Self {
self.ping_config = Some(config);
self
}
pub fn disable_ws_ping(mut self) -> Self {
self.ping_config = None;
self
}
pub fn set_id_provider<I: IdProvider + 'static>(mut self, id_provider: I) -> Self {
self.id_provider = Arc::new(id_provider);
self
}
}
#[derive(Debug)]
pub struct Builder<HttpMiddleware, RpcMiddleware> {
server_cfg: ServerConfig,
rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
http_middleware: tower::ServiceBuilder<HttpMiddleware>,
}
impl Default for Builder<Identity, Identity> {
fn default() -> Self {
Builder {
server_cfg: ServerConfig::default(),
rpc_middleware: RpcServiceBuilder::new(),
http_middleware: tower::ServiceBuilder::new(),
}
}
}
impl Builder<Identity, Identity> {
pub fn new() -> Self {
Self::default()
}
}
impl<RpcMiddleware, HttpMiddleware> TowerServiceBuilder<RpcMiddleware, HttpMiddleware> {
pub fn build(
self,
methods: impl Into<Methods>,
stop_handle: StopHandle,
) -> TowerService<RpcMiddleware, HttpMiddleware> {
let conn_id = self.conn_id.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let rpc_middleware = TowerServiceNoHttp {
rpc_middleware: self.rpc_middleware,
inner: ServiceData {
methods: methods.into(),
stop_handle,
conn_id,
conn_guard: self.conn_guard,
server_cfg: self.server_cfg,
},
on_session_close: None,
};
TowerService { rpc_middleware, http_middleware: self.http_middleware }
}
pub fn connection_id(mut self, id: u32) -> Self {
self.conn_id = Arc::new(AtomicU32::new(id));
self
}
pub fn max_connections(mut self, limit: u32) -> Self {
self.conn_guard = ConnectionGuard::new(limit as usize);
self
}
pub fn set_rpc_middleware<T>(self, rpc_middleware: RpcServiceBuilder<T>) -> TowerServiceBuilder<T, HttpMiddleware> {
TowerServiceBuilder {
server_cfg: self.server_cfg,
rpc_middleware,
http_middleware: self.http_middleware,
conn_id: self.conn_id,
conn_guard: self.conn_guard,
}
}
pub fn set_http_middleware<T>(
self,
http_middleware: tower::ServiceBuilder<T>,
) -> TowerServiceBuilder<RpcMiddleware, T> {
TowerServiceBuilder {
server_cfg: self.server_cfg,
rpc_middleware: self.rpc_middleware,
http_middleware,
conn_id: self.conn_id,
conn_guard: self.conn_guard,
}
}
}
impl<HttpMiddleware, RpcMiddleware> Builder<HttpMiddleware, RpcMiddleware> {
pub fn max_request_body_size(mut self, size: u32) -> Self {
self.server_cfg.max_request_body_size = size;
self
}
pub fn max_response_body_size(mut self, size: u32) -> Self {
self.server_cfg.max_response_body_size = size;
self
}
pub fn max_connections(mut self, max: u32) -> Self {
self.server_cfg.max_connections = max;
self
}
pub fn set_batch_request_config(mut self, cfg: BatchRequestConfig) -> Self {
self.server_cfg.batch_requests_config = cfg;
self
}
pub fn max_subscriptions_per_connection(mut self, max: u32) -> Self {
self.server_cfg.max_subscriptions_per_connection = max;
self
}
pub fn set_rpc_middleware<T>(self, rpc_middleware: RpcServiceBuilder<T>) -> Builder<HttpMiddleware, T> {
Builder { server_cfg: self.server_cfg, rpc_middleware, http_middleware: self.http_middleware }
}
pub fn custom_tokio_runtime(mut self, rt: tokio::runtime::Handle) -> Self {
self.server_cfg.tokio_runtime = Some(rt);
self
}
pub fn enable_ws_ping(mut self, config: PingConfig) -> Self {
self.server_cfg.ping_config = Some(config);
self
}
pub fn disable_ws_ping(mut self) -> Self {
self.server_cfg.ping_config = None;
self
}
pub fn set_id_provider<I: IdProvider + 'static>(mut self, id_provider: I) -> Self {
self.server_cfg.id_provider = Arc::new(id_provider);
self
}
pub fn set_http_middleware<T>(self, http_middleware: tower::ServiceBuilder<T>) -> Builder<T, RpcMiddleware> {
Builder { server_cfg: self.server_cfg, http_middleware, rpc_middleware: self.rpc_middleware }
}
pub fn set_tcp_no_delay(mut self, no_delay: bool) -> Self {
self.server_cfg.tcp_no_delay = no_delay;
self
}
pub fn http_only(mut self) -> Self {
self.server_cfg.enable_http = true;
self.server_cfg.enable_ws = false;
self
}
pub fn ws_only(mut self) -> Self {
self.server_cfg.enable_http = false;
self.server_cfg.enable_ws = true;
self
}
pub fn set_message_buffer_capacity(mut self, c: u32) -> Self {
self.server_cfg.message_buffer_capacity = c;
self
}
pub fn to_service_builder(self) -> TowerServiceBuilder<RpcMiddleware, HttpMiddleware> {
let max_conns = self.server_cfg.max_connections as usize;
TowerServiceBuilder {
server_cfg: self.server_cfg,
rpc_middleware: self.rpc_middleware,
http_middleware: self.http_middleware,
conn_id: Arc::new(AtomicU32::new(0)),
conn_guard: ConnectionGuard::new(max_conns),
}
}
pub async fn build(self, addrs: impl ToSocketAddrs) -> std::io::Result<Server<HttpMiddleware, RpcMiddleware>> {
let listener = TcpListener::bind(addrs).await?;
Ok(Server {
listener,
server_cfg: self.server_cfg,
rpc_middleware: self.rpc_middleware,
http_middleware: self.http_middleware,
})
}
pub fn build_from_tcp(
self,
listener: impl Into<StdTcpListener>,
) -> std::io::Result<Server<HttpMiddleware, RpcMiddleware>> {
let listener = TcpListener::from_std(listener.into())?;
Ok(Server {
listener,
server_cfg: self.server_cfg,
rpc_middleware: self.rpc_middleware,
http_middleware: self.http_middleware,
})
}
}
#[derive(Debug, Clone)]
struct ServiceData {
methods: Methods,
stop_handle: StopHandle,
conn_id: u32,
conn_guard: ConnectionGuard,
server_cfg: ServerConfig,
}
#[derive(Debug, Clone)]
pub struct TowerService<RpcMiddleware, HttpMiddleware> {
rpc_middleware: TowerServiceNoHttp<RpcMiddleware>,
http_middleware: tower::ServiceBuilder<HttpMiddleware>,
}
impl<RpcMiddleware, HttpMiddleware> TowerService<RpcMiddleware, HttpMiddleware> {
pub fn on_session_closed(&mut self) -> SessionClosedFuture {
if let Some(n) = self.rpc_middleware.on_session_close.as_mut() {
n.closed()
} else {
let (session_close, fut) = session_close();
self.rpc_middleware.on_session_close = Some(session_close);
fut
}
}
}
impl<Body, RpcMiddleware, HttpMiddleware> Service<HttpRequest<Body>> for TowerService<RpcMiddleware, HttpMiddleware>
where
RpcMiddleware: for<'a> tower::Layer<RpcService> + Clone,
<RpcMiddleware as Layer<RpcService>>::Service: Send + Sync + 'static,
for<'a> <RpcMiddleware as Layer<RpcService>>::Service: RpcServiceT<'a>,
HttpMiddleware: Layer<TowerServiceNoHttp<RpcMiddleware>> + Send + 'static,
<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service:
Send + Service<HttpRequest<Body>, Response = HttpResponse, Error = Box<(dyn StdError + Send + Sync + 'static)>>,
<<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service as Service<HttpRequest<Body>>>::Future:
Send + 'static,
Body: http_body::Body<Data = Bytes> + Send + 'static,
Body::Error: Into<BoxError>,
{
type Response = HttpResponse;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, request: HttpRequest<Body>) -> Self::Future {
Box::pin(self.http_middleware.service(self.rpc_middleware.clone()).call(request))
}
}
#[derive(Debug, Clone)]
pub struct TowerServiceNoHttp<L> {
inner: ServiceData,
rpc_middleware: RpcServiceBuilder<L>,
on_session_close: Option<SessionClose>,
}
impl<Body, RpcMiddleware> Service<HttpRequest<Body>> for TowerServiceNoHttp<RpcMiddleware>
where
RpcMiddleware: for<'a> tower::Layer<RpcService>,
<RpcMiddleware as Layer<RpcService>>::Service: Send + Sync + 'static,
for<'a> <RpcMiddleware as Layer<RpcService>>::Service: RpcServiceT<'a>,
Body: http_body::Body<Data = Bytes> + Send + 'static,
Body::Error: Into<BoxError>,
{
type Response = HttpResponse;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, request: HttpRequest<Body>) -> Self::Future {
let mut request = request.map(HttpBody::new);
let conn_guard = &self.inner.conn_guard;
let stop_handle = self.inner.stop_handle.clone();
let conn_id = self.inner.conn_id;
let on_session_close = self.on_session_close.take();
tracing::trace!(target: LOG_TARGET, "{:?}", request);
let Some(conn_permit) = conn_guard.try_acquire() else {
return async move { Ok(http::response::too_many_requests()) }.boxed();
};
let conn = ConnectionState::new(stop_handle.clone(), conn_id, conn_permit);
let max_conns = conn_guard.max_connections();
let curr_conns = max_conns - conn_guard.available_connections();
tracing::debug!(target: LOG_TARGET, "Accepting new connection {}/{}", curr_conns, max_conns);
let req_ext = request.extensions_mut();
req_ext.insert::<ConnectionGuard>(conn_guard.clone());
req_ext.insert::<ConnectionId>(conn.conn_id.into());
let is_upgrade_request = is_upgrade_request(&request);
if self.inner.server_cfg.enable_ws && is_upgrade_request {
let this = self.inner.clone();
let mut server = soketto::handshake::http::Server::new();
let response = match server.receive_request(&request) {
Ok(response) => {
let (tx, rx) = mpsc::channel::<String>(this.server_cfg.message_buffer_capacity as usize);
let sink = MethodSink::new(tx);
let (pending_calls, pending_calls_completed) = mpsc::channel::<()>(1);
let cfg = RpcServiceCfg::CallsAndSubscriptions {
bounded_subscriptions: BoundedSubscriptions::new(
this.server_cfg.max_subscriptions_per_connection,
),
id_provider: this.server_cfg.id_provider.clone(),
sink: sink.clone(),
_pending_calls: pending_calls,
};
let rpc_service = RpcService::new(
this.methods.clone(),
this.server_cfg.max_response_body_size as usize,
this.conn_id.into(),
cfg,
);
let rpc_service = self.rpc_middleware.service(rpc_service);
tokio::spawn(
async move {
let extensions = request.extensions().clone();
let upgraded = match hyper::upgrade::on(request).await {
Ok(u) => u,
Err(e) => {
tracing::debug!(target: LOG_TARGET, "Could not upgrade connection: {}", e);
return;
}
};
let io = hyper_util::rt::TokioIo::new(upgraded);
let stream = BufReader::new(BufWriter::new(io.compat()));
let mut ws_builder = server.into_builder(stream);
ws_builder.set_max_message_size(this.server_cfg.max_request_body_size as usize);
let (sender, receiver) = ws_builder.finish();
let params = BackgroundTaskParams {
server_cfg: this.server_cfg,
conn,
ws_sender: sender,
ws_receiver: receiver,
rpc_service,
sink,
rx,
pending_calls_completed,
on_session_close,
extensions,
};
ws::background_task(params).await;
}
.in_current_span(),
);
response.map(|()| HttpBody::empty())
}
Err(e) => {
tracing::debug!(target: LOG_TARGET, "Could not upgrade connection: {}", e);
HttpResponse::new(HttpBody::from(format!("Could not upgrade connection: {e}")))
}
};
async { Ok(response) }.boxed()
} else if self.inner.server_cfg.enable_http && !is_upgrade_request {
let this = &self.inner;
let max_response_size = this.server_cfg.max_response_body_size;
let max_request_size = this.server_cfg.max_request_body_size;
let methods = this.methods.clone();
let batch_config = this.server_cfg.batch_requests_config;
let rpc_service = self.rpc_middleware.service(RpcService::new(
methods,
max_response_size as usize,
this.conn_id.into(),
RpcServiceCfg::OnlyCalls,
));
Box::pin(
http::call_with_service(request, batch_config, max_request_size, rpc_service, max_response_size)
.map(Ok),
)
} else {
Box::pin(async { http::response::denied() }.map(Ok))
}
}
}
struct ProcessConnection<'a, HttpMiddleware, RpcMiddleware> {
http_middleware: &'a tower::ServiceBuilder<HttpMiddleware>,
rpc_middleware: RpcServiceBuilder<RpcMiddleware>,
conn_guard: &'a ConnectionGuard,
conn_id: u32,
server_cfg: ServerConfig,
stop_handle: StopHandle,
socket: TcpStream,
drop_on_completion: mpsc::Sender<()>,
remote_addr: SocketAddr,
methods: Methods,
}
#[instrument(name = "connection", skip_all, fields(remote_addr = %params.remote_addr, conn_id = %params.conn_id), level = "INFO")]
fn process_connection<'a, RpcMiddleware, HttpMiddleware, Body>(params: ProcessConnection<HttpMiddleware, RpcMiddleware>)
where
RpcMiddleware: 'static,
HttpMiddleware: Layer<TowerServiceNoHttp<RpcMiddleware>> + Send + 'static,
<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service:
Send + 'static + Clone + Service<HttpRequest, Response = HttpResponse<Body>, Error = BoxError>,
<<HttpMiddleware as Layer<TowerServiceNoHttp<RpcMiddleware>>>::Service as Service<HttpRequest>>::Future:
Send + 'static,
Body: http_body::Body<Data = Bytes> + Send + 'static,
<Body as http_body::Body>::Error: Into<BoxError>,
<Body as http_body::Body>::Data: Send,
{
let ProcessConnection {
http_middleware,
rpc_middleware,
conn_guard,
conn_id,
server_cfg,
socket,
stop_handle,
drop_on_completion,
methods,
..
} = params;
if let Err(e) = socket.set_nodelay(server_cfg.tcp_no_delay) {
tracing::warn!(target: LOG_TARGET, "Could not set NODELAY on socket: {:?}", e);
return;
}
let tower_service = TowerServiceNoHttp {
inner: ServiceData {
server_cfg,
methods,
stop_handle: stop_handle.clone(),
conn_id,
conn_guard: conn_guard.clone(),
},
rpc_middleware,
on_session_close: None,
};
let service = http_middleware.service(tower_service);
tokio::spawn(async {
let service = crate::utils::TowerToHyperService::new(service);
let io = TokioIo::new(socket);
let builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new());
let conn = builder.serve_connection_with_upgrades(io, service);
let stopped = stop_handle.shutdown();
tokio::pin!(stopped, conn);
let res = match future::select(conn, stopped).await {
Either::Left((conn, _)) => conn,
Either::Right((_, mut conn)) => {
conn.as_mut().graceful_shutdown();
conn.await
}
};
if let Err(e) = res {
tracing::debug!(target: LOG_TARGET, "HTTP serve connection failed {:?}", e);
}
drop(drop_on_completion)
});
}
enum AcceptConnection<S> {
Shutdown,
Established { socket: TcpStream, remote_addr: SocketAddr, stop: S },
Err((std::io::Error, S)),
}
async fn try_accept_conn<S>(listener: &TcpListener, stopped: S) -> AcceptConnection<S>
where
S: Future + Unpin,
{
let accept = listener.accept();
tokio::pin!(accept);
match futures_util::future::select(accept, stopped).await {
Either::Left((res, stop)) => match res {
Ok((socket, remote_addr)) => AcceptConnection::Established { socket, remote_addr, stop },
Err(e) => AcceptConnection::Err((e, stop)),
},
Either::Right(_) => AcceptConnection::Shutdown,
}
}
pub(crate) async fn handle_rpc_call<S>(
body: &[u8],
is_single: bool,
batch_config: BatchRequestConfig,
max_response_size: u32,
rpc_service: &S,
extensions: Extensions,
) -> Option<MethodResponse>
where
for<'a> S: RpcServiceT<'a> + Send,
{
if is_single {
if let Ok(req) = deserialize::from_slice_with_extensions(body, extensions) {
Some(rpc_service.call(req).await)
} else if let Ok(_notif) = serde_json::from_slice::<Notif>(body) {
None
} else {
let (id, code) = prepare_error(body);
Some(MethodResponse::error(id, ErrorObject::from(code)))
}
}
else {
let max_len = match batch_config {
BatchRequestConfig::Disabled => {
let rp = MethodResponse::error(
Id::Null,
ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, BATCHES_NOT_SUPPORTED_MSG, None),
);
return Some(rp);
}
BatchRequestConfig::Limit(limit) => limit as usize,
BatchRequestConfig::Unlimited => usize::MAX,
};
if let Ok(batch) = serde_json::from_slice::<Vec<&JsonRawValue>>(body) {
if batch.len() > max_len {
return Some(MethodResponse::error(Id::Null, reject_too_big_batch_request(max_len)));
}
let mut got_notif = false;
let mut batch_response = BatchResponseBuilder::new_with_limit(max_response_size as usize);
for call in batch {
if let Ok(req) = deserialize::from_str_with_extensions(call.get(), extensions.clone()) {
let rp = rpc_service.call(req).await;
if let Err(too_large) = batch_response.append(&rp) {
return Some(too_large);
}
} else if let Ok(_notif) = serde_json::from_str::<Notif>(call.get()) {
got_notif = true;
} else {
let id = match serde_json::from_str::<InvalidRequest>(call.get()) {
Ok(err) => err.id,
Err(_) => Id::Null,
};
if let Err(too_large) =
batch_response.append(&MethodResponse::error(id, ErrorObject::from(ErrorCode::InvalidRequest)))
{
return Some(too_large);
}
}
}
if got_notif && batch_response.is_empty() {
None
} else {
let batch_rp = batch_response.finish();
Some(MethodResponse::from_batch(batch_rp))
}
} else {
Some(MethodResponse::error(Id::Null, ErrorObject::from(ErrorCode::ParseError)))
}
}
}