use super::{
append_extensions, configure_extensions, expect_ascii_header, with_first_header, Error, WebSocketKey, KEY,
MAX_NUM_HEADERS, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_PROTOCOL,
};
use crate::connection::{self, Mode};
use crate::{extension::Extension, Parsing};
use base64::Engine;
use bytes::{Buf, BytesMut};
use futures::prelude::*;
use sha1::{Digest, Sha1};
use std::{mem, str};
pub use httparse::Header;
const BLOCK_SIZE: usize = 8 * 1024;
#[derive(Debug)]
pub struct Client<'a, T> {
socket: T,
host: &'a str,
resource: &'a str,
headers: &'a [Header<'a>],
nonce: WebSocketKey,
protocols: Vec<&'a str>,
extensions: Vec<Box<dyn Extension + Send>>,
buffer: BytesMut,
}
impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> {
pub fn new(socket: T, host: &'a str, resource: &'a str) -> Self {
Client {
socket,
host,
resource,
headers: &[],
nonce: [0; 24],
protocols: Vec::new(),
extensions: Vec::new(),
buffer: BytesMut::new(),
}
}
pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self {
self.buffer = b;
self
}
pub fn take_buffer(&mut self) -> BytesMut {
mem::take(&mut self.buffer)
}
pub fn set_headers(&mut self, h: &'a [Header]) -> &mut Self {
self.headers = h;
self
}
pub fn add_protocol(&mut self, p: &'a str) -> &mut Self {
self.protocols.push(p);
self
}
pub fn add_extension(&mut self, e: Box<dyn Extension + Send>) -> &mut Self {
self.extensions.push(e);
self
}
pub fn drain_extensions(&mut self) -> impl Iterator<Item = Box<dyn Extension + Send>> + '_ {
self.extensions.drain(..)
}
pub async fn handshake(&mut self) -> Result<ServerResponse, Error> {
self.buffer.clear();
self.encode_request();
self.socket.write_all(&self.buffer).await?;
self.socket.flush().await?;
self.buffer.clear();
loop {
crate::read(&mut self.socket, &mut self.buffer, BLOCK_SIZE).await?;
if let Parsing::Done { value, offset } = self.decode_response()? {
self.buffer.advance(offset);
return Ok(value);
}
}
}
pub fn into_builder(mut self) -> connection::Builder<T> {
let mut builder = connection::Builder::new(self.socket, Mode::Client);
builder.set_buffer(self.buffer);
builder.add_extensions(self.extensions.drain(..));
builder
}
pub fn into_inner(self) -> T {
self.socket
}
fn encode_request(&mut self) {
let nonce: [u8; 16] = rand::random();
base64::engine::general_purpose::STANDARD
.encode_slice(nonce, &mut self.nonce)
.expect("encoding to base64 is exactly 16 bytes; qed");
self.buffer.extend_from_slice(b"GET ");
self.buffer.extend_from_slice(self.resource.as_bytes());
self.buffer.extend_from_slice(b" HTTP/1.1");
self.buffer.extend_from_slice(b"\r\nHost: ");
self.buffer.extend_from_slice(self.host.as_bytes());
self.buffer.extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: Upgrade");
self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Key: ");
self.buffer.extend_from_slice(&self.nonce);
self.headers.iter().for_each(|h| {
self.buffer.extend_from_slice(b"\r\n");
self.buffer.extend_from_slice(h.name.as_bytes());
self.buffer.extend_from_slice(b": ");
self.buffer.extend_from_slice(h.value);
});
if let Some((last, prefix)) = self.protocols.split_last() {
self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: ");
for p in prefix {
self.buffer.extend_from_slice(p.as_bytes());
self.buffer.extend_from_slice(b",")
}
self.buffer.extend_from_slice(last.as_bytes())
}
append_extensions(&self.extensions, &mut self.buffer);
self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Version: 13\r\n\r\n")
}
fn decode_response(&mut self) -> Result<Parsing<ServerResponse>, Error> {
let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS];
let mut response = httparse::Response::new(&mut header_buf);
let offset = match response.parse(self.buffer.as_ref()) {
Ok(httparse::Status::Complete(off)) => off,
Ok(httparse::Status::Partial) => return Ok(Parsing::NeedMore(())),
Err(e) => return Err(Error::Http(Box::new(e))),
};
if response.version != Some(1) {
return Err(Error::UnsupportedHttpVersion);
}
match response.code {
Some(101) => (),
Some(code @ (301..=303)) | Some(code @ 307) | Some(code @ 308) => {
let location =
with_first_header(response.headers, "Location", |loc| Ok(String::from(std::str::from_utf8(loc)?)))?;
let response = ServerResponse::Redirect { status_code: code, location };
return Ok(Parsing::Done { value: response, offset });
}
other => {
let response = ServerResponse::Rejected { status_code: other.unwrap_or(0) };
return Ok(Parsing::Done { value: response, offset });
}
}
expect_ascii_header(response.headers, "Upgrade", "websocket")?;
expect_ascii_header(response.headers, "Connection", "upgrade")?;
with_first_header(&response.headers, "Sec-WebSocket-Accept", |theirs| {
let mut digest = Sha1::new();
digest.update(&self.nonce);
digest.update(KEY);
let ours = base64::engine::general_purpose::STANDARD.encode(digest.finalize());
if ours.as_bytes() != theirs {
return Err(Error::InvalidSecWebSocketAccept);
}
Ok(())
})?;
for h in response.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) {
configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)?
}
let mut selected_proto = None;
if let Some(tp) = response.headers.iter().find(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL)) {
if let Some(&p) = self.protocols.iter().find(|x| x.as_bytes() == tp.value) {
selected_proto = Some(String::from(p))
} else {
return Err(Error::UnsolicitedProtocol);
}
}
let response = ServerResponse::Accepted { protocol: selected_proto };
Ok(Parsing::Done { value: response, offset })
}
}
#[derive(Debug)]
pub enum ServerResponse {
Accepted {
protocol: Option<String>,
},
Redirect {
status_code: u16,
location: String,
},
Rejected {
status_code: u16,
},
}