jsonrpsee_server/middleware/http/
host_filter.rs1use crate::middleware::http::authority::{Authority, AuthorityError, Port};
30use crate::transport::http;
31use crate::{HttpBody, HttpRequest, LOG_TARGET};
32use futures_util::{Future, FutureExt, TryFutureExt};
33use hyper::body::Bytes;
34use hyper::Response;
35use jsonrpsee_core::BoxError;
36use route_recognizer::Router;
37use std::collections::BTreeMap;
38use std::pin::Pin;
39use std::sync::Arc;
40use std::task::{Context, Poll};
41use tower::{Layer, Service};
42
43type Ports = Vec<Port>;
44
45#[derive(Debug, Clone)]
47pub struct HostFilterLayer(Option<Arc<WhitelistedHosts>>);
48
49impl HostFilterLayer {
50 pub fn new<T, U>(allow_only: T) -> Result<Self, AuthorityError>
52 where
53 T: IntoIterator<Item = U>,
54 U: TryInto<Authority, Error = AuthorityError>,
55 {
56 let allow_only: Result<Vec<_>, _> = allow_only.into_iter().map(|a| a.try_into()).collect();
57 Ok(Self(Some(Arc::new(WhitelistedHosts::from(allow_only?)))))
58 }
59
60 pub fn disable() -> Self {
83 Self(None)
84 }
85}
86
87impl<S> Layer<S> for HostFilterLayer {
88 type Service = HostFilter<S>;
89
90 fn layer(&self, inner: S) -> Self::Service {
91 HostFilter { inner, filter: self.0.clone() }
92 }
93}
94
95#[derive(Debug, Clone)]
97pub struct HostFilter<S> {
98 inner: S,
99 filter: Option<Arc<WhitelistedHosts>>,
100}
101
102impl<S, B> Service<HttpRequest<B>> for HostFilter<S>
103where
104 S: Service<HttpRequest<B>, Response = Response<HttpBody>>,
105 S::Response: 'static,
106 S::Error: Into<BoxError> + 'static,
107 S::Future: Send + 'static,
108 B: http_body::Body<Data = Bytes> + Send + std::fmt::Debug + 'static,
109 B::Data: Send,
110 B::Error: Into<BoxError>,
111{
112 type Response = S::Response;
113 type Error = BoxError;
114 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
115
116 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
117 self.inner.poll_ready(cx).map_err(Into::into)
118 }
119
120 fn call(&mut self, request: HttpRequest<B>) -> Self::Future {
121 let Some(authority) = Authority::from_http_request(&request) else {
122 return async { Ok(http::response::malformed()) }.boxed();
123 };
124
125 if self.filter.as_ref().map_or(true, |f| f.recognize(&authority)) {
126 Box::pin(self.inner.call(request).map_err(Into::into))
127 } else {
128 tracing::debug!(target: LOG_TARGET, "Denied request: {:?}", request);
129 async { Ok(http::response::host_not_allowed()) }.boxed()
130 }
131 }
132}
133
134#[derive(Default, Debug, Clone)]
136pub struct WhitelistedHosts(Router<Ports>);
137
138impl<T> From<T> for WhitelistedHosts
139where
140 T: IntoIterator<Item = Authority>,
141{
142 fn from(value: T) -> Self {
143 let mut router = Router::new();
144 let mut uniq_hosts: BTreeMap<String, Ports> = BTreeMap::new();
145
146 for auth in value.into_iter() {
150 uniq_hosts
151 .entry(auth.host)
152 .and_modify(|v| {
153 v.push(auth.port);
154 })
155 .or_insert_with(|| vec![auth.port]);
156 }
157
158 for (host, ports) in uniq_hosts.into_iter() {
159 router.add(&host, ports);
160 }
161
162 Self(router)
163 }
164}
165
166impl WhitelistedHosts {
167 fn recognize(&self, other: &Authority) -> bool {
168 if let Ok(p) = self.0.recognize(&other.host) {
169 let ports = p.handler();
170
171 ports.iter().any(|p| match (p, &other.port) {
172 (Port::Any, _) => true,
173 (Port::Default, Port::Default) => true,
174 (Port::Fixed(p1), Port::Fixed(p2)) if p1 == p2 => true,
175 _ => false,
176 })
177 } else {
178 false
179 }
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::{Authority, WhitelistedHosts};
186
187 fn unwrap_auth(a: &str) -> Authority {
188 a.try_into().unwrap()
189 }
190
191 fn unwrap_filter(list: &[&str]) -> WhitelistedHosts {
192 let l: Vec<_> = list.iter().map(|&a| a.try_into().unwrap()).collect();
193 WhitelistedHosts::from(l)
194 }
195
196 #[test]
197 fn should_reject_if_header_not_on_the_list() {
198 let filter = unwrap_filter(&[]);
199 assert!(!filter.recognize(&unwrap_auth("parity.io")));
200 }
201
202 #[test]
203 fn should_accept_if_on_the_list() {
204 let filter = unwrap_filter(&["parity.io"]);
205 assert!(filter.recognize(&unwrap_auth("parity.io")));
206 }
207
208 #[test]
209 fn should_accept_if_on_the_list_with_port() {
210 let filter = unwrap_filter(&["parity.io:443", "parity.io:9944"]);
211 assert!(filter.recognize(&unwrap_auth("parity.io:443")));
212 assert!(filter.recognize(&unwrap_auth("parity.io:9944")));
213 assert!(!filter.recognize(&unwrap_auth("parity.io")));
214 }
215
216 #[test]
217 fn should_support_wildcards() {
218 let filter = unwrap_filter(&["*.web3.site:*"]);
219 assert!(filter.recognize(&unwrap_auth("parity.web3.site:8180")));
220 assert!(filter.recognize(&unwrap_auth("parity.web3.site")));
221 }
222
223 #[test]
224 fn should_accept_with_and_without_default_port() {
225 let filter = unwrap_filter(&["https://parity.io:443"]);
226 assert!(filter.recognize(&unwrap_auth("https://parity.io")));
227 assert!(filter.recognize(&unwrap_auth("https://parity.io:443")));
228 }
229}