use super::{WebSocketKey, SEC_WEBSOCKET_EXTENSIONS};
use crate::connection::{self, Mode};
use crate::extension::Extension;
use crate::handshake;
use bytes::BytesMut;
use futures::prelude::*;
use http::{header, HeaderMap, Response};
use std::mem;
pub type Error = handshake::Error;
pub struct Server {
extensions: Vec<Box<dyn Extension + Send>>,
buffer: BytesMut,
}
impl Server {
pub fn new() -> Self {
Server { extensions: Vec::new(), buffer: BytesMut::new() }
}
pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self {
self.buffer = b;
self
}
pub fn take_buffer(&mut self) -> BytesMut {
mem::take(&mut self.buffer)
}
pub fn add_extension(&mut self, e: Box<dyn Extension + Send>) -> &mut Self {
self.extensions.push(e);
self
}
pub fn drain_extensions(&mut self) -> impl Iterator<Item = Box<dyn Extension + Send>> + '_ {
self.extensions.drain(..)
}
pub fn receive_request<B>(&mut self, req: &http::Request<B>) -> Result<http::Response<()>, Error> {
if !is_upgrade_request(&req) {
return Err(Error::InvalidSecWebSocketAccept);
}
let key = match req.headers().get("Sec-WebSocket-Key") {
Some(key) => key,
None => {
return Err(Error::HeaderNotFound("Sec-WebSocket-Key".into()).into());
}
};
if req.headers().get("Sec-WebSocket-Version").map(|v| v.as_bytes()) != Some(b"13") {
return Err(Error::HeaderNotFound("Sec-WebSocket-Version".into()).into());
}
let key: &WebSocketKey = match key.as_bytes().try_into() {
Ok(key) => key,
Err(_) => return Err(Error::InvalidSecWebSocketAccept),
};
let accept_key = handshake::generate_accept_key(key);
let extension_config = req
.headers()
.iter()
.filter(|&(name, _)| name.as_str().eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS))
.map(|(_, value)| Ok(std::str::from_utf8(value.as_bytes())?.to_string()))
.collect::<Result<Vec<_>, Error>>()?;
for config_str in &extension_config {
handshake::configure_extensions(&mut self.extensions, &config_str)?;
}
let mut response = Response::builder()
.status(http::StatusCode::SWITCHING_PROTOCOLS)
.header(http::header::CONNECTION, "upgrade")
.header(http::header::UPGRADE, "websocket")
.header("Sec-WebSocket-Accept", &accept_key[..]);
if !self.extensions.is_empty() {
let mut buf = bytes::BytesMut::new();
let enabled_extensions = self.extensions.iter().filter(|e| e.is_enabled()).peekable();
handshake::append_extension_header_value(enabled_extensions, &mut buf);
response = response.header("Sec-WebSocket-Extensions", buf.as_ref());
}
let response = response.body(()).expect("bug: failed to build response");
Ok(response)
}
pub fn into_builder<T: AsyncRead + AsyncWrite + Unpin>(mut self, socket: T) -> connection::Builder<T> {
let mut builder = connection::Builder::new(socket, Mode::Server);
builder.set_buffer(self.buffer);
builder.add_extensions(self.extensions.drain(..));
builder
}
}
pub fn is_upgrade_request<B>(request: &http::Request<B>) -> bool {
header_contains_value(request.headers(), header::CONNECTION, b"upgrade")
&& header_contains_value(request.headers(), header::UPGRADE, b"websocket")
}
fn header_contains_value(headers: &HeaderMap, header: header::HeaderName, value: &[u8]) -> bool {
pub fn trim(x: &[u8]) -> &[u8] {
let from = match x.iter().position(|x| !x.is_ascii_whitespace()) {
Some(i) => i,
None => return &[],
};
let to = x.iter().rposition(|x| !x.is_ascii_whitespace()).unwrap();
&x[from..=to]
}
for header in headers.get_all(header) {
if header.as_bytes().split(|&c| c == b',').any(|x| trim(x).eq_ignore_ascii_case(value)) {
return true;
}
}
false
}