use crate::middleware::http::authority::{Authority, AuthorityError, Port};
use crate::transport::http;
use crate::{HttpBody, HttpRequest, LOG_TARGET};
use futures_util::{Future, FutureExt, TryFutureExt};
use hyper::body::Bytes;
use hyper::Response;
use jsonrpsee_core::BoxError;
use route_recognizer::Router;
use std::collections::BTreeMap;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tower::{Layer, Service};
type Ports = Vec<Port>;
#[derive(Debug, Clone)]
pub struct HostFilterLayer(Option<Arc<WhitelistedHosts>>);
impl HostFilterLayer {
pub fn new<T, U>(allow_only: T) -> Result<Self, AuthorityError>
where
T: IntoIterator<Item = U>,
U: TryInto<Authority, Error = AuthorityError>,
{
let allow_only: Result<Vec<_>, _> = allow_only.into_iter().map(|a| a.try_into()).collect();
Ok(Self(Some(Arc::new(WhitelistedHosts::from(allow_only?)))))
}
pub fn disable() -> Self {
Self(None)
}
}
impl<S> Layer<S> for HostFilterLayer {
type Service = HostFilter<S>;
fn layer(&self, inner: S) -> Self::Service {
HostFilter { inner, filter: self.0.clone() }
}
}
#[derive(Debug, Clone)]
pub struct HostFilter<S> {
inner: S,
filter: Option<Arc<WhitelistedHosts>>,
}
impl<S, B> Service<HttpRequest<B>> for HostFilter<S>
where
S: Service<HttpRequest<B>, Response = Response<HttpBody>>,
S::Response: 'static,
S::Error: Into<BoxError> + 'static,
S::Future: Send + 'static,
B: http_body::Body<Data = Bytes> + Send + std::fmt::Debug + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{
type Response = S::Response;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, request: HttpRequest<B>) -> Self::Future {
let Some(authority) = Authority::from_http_request(&request) else {
return async { Ok(http::response::malformed()) }.boxed();
};
if self.filter.as_ref().map_or(true, |f| f.recognize(&authority)) {
Box::pin(self.inner.call(request).map_err(Into::into))
} else {
tracing::debug!(target: LOG_TARGET, "Denied request: {:?}", request);
async { Ok(http::response::host_not_allowed()) }.boxed()
}
}
}
#[derive(Default, Debug, Clone)]
pub struct WhitelistedHosts(Router<Ports>);
impl<T> From<T> for WhitelistedHosts
where
T: IntoIterator<Item = Authority>,
{
fn from(value: T) -> Self {
let mut router = Router::new();
let mut uniq_hosts: BTreeMap<String, Ports> = BTreeMap::new();
for auth in value.into_iter() {
uniq_hosts
.entry(auth.host)
.and_modify(|v| {
v.push(auth.port);
})
.or_insert_with(|| vec![auth.port]);
}
for (host, ports) in uniq_hosts.into_iter() {
router.add(&host, ports);
}
Self(router)
}
}
impl WhitelistedHosts {
fn recognize(&self, other: &Authority) -> bool {
if let Ok(p) = self.0.recognize(&other.host) {
let ports = p.handler();
ports.iter().any(|p| match (p, &other.port) {
(Port::Any, _) => true,
(Port::Default, Port::Default) => true,
(Port::Fixed(p1), Port::Fixed(p2)) if p1 == p2 => true,
_ => false,
})
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::{Authority, WhitelistedHosts};
fn unwrap_auth(a: &str) -> Authority {
a.try_into().unwrap()
}
fn unwrap_filter(list: &[&str]) -> WhitelistedHosts {
let l: Vec<_> = list.iter().map(|&a| a.try_into().unwrap()).collect();
WhitelistedHosts::from(l)
}
#[test]
fn should_reject_if_header_not_on_the_list() {
let filter = unwrap_filter(&[]);
assert!(!filter.recognize(&unwrap_auth("parity.io")));
}
#[test]
fn should_accept_if_on_the_list() {
let filter = unwrap_filter(&["parity.io"]);
assert!(filter.recognize(&unwrap_auth("parity.io")));
}
#[test]
fn should_accept_if_on_the_list_with_port() {
let filter = unwrap_filter(&["parity.io:443", "parity.io:9944"]);
assert!(filter.recognize(&unwrap_auth("parity.io:443")));
assert!(filter.recognize(&unwrap_auth("parity.io:9944")));
assert!(!filter.recognize(&unwrap_auth("parity.io")));
}
#[test]
fn should_support_wildcards() {
let filter = unwrap_filter(&["*.web3.site:*"]);
assert!(filter.recognize(&unwrap_auth("parity.web3.site:8180")));
assert!(filter.recognize(&unwrap_auth("parity.web3.site")));
}
#[test]
fn should_accept_with_and_without_default_port() {
let filter = unwrap_filter(&["https://parity.io:443"]);
assert!(filter.recognize(&unwrap_auth("https://parity.io")));
assert!(filter.recognize(&unwrap_auth("https://parity.io:443")));
}
}