soketto/
handshake.rs

1// Copyright (c) 2019 Parity Technologies (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9//! Websocket [handshake]s.
10//!
11//! [handshake]: https://tools.ietf.org/html/rfc6455#section-4
12
13pub mod client;
14#[cfg(feature = "http")]
15pub mod http;
16pub mod server;
17
18use crate::extension::{Extension, Param};
19use base64::Engine;
20use bytes::BytesMut;
21use sha1::{Digest, Sha1};
22use std::{fmt, io, str};
23
24pub use client::{Client, ServerResponse};
25pub use server::{ClientRequest, Server};
26
27// Defined in RFC 6455 and used to generate the `Sec-WebSocket-Accept` header
28// in the server handshake response.
29const KEY: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
30
31// How many HTTP headers do we support during parsing?
32const MAX_NUM_HEADERS: usize = 32;
33
34// Some HTTP headers we need to check during parsing.
35const SEC_WEBSOCKET_EXTENSIONS: &str = "Sec-WebSocket-Extensions";
36const SEC_WEBSOCKET_PROTOCOL: &str = "Sec-WebSocket-Protocol";
37
38/// Check a set of headers contains a specific one.
39fn expect_ascii_header(headers: &[httparse::Header], name: &str, ours: &str) -> Result<(), Error> {
40	enum State {
41		Init,  // Start state
42		Name,  // Header name found
43		Match, // Header value matches
44	}
45
46	headers
47		.iter()
48		.filter(|h| h.name.eq_ignore_ascii_case(name))
49		.fold(Ok(State::Init), |result, header| {
50			if let Ok(State::Match) = result {
51				return result;
52			}
53			if str::from_utf8(header.value)?.split(',').any(|v| v.trim().eq_ignore_ascii_case(ours)) {
54				return Ok(State::Match);
55			}
56			Ok(State::Name)
57		})
58		.and_then(|state| match state {
59			State::Init => Err(Error::HeaderNotFound(name.into())),
60			State::Name => Err(Error::UnexpectedHeader(name.into())),
61			State::Match => Ok(()),
62		})
63}
64
65/// Pick the first header with the given name and apply the given closure to it.
66fn with_first_header<'a, F, R>(headers: &[httparse::Header<'a>], name: &str, f: F) -> Result<R, Error>
67where
68	F: Fn(&'a [u8]) -> Result<R, Error>,
69{
70	if let Some(h) = headers.iter().find(|h| h.name.eq_ignore_ascii_case(name)) {
71		f(h.value)
72	} else {
73		Err(Error::HeaderNotFound(name.into()))
74	}
75}
76
77// Configure all extensions with parsed parameters.
78fn configure_extensions(extensions: &mut [Box<dyn Extension + Send>], line: &str) -> Result<(), Error> {
79	for e in line.split(',') {
80		let mut ext_parts = e.split(';');
81		if let Some(name) = ext_parts.next() {
82			let name = name.trim();
83			if let Some(ext) = extensions.iter_mut().find(|x| x.name().eq_ignore_ascii_case(name)) {
84				let mut params = Vec::new();
85				for p in ext_parts {
86					let mut key_value = p.split('=');
87					if let Some(key) = key_value.next().map(str::trim) {
88						let val = key_value.next().map(|v| v.trim().trim_matches('"'));
89						let mut p = Param::new(key);
90						p.set_value(val);
91						params.push(p)
92					}
93				}
94				ext.configure(&params).map_err(Error::Extension)?
95			}
96		}
97	}
98	Ok(())
99}
100
101// Write all extensions to the given buffer.
102fn append_extensions<'a, I>(extensions: I, bytes: &mut BytesMut)
103where
104	I: IntoIterator<Item = &'a Box<dyn Extension + Send>>,
105{
106	let mut iter = extensions.into_iter().peekable();
107
108	if iter.peek().is_some() {
109		bytes.extend_from_slice(b"\r\nSec-WebSocket-Extensions: ")
110	}
111
112	append_extension_header_value(iter, bytes)
113}
114
115// Write the extension header value to the given buffer.
116fn append_extension_header_value<'a, I>(mut extensions_iter: std::iter::Peekable<I>, bytes: &mut BytesMut)
117where
118	I: Iterator<Item = &'a Box<dyn Extension + Send>>,
119{
120	while let Some(e) = extensions_iter.next() {
121		bytes.extend_from_slice(e.name().as_bytes());
122		for p in e.params() {
123			bytes.extend_from_slice(b"; ");
124			bytes.extend_from_slice(p.name().as_bytes());
125			if let Some(v) = p.value() {
126				bytes.extend_from_slice(b"=");
127				bytes.extend_from_slice(v.as_bytes())
128			}
129		}
130		if extensions_iter.peek().is_some() {
131			bytes.extend_from_slice(b", ")
132		}
133	}
134}
135
136// This function takes a 16 byte key (base64 encoded, and so 24 bytes of input) that is expected via
137// the `Sec-WebSocket-Key` header during a websocket handshake, and writes the response that's expected
138// to be handed back in the response header `Sec-WebSocket-Accept`.
139//
140// The response is a base64 encoding of a 160bit hash. base64 encoding uses 1 ascii character per 6 bits,
141// and 160 / 6 = 26.66 characters. The output is padded with '=' to the nearest 4 characters, so we need 28
142// bytes in total for all of the characters.
143//
144// See https://datatracker.ietf.org/doc/html/rfc6455#section-1.3 for more information on this.
145fn generate_accept_key<'k>(key_base64: &WebSocketKey) -> [u8; 28] {
146	let mut digest = Sha1::new();
147	digest.update(key_base64);
148	digest.update(KEY);
149	let d = digest.finalize();
150
151	let mut output_buf = [0; 28];
152	let n = base64::engine::general_purpose::STANDARD
153		.encode_slice(d, &mut output_buf)
154		.expect("encoding to base64 is exactly 28 bytes; qed");
155	debug_assert_eq!(n, 28, "encoding to base64 should be exactly 28 bytes");
156	output_buf
157}
158
159/// Enumeration of possible handshake errors.
160#[non_exhaustive]
161#[derive(Debug)]
162pub enum Error {
163	/// An I/O error has been encountered.
164	Io(io::Error),
165	/// An HTTP version =/= 1.1 was encountered.
166	UnsupportedHttpVersion,
167	/// An incomplete HTTP request.
168	IncompleteHttpRequest,
169	/// The value of the `Sec-WebSocket-Key` header is of unexpected length.
170	SecWebSocketKeyInvalidLength(usize),
171	/// The handshake request was not a GET request.
172	InvalidRequestMethod,
173	/// An HTTP header has not been present.
174	HeaderNotFound(String),
175	/// An HTTP header value was not expected.
176	UnexpectedHeader(String),
177	/// The Sec-WebSocket-Accept header value did not match.
178	InvalidSecWebSocketAccept,
179	/// The server returned an extension we did not ask for.
180	UnsolicitedExtension,
181	/// The server returned a protocol we did not ask for.
182	UnsolicitedProtocol,
183	/// An extension produced an error while encoding or decoding.
184	Extension(crate::BoxedError),
185	/// The HTTP entity could not be parsed successfully.
186	Http(crate::BoxedError),
187	/// UTF-8 decoding failed.
188	Utf8(str::Utf8Error),
189}
190
191impl fmt::Display for Error {
192	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
193		match self {
194			Error::Io(e) => write!(f, "i/o error: {}", e),
195			Error::UnsupportedHttpVersion => f.write_str("http version was not 1.1"),
196			Error::IncompleteHttpRequest => f.write_str("http request was incomplete"),
197			Error::SecWebSocketKeyInvalidLength(len) => {
198				write!(f, "Sec-WebSocket-Key header was {} bytes long, expected 24", len)
199			}
200			Error::InvalidRequestMethod => f.write_str("handshake was not a GET request"),
201			Error::HeaderNotFound(name) => write!(f, "header {} not found", name),
202			Error::UnexpectedHeader(name) => write!(f, "header {} had an unexpected value", name),
203			Error::InvalidSecWebSocketAccept => f.write_str("websocket key mismatch"),
204			Error::UnsolicitedExtension => f.write_str("unsolicited extension returned"),
205			Error::UnsolicitedProtocol => f.write_str("unsolicited protocol returned"),
206			Error::Extension(e) => write!(f, "extension error: {}", e),
207			Error::Http(e) => write!(f, "http parser error: {}", e),
208			Error::Utf8(e) => write!(f, "utf-8 decoding error: {}", e),
209		}
210	}
211}
212
213impl std::error::Error for Error {
214	fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
215		match self {
216			Error::Io(e) => Some(e),
217			Error::Extension(e) => Some(&**e),
218			Error::Http(e) => Some(&**e),
219			Error::Utf8(e) => Some(e),
220			Error::UnsupportedHttpVersion
221			| Error::IncompleteHttpRequest
222			| Error::SecWebSocketKeyInvalidLength(_)
223			| Error::InvalidRequestMethod
224			| Error::HeaderNotFound(_)
225			| Error::UnexpectedHeader(_)
226			| Error::InvalidSecWebSocketAccept
227			| Error::UnsolicitedExtension
228			| Error::UnsolicitedProtocol => None,
229		}
230	}
231}
232
233impl From<io::Error> for Error {
234	fn from(e: io::Error) -> Self {
235		Error::Io(e)
236	}
237}
238
239impl From<str::Utf8Error> for Error {
240	fn from(e: str::Utf8Error) -> Self {
241		Error::Utf8(e)
242	}
243}
244
245/// Owned value of the `Sec-WebSocket-Key` header.
246///
247/// Per [RFC 6455](https://datatracker.ietf.org/doc/html/rfc6455#section-4.1):
248///
249/// ```text
250/// (...) The value of this header field MUST be a
251/// nonce consisting of a randomly selected 16-byte value that has
252/// been base64-encoded (see Section 4 of [RFC4648]). (...)
253/// ```
254///
255/// Base64 encoding of the nonce produces 24 ASCII bytes, padding included.
256pub type WebSocketKey = [u8; 24];
257
258#[cfg(test)]
259mod tests {
260	use super::expect_ascii_header;
261
262	#[test]
263	fn header_match() {
264		let headers = &[
265			httparse::Header { name: "foo", value: b"a,b,c,d" },
266			httparse::Header { name: "foo", value: b"x" },
267			httparse::Header { name: "foo", value: b"y, z, a" },
268			httparse::Header { name: "bar", value: b"xxx" },
269			httparse::Header { name: "bar", value: b"sdfsdf 423 42 424" },
270			httparse::Header { name: "baz", value: b"123" },
271		];
272
273		assert!(expect_ascii_header(headers, "foo", "a").is_ok());
274		assert!(expect_ascii_header(headers, "foo", "b").is_ok());
275		assert!(expect_ascii_header(headers, "foo", "c").is_ok());
276		assert!(expect_ascii_header(headers, "foo", "d").is_ok());
277		assert!(expect_ascii_header(headers, "foo", "x").is_ok());
278		assert!(expect_ascii_header(headers, "foo", "y").is_ok());
279		assert!(expect_ascii_header(headers, "foo", "z").is_ok());
280		assert!(expect_ascii_header(headers, "foo", "a").is_ok());
281		assert!(expect_ascii_header(headers, "bar", "xxx").is_ok());
282		assert!(expect_ascii_header(headers, "bar", "sdfsdf 423 42 424").is_ok());
283		assert!(expect_ascii_header(headers, "baz", "123").is_ok());
284		assert!(expect_ascii_header(headers, "baz", "???").is_err());
285		assert!(expect_ascii_header(headers, "???", "x").is_err());
286	}
287}