1pub mod client;
14#[cfg(feature = "http")]
15pub mod http;
16pub mod server;
17
18use crate::extension::{Extension, Param};
19use base64::Engine;
20use bytes::BytesMut;
21use sha1::{Digest, Sha1};
22use std::{fmt, io, str};
23
24pub use client::{Client, ServerResponse};
25pub use server::{ClientRequest, Server};
26
27const KEY: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
30
31const MAX_NUM_HEADERS: usize = 32;
33
34const SEC_WEBSOCKET_EXTENSIONS: &str = "Sec-WebSocket-Extensions";
36const SEC_WEBSOCKET_PROTOCOL: &str = "Sec-WebSocket-Protocol";
37
38fn expect_ascii_header(headers: &[httparse::Header], name: &str, ours: &str) -> Result<(), Error> {
40 enum State {
41 Init, Name, Match, }
45
46 headers
47 .iter()
48 .filter(|h| h.name.eq_ignore_ascii_case(name))
49 .fold(Ok(State::Init), |result, header| {
50 if let Ok(State::Match) = result {
51 return result;
52 }
53 if str::from_utf8(header.value)?.split(',').any(|v| v.trim().eq_ignore_ascii_case(ours)) {
54 return Ok(State::Match);
55 }
56 Ok(State::Name)
57 })
58 .and_then(|state| match state {
59 State::Init => Err(Error::HeaderNotFound(name.into())),
60 State::Name => Err(Error::UnexpectedHeader(name.into())),
61 State::Match => Ok(()),
62 })
63}
64
65fn with_first_header<'a, F, R>(headers: &[httparse::Header<'a>], name: &str, f: F) -> Result<R, Error>
67where
68 F: Fn(&'a [u8]) -> Result<R, Error>,
69{
70 if let Some(h) = headers.iter().find(|h| h.name.eq_ignore_ascii_case(name)) {
71 f(h.value)
72 } else {
73 Err(Error::HeaderNotFound(name.into()))
74 }
75}
76
77fn configure_extensions(extensions: &mut [Box<dyn Extension + Send>], line: &str) -> Result<(), Error> {
79 for e in line.split(',') {
80 let mut ext_parts = e.split(';');
81 if let Some(name) = ext_parts.next() {
82 let name = name.trim();
83 if let Some(ext) = extensions.iter_mut().find(|x| x.name().eq_ignore_ascii_case(name)) {
84 let mut params = Vec::new();
85 for p in ext_parts {
86 let mut key_value = p.split('=');
87 if let Some(key) = key_value.next().map(str::trim) {
88 let val = key_value.next().map(|v| v.trim().trim_matches('"'));
89 let mut p = Param::new(key);
90 p.set_value(val);
91 params.push(p)
92 }
93 }
94 ext.configure(¶ms).map_err(Error::Extension)?
95 }
96 }
97 }
98 Ok(())
99}
100
101fn append_extensions<'a, I>(extensions: I, bytes: &mut BytesMut)
103where
104 I: IntoIterator<Item = &'a Box<dyn Extension + Send>>,
105{
106 let mut iter = extensions.into_iter().peekable();
107
108 if iter.peek().is_some() {
109 bytes.extend_from_slice(b"\r\nSec-WebSocket-Extensions: ")
110 }
111
112 append_extension_header_value(iter, bytes)
113}
114
115fn append_extension_header_value<'a, I>(mut extensions_iter: std::iter::Peekable<I>, bytes: &mut BytesMut)
117where
118 I: Iterator<Item = &'a Box<dyn Extension + Send>>,
119{
120 while let Some(e) = extensions_iter.next() {
121 bytes.extend_from_slice(e.name().as_bytes());
122 for p in e.params() {
123 bytes.extend_from_slice(b"; ");
124 bytes.extend_from_slice(p.name().as_bytes());
125 if let Some(v) = p.value() {
126 bytes.extend_from_slice(b"=");
127 bytes.extend_from_slice(v.as_bytes())
128 }
129 }
130 if extensions_iter.peek().is_some() {
131 bytes.extend_from_slice(b", ")
132 }
133 }
134}
135
136fn generate_accept_key<'k>(key_base64: &WebSocketKey) -> [u8; 28] {
146 let mut digest = Sha1::new();
147 digest.update(key_base64);
148 digest.update(KEY);
149 let d = digest.finalize();
150
151 let mut output_buf = [0; 28];
152 let n = base64::engine::general_purpose::STANDARD
153 .encode_slice(d, &mut output_buf)
154 .expect("encoding to base64 is exactly 28 bytes; qed");
155 debug_assert_eq!(n, 28, "encoding to base64 should be exactly 28 bytes");
156 output_buf
157}
158
159#[non_exhaustive]
161#[derive(Debug)]
162pub enum Error {
163 Io(io::Error),
165 UnsupportedHttpVersion,
167 IncompleteHttpRequest,
169 SecWebSocketKeyInvalidLength(usize),
171 InvalidRequestMethod,
173 HeaderNotFound(String),
175 UnexpectedHeader(String),
177 InvalidSecWebSocketAccept,
179 UnsolicitedExtension,
181 UnsolicitedProtocol,
183 Extension(crate::BoxedError),
185 Http(crate::BoxedError),
187 Utf8(str::Utf8Error),
189}
190
191impl fmt::Display for Error {
192 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
193 match self {
194 Error::Io(e) => write!(f, "i/o error: {}", e),
195 Error::UnsupportedHttpVersion => f.write_str("http version was not 1.1"),
196 Error::IncompleteHttpRequest => f.write_str("http request was incomplete"),
197 Error::SecWebSocketKeyInvalidLength(len) => {
198 write!(f, "Sec-WebSocket-Key header was {} bytes long, expected 24", len)
199 }
200 Error::InvalidRequestMethod => f.write_str("handshake was not a GET request"),
201 Error::HeaderNotFound(name) => write!(f, "header {} not found", name),
202 Error::UnexpectedHeader(name) => write!(f, "header {} had an unexpected value", name),
203 Error::InvalidSecWebSocketAccept => f.write_str("websocket key mismatch"),
204 Error::UnsolicitedExtension => f.write_str("unsolicited extension returned"),
205 Error::UnsolicitedProtocol => f.write_str("unsolicited protocol returned"),
206 Error::Extension(e) => write!(f, "extension error: {}", e),
207 Error::Http(e) => write!(f, "http parser error: {}", e),
208 Error::Utf8(e) => write!(f, "utf-8 decoding error: {}", e),
209 }
210 }
211}
212
213impl std::error::Error for Error {
214 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
215 match self {
216 Error::Io(e) => Some(e),
217 Error::Extension(e) => Some(&**e),
218 Error::Http(e) => Some(&**e),
219 Error::Utf8(e) => Some(e),
220 Error::UnsupportedHttpVersion
221 | Error::IncompleteHttpRequest
222 | Error::SecWebSocketKeyInvalidLength(_)
223 | Error::InvalidRequestMethod
224 | Error::HeaderNotFound(_)
225 | Error::UnexpectedHeader(_)
226 | Error::InvalidSecWebSocketAccept
227 | Error::UnsolicitedExtension
228 | Error::UnsolicitedProtocol => None,
229 }
230 }
231}
232
233impl From<io::Error> for Error {
234 fn from(e: io::Error) -> Self {
235 Error::Io(e)
236 }
237}
238
239impl From<str::Utf8Error> for Error {
240 fn from(e: str::Utf8Error) -> Self {
241 Error::Utf8(e)
242 }
243}
244
245pub type WebSocketKey = [u8; 24];
257
258#[cfg(test)]
259mod tests {
260 use super::expect_ascii_header;
261
262 #[test]
263 fn header_match() {
264 let headers = &[
265 httparse::Header { name: "foo", value: b"a,b,c,d" },
266 httparse::Header { name: "foo", value: b"x" },
267 httparse::Header { name: "foo", value: b"y, z, a" },
268 httparse::Header { name: "bar", value: b"xxx" },
269 httparse::Header { name: "bar", value: b"sdfsdf 423 42 424" },
270 httparse::Header { name: "baz", value: b"123" },
271 ];
272
273 assert!(expect_ascii_header(headers, "foo", "a").is_ok());
274 assert!(expect_ascii_header(headers, "foo", "b").is_ok());
275 assert!(expect_ascii_header(headers, "foo", "c").is_ok());
276 assert!(expect_ascii_header(headers, "foo", "d").is_ok());
277 assert!(expect_ascii_header(headers, "foo", "x").is_ok());
278 assert!(expect_ascii_header(headers, "foo", "y").is_ok());
279 assert!(expect_ascii_header(headers, "foo", "z").is_ok());
280 assert!(expect_ascii_header(headers, "foo", "a").is_ok());
281 assert!(expect_ascii_header(headers, "bar", "xxx").is_ok());
282 assert!(expect_ascii_header(headers, "bar", "sdfsdf 423 42 424").is_ok());
283 assert!(expect_ascii_header(headers, "baz", "123").is_ok());
284 assert!(expect_ascii_header(headers, "baz", "???").is_err());
285 assert!(expect_ascii_header(headers, "???", "x").is_err());
286 }
287}