jsonrpsee_server/middleware/http/
host_filter.rs

1// Copyright 2019-2023 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27//! HTTP host validation middleware.
28
29use crate::middleware::http::authority::{Authority, AuthorityError, Port};
30use crate::transport::http;
31use crate::{HttpBody, HttpRequest, LOG_TARGET};
32use futures_util::{Future, FutureExt, TryFutureExt};
33use hyper::body::Bytes;
34use hyper::Response;
35use jsonrpsee_core::BoxError;
36use route_recognizer::Router;
37use std::collections::BTreeMap;
38use std::pin::Pin;
39use std::sync::Arc;
40use std::task::{Context, Poll};
41use tower::{Layer, Service};
42
43type Ports = Vec<Port>;
44
45/// Middleware to enable host filtering.
46#[derive(Debug, Clone)]
47pub struct HostFilterLayer(Option<Arc<WhitelistedHosts>>);
48
49impl HostFilterLayer {
50	/// Enables host filtering and allow only the specified hosts.
51	pub fn new<T, U>(allow_only: T) -> Result<Self, AuthorityError>
52	where
53		T: IntoIterator<Item = U>,
54		U: TryInto<Authority, Error = AuthorityError>,
55	{
56		let allow_only: Result<Vec<_>, _> = allow_only.into_iter().map(|a| a.try_into()).collect();
57		Ok(Self(Some(Arc::new(WhitelistedHosts::from(allow_only?)))))
58	}
59
60	/// Convenience method to disable host filtering but less efficient
61	/// than to not enable the middleware at all.
62	///
63	/// Because is the `tower middleware` returns a different type
64	/// depending on which Layers are configured it and may not compile
65	/// in some contexts.
66	///
67	/// For example the following won't compile:
68	///
69	/// ```ignore
70	/// use jsonrpsee_server::middleware::{ProxyGetRequestLayer, HostFilterLayer};
71	///
72	/// let host_filter = false;
73	///
74	/// let middleware = if host_filter {
75	///     tower::ServiceBuilder::new()
76	///        .layer(HostFilterLayer::new(["example.com"]).unwrap())
77	///        .layer(ProxyGetRequestLayer::new("/health", "system_health").unwrap())
78	/// } else {
79	///    tower::ServiceBuilder::new()
80	/// };
81	/// ```
82	pub fn disable() -> Self {
83		Self(None)
84	}
85}
86
87impl<S> Layer<S> for HostFilterLayer {
88	type Service = HostFilter<S>;
89
90	fn layer(&self, inner: S) -> Self::Service {
91		HostFilter { inner, filter: self.0.clone() }
92	}
93}
94
95/// Middleware to enable host filtering.
96#[derive(Debug, Clone)]
97pub struct HostFilter<S> {
98	inner: S,
99	filter: Option<Arc<WhitelistedHosts>>,
100}
101
102impl<S, B> Service<HttpRequest<B>> for HostFilter<S>
103where
104	S: Service<HttpRequest<B>, Response = Response<HttpBody>>,
105	S::Response: 'static,
106	S::Error: Into<BoxError> + 'static,
107	S::Future: Send + 'static,
108	B: http_body::Body<Data = Bytes> + Send + std::fmt::Debug + 'static,
109	B::Data: Send,
110	B::Error: Into<BoxError>,
111{
112	type Response = S::Response;
113	type Error = BoxError;
114	type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
115
116	fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
117		self.inner.poll_ready(cx).map_err(Into::into)
118	}
119
120	fn call(&mut self, request: HttpRequest<B>) -> Self::Future {
121		let Some(authority) = Authority::from_http_request(&request) else {
122			return async { Ok(http::response::malformed()) }.boxed();
123		};
124
125		if self.filter.as_ref().map_or(true, |f| f.recognize(&authority)) {
126			Box::pin(self.inner.call(request).map_err(Into::into))
127		} else {
128			tracing::debug!(target: LOG_TARGET, "Denied request: {:?}", request);
129			async { Ok(http::response::host_not_allowed()) }.boxed()
130		}
131	}
132}
133
134/// Represent the URL patterns that is whitelisted.
135#[derive(Default, Debug, Clone)]
136pub struct WhitelistedHosts(Router<Ports>);
137
138impl<T> From<T> for WhitelistedHosts
139where
140	T: IntoIterator<Item = Authority>,
141{
142	fn from(value: T) -> Self {
143		let mut router = Router::new();
144		let mut uniq_hosts: BTreeMap<String, Ports> = BTreeMap::new();
145
146		// Ensure that no ports are "overwritten"
147		// since it's possible add the same hostname with
148		// several port numbers.
149		for auth in value.into_iter() {
150			uniq_hosts
151				.entry(auth.host)
152				.and_modify(|v| {
153					v.push(auth.port);
154				})
155				.or_insert_with(|| vec![auth.port]);
156		}
157
158		for (host, ports) in uniq_hosts.into_iter() {
159			router.add(&host, ports);
160		}
161
162		Self(router)
163	}
164}
165
166impl WhitelistedHosts {
167	fn recognize(&self, other: &Authority) -> bool {
168		if let Ok(p) = self.0.recognize(&other.host) {
169			let ports = p.handler();
170
171			ports.iter().any(|p| match (p, &other.port) {
172				(Port::Any, _) => true,
173				(Port::Default, Port::Default) => true,
174				(Port::Fixed(p1), Port::Fixed(p2)) if p1 == p2 => true,
175				_ => false,
176			})
177		} else {
178			false
179		}
180	}
181}
182
183#[cfg(test)]
184mod tests {
185	use super::{Authority, WhitelistedHosts};
186
187	fn unwrap_auth(a: &str) -> Authority {
188		a.try_into().unwrap()
189	}
190
191	fn unwrap_filter(list: &[&str]) -> WhitelistedHosts {
192		let l: Vec<_> = list.iter().map(|&a| a.try_into().unwrap()).collect();
193		WhitelistedHosts::from(l)
194	}
195
196	#[test]
197	fn should_reject_if_header_not_on_the_list() {
198		let filter = unwrap_filter(&[]);
199		assert!(!filter.recognize(&unwrap_auth("parity.io")));
200	}
201
202	#[test]
203	fn should_accept_if_on_the_list() {
204		let filter = unwrap_filter(&["parity.io"]);
205		assert!(filter.recognize(&unwrap_auth("parity.io")));
206	}
207
208	#[test]
209	fn should_accept_if_on_the_list_with_port() {
210		let filter = unwrap_filter(&["parity.io:443", "parity.io:9944"]);
211		assert!(filter.recognize(&unwrap_auth("parity.io:443")));
212		assert!(filter.recognize(&unwrap_auth("parity.io:9944")));
213		assert!(!filter.recognize(&unwrap_auth("parity.io")));
214	}
215
216	#[test]
217	fn should_support_wildcards() {
218		let filter = unwrap_filter(&["*.web3.site:*"]);
219		assert!(filter.recognize(&unwrap_auth("parity.web3.site:8180")));
220		assert!(filter.recognize(&unwrap_auth("parity.web3.site")));
221	}
222
223	#[test]
224	fn should_accept_with_and_without_default_port() {
225		let filter = unwrap_filter(&["https://parity.io:443"]);
226		assert!(filter.recognize(&unwrap_auth("https://parity.io")));
227		assert!(filter.recognize(&unwrap_auth("https://parity.io:443")));
228	}
229}