use std;
use std::sync::{atomic, Arc};
use crate::core;
use crate::core::futures::sync::oneshot;
use crate::core::futures::{Async, Future, Poll};
use parking_lot::Mutex;
use slab::Slab;
use crate::server_utils::cors::Origin;
use crate::server_utils::hosts::Host;
use crate::server_utils::session::{SessionId, SessionStats};
use crate::server_utils::tokio::runtime::TaskExecutor;
use crate::server_utils::Pattern;
use crate::ws;
use crate::error;
use crate::metadata;
pub trait RequestMiddleware: Send + Sync + 'static {
fn process(&self, req: &ws::Request) -> MiddlewareAction;
}
impl<F> RequestMiddleware for F
where
F: Fn(&ws::Request) -> Option<ws::Response> + Send + Sync + 'static,
{
fn process(&self, req: &ws::Request) -> MiddlewareAction {
(*self)(req).into()
}
}
#[derive(Debug)]
pub enum MiddlewareAction {
Proceed,
Respond {
response: ws::Response,
validate_origin: bool,
validate_hosts: bool,
},
}
impl MiddlewareAction {
fn should_verify_origin(&self) -> bool {
use self::MiddlewareAction::*;
match *self {
Proceed => true,
Respond { validate_origin, .. } => validate_origin,
}
}
fn should_verify_hosts(&self) -> bool {
use self::MiddlewareAction::*;
match *self {
Proceed => true,
Respond { validate_hosts, .. } => validate_hosts,
}
}
}
impl From<Option<ws::Response>> for MiddlewareAction {
fn from(opt: Option<ws::Response>) -> Self {
match opt {
Some(res) => MiddlewareAction::Respond {
response: res,
validate_origin: true,
validate_hosts: true,
},
None => MiddlewareAction::Proceed,
}
}
}
type TaskSlab = Mutex<Slab<Option<oneshot::Sender<()>>>>;
#[derive(Debug)]
struct LivenessPoll {
task_slab: Arc<TaskSlab>,
slab_handle: usize,
rx: oneshot::Receiver<()>,
}
impl LivenessPoll {
fn create(task_slab: Arc<TaskSlab>) -> Self {
const INITIAL_SIZE: usize = 4;
let (index, rx) = {
let mut task_slab = task_slab.lock();
if task_slab.len() == task_slab.capacity() {
let reserve = ::std::cmp::max(task_slab.capacity(), INITIAL_SIZE);
task_slab.reserve_exact(reserve);
}
let (tx, rx) = oneshot::channel();
let index = task_slab.insert(Some(tx));
(index, rx)
};
LivenessPoll {
task_slab,
slab_handle: index,
rx,
}
}
}
impl Future for LivenessPoll {
type Item = ();
type Error = ();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.rx.poll() {
Ok(Async::Ready(_)) | Err(_) => Ok(Async::Ready(())),
Ok(Async::NotReady) => Ok(Async::NotReady),
}
}
}
impl Drop for LivenessPoll {
fn drop(&mut self) {
self.task_slab.lock().remove(self.slab_handle);
}
}
pub struct Session<M: core::Metadata, S: core::Middleware<M>> {
active: Arc<atomic::AtomicBool>,
context: metadata::RequestContext,
handler: Arc<core::MetaIoHandler<M, S>>,
meta_extractor: Arc<dyn metadata::MetaExtractor<M>>,
allowed_origins: Option<Vec<Origin>>,
allowed_hosts: Option<Vec<Host>>,
request_middleware: Option<Arc<dyn RequestMiddleware>>,
stats: Option<Arc<dyn SessionStats>>,
metadata: Option<M>,
executor: TaskExecutor,
task_slab: Arc<TaskSlab>,
}
impl<M: core::Metadata, S: core::Middleware<M>> Drop for Session<M, S> {
fn drop(&mut self) {
self.active.store(false, atomic::Ordering::SeqCst);
if let Some(stats) = self.stats.as_ref() {
stats.close_session(self.context.session_id)
}
for (_index, task) in self.task_slab.lock().iter_mut() {
if let Some(task) = task.take() {
let _ = task.send(());
}
}
}
}
impl<M: core::Metadata, S: core::Middleware<M>> Session<M, S> {
fn read_origin<'a>(&self, req: &'a ws::Request) -> Option<&'a [u8]> {
req.header("origin").map(|x| &x[..])
}
fn verify_origin(&self, origin: Option<&[u8]>) -> Option<ws::Response> {
if !header_is_allowed(&self.allowed_origins, origin) {
warn!(
"Blocked connection to WebSockets server from untrusted origin: {:?}",
origin.and_then(|s| std::str::from_utf8(s).ok()),
);
Some(forbidden("URL Blocked", "Connection Origin has been rejected."))
} else {
None
}
}
fn verify_host(&self, req: &ws::Request) -> Option<ws::Response> {
let host = req.header("host").map(|x| &x[..]);
if !header_is_allowed(&self.allowed_hosts, host) {
warn!(
"Blocked connection to WebSockets server with untrusted host: {:?}",
host.and_then(|s| std::str::from_utf8(s).ok()),
);
Some(forbidden("URL Blocked", "Connection Host has been rejected."))
} else {
None
}
}
}
impl<M: core::Metadata, S: core::Middleware<M>> ws::Handler for Session<M, S> {
fn on_request(&mut self, req: &ws::Request) -> ws::Result<ws::Response> {
let action = if let Some(ref middleware) = self.request_middleware {
middleware.process(req)
} else {
MiddlewareAction::Proceed
};
let origin = self.read_origin(req);
if action.should_verify_origin() {
if let Some(response) = self.verify_origin(origin) {
return Ok(response);
}
}
if action.should_verify_hosts() {
if let Some(response) = self.verify_host(req) {
return Ok(response);
}
}
self.context.origin = origin
.and_then(|origin| ::std::str::from_utf8(origin).ok())
.map(Into::into);
self.context.protocols = req
.protocols()
.ok()
.map(|protos| protos.into_iter().map(Into::into).collect())
.unwrap_or_else(Vec::new);
self.metadata = Some(self.meta_extractor.extract(&self.context));
match action {
MiddlewareAction::Proceed => ws::Response::from_request(req).map(|mut res| {
if let Some(protocol) = self.context.protocols.get(0) {
res.set_protocol(protocol);
}
res
}),
MiddlewareAction::Respond { response, .. } => Ok(response),
}
}
fn on_message(&mut self, msg: ws::Message) -> ws::Result<()> {
let req = msg.as_text()?;
let out = self.context.out.clone();
let metadata = self
.metadata
.clone()
.expect("Metadata is always set in on_request; qed");
let poll_liveness = LivenessPoll::create(self.task_slab.clone());
let active_lock = self.active.clone();
let future = self
.handler
.handle_request(req, metadata)
.map(move |response| {
if !active_lock.load(atomic::Ordering::SeqCst) {
return;
}
if let Some(result) = response {
let res = out.send(result);
match res {
Err(error::Error::ConnectionClosed) => {
active_lock.store(false, atomic::Ordering::SeqCst);
}
Err(e) => {
warn!("Error while sending response: {:?}", e);
}
_ => {}
}
}
})
.select(poll_liveness)
.map(|_| ())
.map_err(|_| ());
self.executor.spawn(future);
Ok(())
}
}
pub struct Factory<M: core::Metadata, S: core::Middleware<M>> {
session_id: SessionId,
handler: Arc<core::MetaIoHandler<M, S>>,
meta_extractor: Arc<dyn metadata::MetaExtractor<M>>,
allowed_origins: Option<Vec<Origin>>,
allowed_hosts: Option<Vec<Host>>,
request_middleware: Option<Arc<dyn RequestMiddleware>>,
stats: Option<Arc<dyn SessionStats>>,
executor: TaskExecutor,
}
impl<M: core::Metadata, S: core::Middleware<M>> Factory<M, S> {
pub fn new(
handler: Arc<core::MetaIoHandler<M, S>>,
meta_extractor: Arc<dyn metadata::MetaExtractor<M>>,
allowed_origins: Option<Vec<Origin>>,
allowed_hosts: Option<Vec<Host>>,
request_middleware: Option<Arc<dyn RequestMiddleware>>,
stats: Option<Arc<dyn SessionStats>>,
executor: TaskExecutor,
) -> Self {
Factory {
session_id: 0,
handler,
meta_extractor,
allowed_origins,
allowed_hosts,
request_middleware,
stats,
executor,
}
}
}
impl<M: core::Metadata, S: core::Middleware<M>> ws::Factory for Factory<M, S> {
type Handler = Session<M, S>;
fn connection_made(&mut self, sender: ws::Sender) -> Self::Handler {
self.session_id += 1;
if let Some(executor) = self.stats.as_ref() {
executor.open_session(self.session_id)
}
let active = Arc::new(atomic::AtomicBool::new(true));
Session {
active: active.clone(),
context: metadata::RequestContext {
session_id: self.session_id,
origin: None,
protocols: Vec::new(),
out: metadata::Sender::new(sender, active),
executor: self.executor.clone(),
},
handler: self.handler.clone(),
meta_extractor: self.meta_extractor.clone(),
allowed_origins: self.allowed_origins.clone(),
allowed_hosts: self.allowed_hosts.clone(),
stats: self.stats.clone(),
request_middleware: self.request_middleware.clone(),
metadata: None,
executor: self.executor.clone(),
task_slab: Arc::new(Mutex::new(Slab::with_capacity(0))),
}
}
}
fn header_is_allowed<T>(allowed: &Option<Vec<T>>, header: Option<&[u8]>) -> bool
where
T: Pattern,
{
let header = header.map(std::str::from_utf8);
match (header, allowed.as_ref()) {
(None, _) => true,
(_, None) => true,
(Some(Ok(val)), Some(values)) => {
for v in values {
if v.matches(val) {
return true;
}
}
false
}
_ => false,
}
}
fn forbidden(title: &str, message: &str) -> ws::Response {
let mut forbidden = ws::Response::new(403, "Forbidden", format!("{}\n{}\n", title, message).into_bytes());
{
let headers = forbidden.headers_mut();
headers.push(("Connection".to_owned(), b"close".to_vec()));
}
forbidden
}