soketto/handshake/
client.rs

1// Copyright (c) 2019 Parity Technologies (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9//! Websocket client [handshake].
10//!
11//! [handshake]: https://tools.ietf.org/html/rfc6455#section-4
12
13use super::{
14	append_extensions, configure_extensions, expect_ascii_header, with_first_header, Error, WebSocketKey, KEY,
15	MAX_NUM_HEADERS, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_PROTOCOL,
16};
17use crate::connection::{self, Mode};
18use crate::{extension::Extension, Parsing};
19use base64::Engine;
20use bytes::{Buf, BytesMut};
21use futures::prelude::*;
22use sha1::{Digest, Sha1};
23use std::{mem, str};
24
25pub use httparse::Header;
26
27const BLOCK_SIZE: usize = 8 * 1024;
28
29/// Websocket client handshake.
30#[derive(Debug)]
31pub struct Client<'a, T> {
32	/// The underlying async I/O resource.
33	socket: T,
34	/// The HTTP host to send the handshake to.
35	host: &'a str,
36	/// The HTTP host resource.
37	resource: &'a str,
38	/// The HTTP headers.
39	headers: &'a [Header<'a>],
40	/// A buffer holding the base-64 encoded request nonce.
41	nonce: WebSocketKey,
42	/// The protocols to include in the handshake.
43	protocols: Vec<&'a str>,
44	/// The extensions the client wishes to include in the request.
45	extensions: Vec<Box<dyn Extension + Send>>,
46	/// Encoding/decoding buffer.
47	buffer: BytesMut,
48}
49
50impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> {
51	/// Create a new client handshake for some host and resource.
52	pub fn new(socket: T, host: &'a str, resource: &'a str) -> Self {
53		Client {
54			socket,
55			host,
56			resource,
57			headers: &[],
58			nonce: [0; 24],
59			protocols: Vec::new(),
60			extensions: Vec::new(),
61			buffer: BytesMut::new(),
62		}
63	}
64
65	/// Override the buffer to use for request/response handling.
66	pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self {
67		self.buffer = b;
68		self
69	}
70
71	/// Extract the buffer.
72	pub fn take_buffer(&mut self) -> BytesMut {
73		mem::take(&mut self.buffer)
74	}
75
76	/// Set connection headers to a slice. These headers are not checked for validity,
77	/// the caller of this method is responsible for verification as well as avoiding
78	/// conflicts with internally set headers.
79	pub fn set_headers(&mut self, h: &'a [Header]) -> &mut Self {
80		self.headers = h;
81		self
82	}
83
84	/// Add a protocol to be included in the handshake.
85	pub fn add_protocol(&mut self, p: &'a str) -> &mut Self {
86		self.protocols.push(p);
87		self
88	}
89
90	/// Add an extension to be included in the handshake.
91	pub fn add_extension(&mut self, e: Box<dyn Extension + Send>) -> &mut Self {
92		self.extensions.push(e);
93		self
94	}
95
96	/// Get back all extensions.
97	pub fn drain_extensions(&mut self) -> impl Iterator<Item = Box<dyn Extension + Send>> + '_ {
98		self.extensions.drain(..)
99	}
100
101	/// Initiate client handshake request to server and get back the response.
102	pub async fn handshake(&mut self) -> Result<ServerResponse, Error> {
103		self.buffer.clear();
104		self.encode_request();
105		self.socket.write_all(&self.buffer).await?;
106		self.socket.flush().await?;
107		self.buffer.clear();
108
109		loop {
110			crate::read(&mut self.socket, &mut self.buffer, BLOCK_SIZE).await?;
111			if let Parsing::Done { value, offset } = self.decode_response()? {
112				self.buffer.advance(offset);
113				return Ok(value);
114			}
115		}
116	}
117
118	/// Turn this handshake into a [`connection::Builder`].
119	pub fn into_builder(mut self) -> connection::Builder<T> {
120		let mut builder = connection::Builder::new(self.socket, Mode::Client);
121		builder.set_buffer(self.buffer);
122		builder.add_extensions(self.extensions.drain(..));
123		builder
124	}
125
126	/// Get out the inner socket of the client.
127	pub fn into_inner(self) -> T {
128		self.socket
129	}
130
131	/// Encode the client handshake as a request, ready to be sent to the server.
132	fn encode_request(&mut self) {
133		let nonce: [u8; 16] = rand::random();
134		base64::engine::general_purpose::STANDARD
135			.encode_slice(nonce, &mut self.nonce)
136			.expect("encoding to base64 is exactly 16 bytes; qed");
137		self.buffer.extend_from_slice(b"GET ");
138		self.buffer.extend_from_slice(self.resource.as_bytes());
139		self.buffer.extend_from_slice(b" HTTP/1.1");
140		self.buffer.extend_from_slice(b"\r\nHost: ");
141		self.buffer.extend_from_slice(self.host.as_bytes());
142		self.buffer.extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: Upgrade");
143		self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Key: ");
144		self.buffer.extend_from_slice(&self.nonce);
145		self.headers.iter().for_each(|h| {
146			self.buffer.extend_from_slice(b"\r\n");
147			self.buffer.extend_from_slice(h.name.as_bytes());
148			self.buffer.extend_from_slice(b": ");
149			self.buffer.extend_from_slice(h.value);
150		});
151		if let Some((last, prefix)) = self.protocols.split_last() {
152			self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: ");
153			for p in prefix {
154				self.buffer.extend_from_slice(p.as_bytes());
155				self.buffer.extend_from_slice(b",")
156			}
157			self.buffer.extend_from_slice(last.as_bytes())
158		}
159		append_extensions(&self.extensions, &mut self.buffer);
160		self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Version: 13\r\n\r\n")
161	}
162
163	/// Decode the server response to this client request.
164	fn decode_response(&mut self) -> Result<Parsing<ServerResponse>, Error> {
165		let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS];
166		let mut response = httparse::Response::new(&mut header_buf);
167
168		let offset = match response.parse(self.buffer.as_ref()) {
169			Ok(httparse::Status::Complete(off)) => off,
170			Ok(httparse::Status::Partial) => return Ok(Parsing::NeedMore(())),
171			Err(e) => return Err(Error::Http(Box::new(e))),
172		};
173
174		if response.version != Some(1) {
175			return Err(Error::UnsupportedHttpVersion);
176		}
177
178		match response.code {
179			Some(101) => (),
180			Some(code @ (301..=303)) | Some(code @ 307) | Some(code @ 308) => {
181				// redirect response
182				let location =
183					with_first_header(response.headers, "Location", |loc| Ok(String::from(std::str::from_utf8(loc)?)))?;
184				let response = ServerResponse::Redirect { status_code: code, location };
185				return Ok(Parsing::Done { value: response, offset });
186			}
187			other => {
188				let response = ServerResponse::Rejected { status_code: other.unwrap_or(0) };
189				return Ok(Parsing::Done { value: response, offset });
190			}
191		}
192
193		expect_ascii_header(response.headers, "Upgrade", "websocket")?;
194		expect_ascii_header(response.headers, "Connection", "upgrade")?;
195
196		with_first_header(&response.headers, "Sec-WebSocket-Accept", |theirs| {
197			let mut digest = Sha1::new();
198			digest.update(&self.nonce);
199			digest.update(KEY);
200			let ours = base64::engine::general_purpose::STANDARD.encode(digest.finalize());
201			if ours.as_bytes() != theirs {
202				return Err(Error::InvalidSecWebSocketAccept);
203			}
204			Ok(())
205		})?;
206
207		// Parse `Sec-WebSocket-Extensions` headers.
208
209		for h in response.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) {
210			configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)?
211		}
212
213		// Match `Sec-WebSocket-Protocol` header.
214
215		let mut selected_proto = None;
216		if let Some(tp) = response.headers.iter().find(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL)) {
217			if let Some(&p) = self.protocols.iter().find(|x| x.as_bytes() == tp.value) {
218				selected_proto = Some(String::from(p))
219			} else {
220				return Err(Error::UnsolicitedProtocol);
221			}
222		}
223
224		let response = ServerResponse::Accepted { protocol: selected_proto };
225		Ok(Parsing::Done { value: response, offset })
226	}
227}
228
229/// Handshake response received from the server.
230#[derive(Debug)]
231pub enum ServerResponse {
232	/// The server has accepted our request.
233	Accepted {
234		/// The protocol (if any) the server has selected.
235		protocol: Option<String>,
236	},
237	/// The server is redirecting us to some other location.
238	Redirect {
239		/// The HTTP response status code.
240		status_code: u16,
241		/// The location URL we should go to.
242		location: String,
243	},
244	/// The server rejected our request.
245	Rejected {
246		/// HTTP response status code.
247		status_code: u16,
248	},
249}