use crate::BatchRequestConfig;
use std::{
error::Error as StdError,
net::{IpAddr, SocketAddr},
num::NonZeroU32,
str::FromStr,
};
use forwarded_header_value::ForwardedHeaderValue;
use http::header::{HeaderName, HeaderValue};
use ip_network::IpNetwork;
use jsonrpsee::{server::middleware::http::HostFilterLayer, RpcModule};
use sc_rpc_api::DenyUnsafe;
use tower_http::cors::{AllowOrigin, CorsLayer};
const X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
const X_REAL_IP: HeaderName = HeaderName::from_static("x-real-ip");
const FORWARDED: HeaderName = HeaderName::from_static("forwarded");
#[derive(Debug)]
pub(crate) struct ListenAddrError;
impl std::error::Error for ListenAddrError {}
impl std::fmt::Display for ListenAddrError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "No listen address was successfully bound")
}
}
#[derive(Debug, Copy, Clone)]
pub enum RpcMethods {
Safe,
Unsafe,
Auto,
}
impl Default for RpcMethods {
fn default() -> Self {
RpcMethods::Auto
}
}
impl FromStr for RpcMethods {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"safe" => Ok(RpcMethods::Safe),
"unsafe" => Ok(RpcMethods::Unsafe),
"auto" => Ok(RpcMethods::Auto),
invalid => Err(format!("Invalid rpc methods {invalid}")),
}
}
}
#[derive(Debug, Clone)]
pub(crate) struct RpcSettings {
pub(crate) batch_config: BatchRequestConfig,
pub(crate) max_connections: u32,
pub(crate) max_payload_in_mb: u32,
pub(crate) max_payload_out_mb: u32,
pub(crate) max_subscriptions_per_connection: u32,
pub(crate) max_buffer_capacity_per_connection: u32,
pub(crate) rpc_methods: RpcMethods,
pub(crate) rate_limit: Option<NonZeroU32>,
pub(crate) rate_limit_trust_proxy_headers: bool,
pub(crate) rate_limit_whitelisted_ips: Vec<IpNetwork>,
pub(crate) cors: CorsLayer,
pub(crate) host_filter: Option<HostFilterLayer>,
}
#[derive(Debug, Clone)]
pub struct RpcEndpoint {
pub listen_addr: SocketAddr,
pub batch_config: BatchRequestConfig,
pub max_connections: u32,
pub max_payload_in_mb: u32,
pub max_payload_out_mb: u32,
pub max_subscriptions_per_connection: u32,
pub max_buffer_capacity_per_connection: u32,
pub rate_limit: Option<NonZeroU32>,
pub rate_limit_trust_proxy_headers: bool,
pub rate_limit_whitelisted_ips: Vec<IpNetwork>,
pub cors: Option<Vec<String>>,
pub rpc_methods: RpcMethods,
pub is_optional: bool,
pub retry_random_port: bool,
}
impl RpcEndpoint {
pub(crate) async fn bind(self) -> Result<Listener, Box<dyn StdError + Send + Sync>> {
let listener = match tokio::net::TcpListener::bind(self.listen_addr).await {
Ok(listener) => listener,
Err(_) if self.retry_random_port => {
let mut addr = self.listen_addr;
addr.set_port(0);
tokio::net::TcpListener::bind(addr).await?
},
Err(e) => return Err(e.into()),
};
let local_addr = listener.local_addr()?;
let host_filter = host_filtering(self.cors.is_some(), local_addr);
let cors = try_into_cors(self.cors)?;
Ok(Listener {
listener,
local_addr,
cfg: RpcSettings {
batch_config: self.batch_config,
max_connections: self.max_connections,
max_payload_in_mb: self.max_payload_in_mb,
max_payload_out_mb: self.max_payload_out_mb,
max_subscriptions_per_connection: self.max_subscriptions_per_connection,
max_buffer_capacity_per_connection: self.max_buffer_capacity_per_connection,
rpc_methods: self.rpc_methods,
rate_limit: self.rate_limit,
rate_limit_trust_proxy_headers: self.rate_limit_trust_proxy_headers,
rate_limit_whitelisted_ips: self.rate_limit_whitelisted_ips,
host_filter,
cors,
},
})
}
}
pub(crate) struct Listener {
listener: tokio::net::TcpListener,
local_addr: SocketAddr,
cfg: RpcSettings,
}
impl Listener {
pub(crate) async fn accept(&mut self) -> std::io::Result<(tokio::net::TcpStream, SocketAddr)> {
let (sock, remote_addr) = self.listener.accept().await?;
Ok((sock, remote_addr))
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn rpc_settings(&self) -> RpcSettings {
self.cfg.clone()
}
}
pub(crate) fn host_filtering(enabled: bool, addr: SocketAddr) -> Option<HostFilterLayer> {
if enabled {
let hosts = [
format!("localhost:{}", addr.port()),
format!("127.0.0.1:{}", addr.port()),
format!("[::1]:{}", addr.port()),
];
Some(HostFilterLayer::new(hosts).expect("Valid hosts; qed"))
} else {
None
}
}
pub(crate) fn build_rpc_api<M: Send + Sync + 'static>(mut rpc_api: RpcModule<M>) -> RpcModule<M> {
let mut available_methods = rpc_api.method_names().collect::<Vec<_>>();
available_methods.push("rpc_methods");
available_methods.sort();
rpc_api
.register_method("rpc_methods", move |_, _, _| {
serde_json::json!({
"methods": available_methods,
})
})
.expect("infallible all other methods have their own address space; qed");
rpc_api
}
pub(crate) fn try_into_cors(
maybe_cors: Option<Vec<String>>,
) -> Result<CorsLayer, Box<dyn StdError + Send + Sync>> {
if let Some(cors) = maybe_cors {
let mut list = Vec::new();
for origin in cors {
list.push(HeaderValue::from_str(&origin)?)
}
Ok(CorsLayer::new().allow_origin(AllowOrigin::list(list)))
} else {
Ok(CorsLayer::permissive())
}
}
pub(crate) fn get_proxy_ip<B>(req: &http::Request<B>) -> Option<IpAddr> {
if let Some(ip) = req
.headers()
.get(&FORWARDED)
.and_then(|v| v.to_str().ok())
.and_then(|v| ForwardedHeaderValue::from_forwarded(v).ok())
.and_then(|v| v.remotest_forwarded_for_ip())
{
return Some(ip);
}
if let Some(ip) = req
.headers()
.get(&X_FORWARDED_FOR)
.and_then(|v| v.to_str().ok())
.and_then(|v| ForwardedHeaderValue::from_x_forwarded_for(v).ok())
.and_then(|v| v.remotest_forwarded_for_ip())
{
return Some(ip);
}
if let Some(ip) = req
.headers()
.get(&X_REAL_IP)
.and_then(|v| v.to_str().ok())
.and_then(|v| IpAddr::from_str(v).ok())
{
return Some(ip);
}
None
}
pub fn deny_unsafe(addr: &SocketAddr, methods: &RpcMethods) -> DenyUnsafe {
match (addr.ip().is_loopback(), methods) {
(_, RpcMethods::Unsafe) | (true, RpcMethods::Auto) => DenyUnsafe::No,
_ => DenyUnsafe::Yes,
}
}
pub(crate) fn format_listen_addrs(addr: &[SocketAddr]) -> String {
let mut s = String::new();
let mut it = addr.iter().peekable();
while let Some(addr) = it.next() {
s.push_str(&addr.to_string());
if it.peek().is_some() {
s.push(',');
}
}
if addr.len() == 1 {
s.push(',');
}
s
}
#[cfg(test)]
mod tests {
use super::*;
use hyper::header::HeaderValue;
use jsonrpsee::server::{HttpBody, HttpRequest};
fn request() -> http::Request<HttpBody> {
HttpRequest::builder().body(HttpBody::empty()).unwrap()
}
#[test]
fn empty_works() {
let req = request();
let host = get_proxy_ip(&req);
assert!(host.is_none())
}
#[test]
fn host_from_x_real_ip() {
let mut req = request();
req.headers_mut().insert(&X_REAL_IP, HeaderValue::from_static("127.0.0.1"));
let ip = get_proxy_ip(&req);
assert_eq!(Some(IpAddr::from_str("127.0.0.1").unwrap()), ip);
}
#[test]
fn ip_from_forwarded_works() {
let mut req = request();
req.headers_mut().insert(
&FORWARDED,
HeaderValue::from_static("for=192.0.2.60;proto=http;by=203.0.113.43;host=example.com"),
);
let ip = get_proxy_ip(&req);
assert_eq!(Some(IpAddr::from_str("192.0.2.60").unwrap()), ip);
}
#[test]
fn ip_from_forwarded_multiple() {
let mut req = request();
req.headers_mut().append(&FORWARDED, HeaderValue::from_static("for=127.0.0.1"));
req.headers_mut().append(&FORWARDED, HeaderValue::from_static("for=192.0.2.60"));
req.headers_mut().append(&FORWARDED, HeaderValue::from_static("for=192.0.2.61"));
let ip = get_proxy_ip(&req);
assert_eq!(Some(IpAddr::from_str("127.0.0.1").unwrap()), ip);
}
#[test]
fn ip_from_x_forwarded_works() {
let mut req = request();
req.headers_mut()
.insert(&X_FORWARDED_FOR, HeaderValue::from_static("127.0.0.1,192.0.2.60,0.0.0.1"));
let ip = get_proxy_ip(&req);
assert_eq!(Some(IpAddr::from_str("127.0.0.1").unwrap()), ip);
}
}