soketto/handshake/
client.rs1use super::{
14 append_extensions, configure_extensions, expect_ascii_header, with_first_header, Error, WebSocketKey, KEY,
15 MAX_NUM_HEADERS, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_PROTOCOL,
16};
17use crate::connection::{self, Mode};
18use crate::{extension::Extension, Parsing};
19use base64::Engine;
20use bytes::{Buf, BytesMut};
21use futures::prelude::*;
22use sha1::{Digest, Sha1};
23use std::{mem, str};
24
25pub use httparse::Header;
26
27const BLOCK_SIZE: usize = 8 * 1024;
28
29#[derive(Debug)]
31pub struct Client<'a, T> {
32 socket: T,
34 host: &'a str,
36 resource: &'a str,
38 headers: &'a [Header<'a>],
40 nonce: WebSocketKey,
42 protocols: Vec<&'a str>,
44 extensions: Vec<Box<dyn Extension + Send>>,
46 buffer: BytesMut,
48}
49
50impl<'a, T: AsyncRead + AsyncWrite + Unpin> Client<'a, T> {
51 pub fn new(socket: T, host: &'a str, resource: &'a str) -> Self {
53 Client {
54 socket,
55 host,
56 resource,
57 headers: &[],
58 nonce: [0; 24],
59 protocols: Vec::new(),
60 extensions: Vec::new(),
61 buffer: BytesMut::new(),
62 }
63 }
64
65 pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self {
67 self.buffer = b;
68 self
69 }
70
71 pub fn take_buffer(&mut self) -> BytesMut {
73 mem::take(&mut self.buffer)
74 }
75
76 pub fn set_headers(&mut self, h: &'a [Header]) -> &mut Self {
80 self.headers = h;
81 self
82 }
83
84 pub fn add_protocol(&mut self, p: &'a str) -> &mut Self {
86 self.protocols.push(p);
87 self
88 }
89
90 pub fn add_extension(&mut self, e: Box<dyn Extension + Send>) -> &mut Self {
92 self.extensions.push(e);
93 self
94 }
95
96 pub fn drain_extensions(&mut self) -> impl Iterator<Item = Box<dyn Extension + Send>> + '_ {
98 self.extensions.drain(..)
99 }
100
101 pub async fn handshake(&mut self) -> Result<ServerResponse, Error> {
103 self.buffer.clear();
104 self.encode_request();
105 self.socket.write_all(&self.buffer).await?;
106 self.socket.flush().await?;
107 self.buffer.clear();
108
109 loop {
110 crate::read(&mut self.socket, &mut self.buffer, BLOCK_SIZE).await?;
111 if let Parsing::Done { value, offset } = self.decode_response()? {
112 self.buffer.advance(offset);
113 return Ok(value);
114 }
115 }
116 }
117
118 pub fn into_builder(mut self) -> connection::Builder<T> {
120 let mut builder = connection::Builder::new(self.socket, Mode::Client);
121 builder.set_buffer(self.buffer);
122 builder.add_extensions(self.extensions.drain(..));
123 builder
124 }
125
126 pub fn into_inner(self) -> T {
128 self.socket
129 }
130
131 fn encode_request(&mut self) {
133 let nonce: [u8; 16] = rand::random();
134 base64::engine::general_purpose::STANDARD
135 .encode_slice(nonce, &mut self.nonce)
136 .expect("encoding to base64 is exactly 16 bytes; qed");
137 self.buffer.extend_from_slice(b"GET ");
138 self.buffer.extend_from_slice(self.resource.as_bytes());
139 self.buffer.extend_from_slice(b" HTTP/1.1");
140 self.buffer.extend_from_slice(b"\r\nHost: ");
141 self.buffer.extend_from_slice(self.host.as_bytes());
142 self.buffer.extend_from_slice(b"\r\nUpgrade: websocket\r\nConnection: Upgrade");
143 self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Key: ");
144 self.buffer.extend_from_slice(&self.nonce);
145 self.headers.iter().for_each(|h| {
146 self.buffer.extend_from_slice(b"\r\n");
147 self.buffer.extend_from_slice(h.name.as_bytes());
148 self.buffer.extend_from_slice(b": ");
149 self.buffer.extend_from_slice(h.value);
150 });
151 if let Some((last, prefix)) = self.protocols.split_last() {
152 self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: ");
153 for p in prefix {
154 self.buffer.extend_from_slice(p.as_bytes());
155 self.buffer.extend_from_slice(b",")
156 }
157 self.buffer.extend_from_slice(last.as_bytes())
158 }
159 append_extensions(&self.extensions, &mut self.buffer);
160 self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Version: 13\r\n\r\n")
161 }
162
163 fn decode_response(&mut self) -> Result<Parsing<ServerResponse>, Error> {
165 let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS];
166 let mut response = httparse::Response::new(&mut header_buf);
167
168 let offset = match response.parse(self.buffer.as_ref()) {
169 Ok(httparse::Status::Complete(off)) => off,
170 Ok(httparse::Status::Partial) => return Ok(Parsing::NeedMore(())),
171 Err(e) => return Err(Error::Http(Box::new(e))),
172 };
173
174 if response.version != Some(1) {
175 return Err(Error::UnsupportedHttpVersion);
176 }
177
178 match response.code {
179 Some(101) => (),
180 Some(code @ (301..=303)) | Some(code @ 307) | Some(code @ 308) => {
181 let location =
183 with_first_header(response.headers, "Location", |loc| Ok(String::from(std::str::from_utf8(loc)?)))?;
184 let response = ServerResponse::Redirect { status_code: code, location };
185 return Ok(Parsing::Done { value: response, offset });
186 }
187 other => {
188 let response = ServerResponse::Rejected { status_code: other.unwrap_or(0) };
189 return Ok(Parsing::Done { value: response, offset });
190 }
191 }
192
193 expect_ascii_header(response.headers, "Upgrade", "websocket")?;
194 expect_ascii_header(response.headers, "Connection", "upgrade")?;
195
196 with_first_header(&response.headers, "Sec-WebSocket-Accept", |theirs| {
197 let mut digest = Sha1::new();
198 digest.update(&self.nonce);
199 digest.update(KEY);
200 let ours = base64::engine::general_purpose::STANDARD.encode(digest.finalize());
201 if ours.as_bytes() != theirs {
202 return Err(Error::InvalidSecWebSocketAccept);
203 }
204 Ok(())
205 })?;
206
207 for h in response.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) {
210 configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)?
211 }
212
213 let mut selected_proto = None;
216 if let Some(tp) = response.headers.iter().find(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL)) {
217 if let Some(&p) = self.protocols.iter().find(|x| x.as_bytes() == tp.value) {
218 selected_proto = Some(String::from(p))
219 } else {
220 return Err(Error::UnsolicitedProtocol);
221 }
222 }
223
224 let response = ServerResponse::Accepted { protocol: selected_proto };
225 Ok(Parsing::Done { value: response, offset })
226 }
227}
228
229#[derive(Debug)]
231pub enum ServerResponse {
232 Accepted {
234 protocol: Option<String>,
236 },
237 Redirect {
239 status_code: u16,
241 location: String,
243 },
244 Rejected {
246 status_code: u16,
248 },
249}