1use super::{
14 append_extensions, configure_extensions, expect_ascii_header, with_first_header, Error, WebSocketKey,
15 MAX_NUM_HEADERS, SEC_WEBSOCKET_EXTENSIONS, SEC_WEBSOCKET_PROTOCOL,
16};
17use crate::connection::{self, Mode};
18use crate::extension::Extension;
19use bytes::BytesMut;
20use futures::prelude::*;
21use std::{mem, str};
22
23const MAX_HEADERS_SIZE: usize = 8 * 1024;
25const BLOCK_SIZE: usize = 8 * 1024;
26
27#[derive(Debug)]
29pub struct Server<'a, T> {
30 socket: T,
31 protocols: Vec<&'a str>,
33 extensions: Vec<Box<dyn Extension + Send>>,
35 buffer: BytesMut,
37}
38
39impl<'a, T: AsyncRead + AsyncWrite + Unpin> Server<'a, T> {
40 pub fn new(socket: T) -> Self {
42 Server { socket, protocols: Vec::new(), extensions: Vec::new(), buffer: BytesMut::new() }
43 }
44
45 pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self {
47 self.buffer = b;
48 self
49 }
50
51 pub fn take_buffer(&mut self) -> BytesMut {
53 mem::take(&mut self.buffer)
54 }
55
56 pub fn add_protocol(&mut self, p: &'a str) -> &mut Self {
58 self.protocols.push(p);
59 self
60 }
61
62 pub fn add_extension(&mut self, e: Box<dyn Extension + Send>) -> &mut Self {
64 self.extensions.push(e);
65 self
66 }
67
68 pub fn drain_extensions(&mut self) -> impl Iterator<Item = Box<dyn Extension + Send>> + '_ {
70 self.extensions.drain(..)
71 }
72
73 pub async fn receive_request(&mut self) -> Result<ClientRequest<'_>, Error> {
75 self.buffer.clear();
76
77 let mut skip = 0;
78
79 loop {
80 crate::read(&mut self.socket, &mut self.buffer, BLOCK_SIZE).await?;
81
82 let limit = std::cmp::min(self.buffer.len(), MAX_HEADERS_SIZE);
83
84 if self.buffer[skip..limit].windows(4).rev().any(|w| w == b"\r\n\r\n") {
87 break;
88 }
89
90 if limit == MAX_HEADERS_SIZE {
93 break;
94 }
95
96 skip = self.buffer.len().saturating_sub(4);
100 }
101
102 self.decode_request()
103 }
104
105 pub async fn send_response(&mut self, r: &Response<'_>) -> Result<(), Error> {
107 self.buffer.clear();
108 self.encode_response(r);
109 self.socket.write_all(&self.buffer).await?;
110 self.socket.flush().await?;
111 self.buffer.clear();
112 Ok(())
113 }
114
115 pub fn into_builder(mut self) -> connection::Builder<T> {
117 let mut builder = connection::Builder::new(self.socket, Mode::Server);
118 builder.set_buffer(self.buffer);
119 builder.add_extensions(self.extensions.drain(..));
120 builder
121 }
122
123 pub fn into_inner(self) -> T {
125 self.socket
126 }
127
128 fn decode_request(&mut self) -> Result<ClientRequest, Error> {
130 let mut header_buf = [httparse::EMPTY_HEADER; MAX_NUM_HEADERS];
131 let mut request = httparse::Request::new(&mut header_buf);
132
133 match request.parse(self.buffer.as_ref()) {
134 Ok(httparse::Status::Complete(_)) => (),
135 Ok(httparse::Status::Partial) => return Err(Error::IncompleteHttpRequest),
136 Err(e) => return Err(Error::Http(Box::new(e))),
137 };
138 if request.method != Some("GET") {
139 return Err(Error::InvalidRequestMethod);
140 }
141 if request.version != Some(1) {
142 return Err(Error::UnsupportedHttpVersion);
143 }
144
145 let host = with_first_header(&request.headers, "Host", Ok)?;
146
147 expect_ascii_header(request.headers, "Upgrade", "websocket")?;
148 expect_ascii_header(request.headers, "Connection", "upgrade")?;
149 expect_ascii_header(request.headers, "Sec-WebSocket-Version", "13")?;
150
151 let origin =
152 request.headers.iter().find_map(
153 |h| {
154 if h.name.eq_ignore_ascii_case("Origin") {
155 Some(h.value)
156 } else {
157 None
158 }
159 },
160 );
161 let headers = RequestHeaders { host, origin };
162
163 let ws_key = with_first_header(&request.headers, "Sec-WebSocket-Key", |k| {
164 WebSocketKey::try_from(k).map_err(|_| Error::SecWebSocketKeyInvalidLength(k.len()))
165 })?;
166
167 for h in request.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS)) {
168 configure_extensions(&mut self.extensions, std::str::from_utf8(h.value)?)?
169 }
170
171 let mut protocols = Vec::new();
172 for p in request.headers.iter().filter(|h| h.name.eq_ignore_ascii_case(SEC_WEBSOCKET_PROTOCOL)) {
173 if let Some(&p) = self.protocols.iter().find(|x| x.as_bytes() == p.value) {
174 protocols.push(p)
175 }
176 }
177
178 let path = request.path.unwrap_or("/");
179
180 Ok(ClientRequest { ws_key, protocols, path, headers })
181 }
182
183 fn encode_response(&mut self, response: &Response<'_>) {
185 match response {
186 Response::Accept { key, protocol } => {
187 let accept_value = super::generate_accept_key(&key);
188 self.buffer.extend_from_slice(
189 concat![
190 "HTTP/1.1 101 Switching Protocols",
191 "\r\nServer: soketto-",
192 env!("CARGO_PKG_VERSION"),
193 "\r\nUpgrade: websocket",
194 "\r\nConnection: upgrade",
195 "\r\nSec-WebSocket-Accept: ",
196 ]
197 .as_bytes(),
198 );
199 self.buffer.extend_from_slice(&accept_value);
200 if let Some(p) = protocol {
201 self.buffer.extend_from_slice(b"\r\nSec-WebSocket-Protocol: ");
202 self.buffer.extend_from_slice(p.as_bytes())
203 }
204 append_extensions(self.extensions.iter().filter(|e| e.is_enabled()), &mut self.buffer);
205 self.buffer.extend_from_slice(b"\r\n\r\n")
206 }
207 Response::Reject { status_code } => {
208 self.buffer.extend_from_slice(b"HTTP/1.1 ");
209 let (_, reason) = if let Ok(i) = STATUSCODES.binary_search_by_key(status_code, |(n, _)| *n) {
210 STATUSCODES[i]
211 } else {
212 (500, "500 Internal Server Error")
213 };
214 self.buffer.extend_from_slice(reason.as_bytes());
215 self.buffer.extend_from_slice(b"\r\n\r\n")
216 }
217 }
218 }
219}
220
221#[derive(Debug)]
223pub struct ClientRequest<'a> {
224 ws_key: WebSocketKey,
225 protocols: Vec<&'a str>,
226 path: &'a str,
227 headers: RequestHeaders<'a>,
228}
229
230#[derive(Debug, Copy, Clone)]
232pub struct RequestHeaders<'a> {
233 pub host: &'a [u8],
235 pub origin: Option<&'a [u8]>,
237}
238
239impl<'a> ClientRequest<'a> {
240 pub fn key(&self) -> WebSocketKey {
242 self.ws_key
243 }
244
245 pub fn protocols(&self) -> impl Iterator<Item = &str> {
247 self.protocols.iter().cloned()
248 }
249
250 pub fn path(&self) -> &str {
252 self.path
253 }
254
255 pub fn headers(&self) -> RequestHeaders {
257 self.headers
258 }
259}
260
261#[derive(Debug)]
263pub enum Response<'a> {
264 Accept { key: WebSocketKey, protocol: Option<&'a str> },
266 Reject { status_code: u16 },
268}
269
270const STATUSCODES: &[(u16, &str)] = &[
272 (100, "100 Continue"),
273 (101, "101 Switching Protocols"),
274 (102, "102 Processing"),
275 (200, "200 OK"),
276 (201, "201 Created"),
277 (202, "202 Accepted"),
278 (203, "203 Non Authoritative Information"),
279 (204, "204 No Content"),
280 (205, "205 Reset Content"),
281 (206, "206 Partial Content"),
282 (207, "207 Multi-Status"),
283 (208, "208 Already Reported"),
284 (226, "226 IM Used"),
285 (300, "300 Multiple Choices"),
286 (301, "301 Moved Permanently"),
287 (302, "302 Found"),
288 (303, "303 See Other"),
289 (304, "304 Not Modified"),
290 (305, "305 Use Proxy"),
291 (307, "307 Temporary Redirect"),
292 (308, "308 Permanent Redirect"),
293 (400, "400 Bad Request"),
294 (401, "401 Unauthorized"),
295 (402, "402 Payment Required"),
296 (403, "403 Forbidden"),
297 (404, "404 Not Found"),
298 (405, "405 Method Not Allowed"),
299 (406, "406 Not Acceptable"),
300 (407, "407 Proxy Authentication Required"),
301 (408, "408 Request Timeout"),
302 (409, "409 Conflict"),
303 (410, "410 Gone"),
304 (411, "411 Length Required"),
305 (412, "412 Precondition Failed"),
306 (413, "413 Payload Too Large"),
307 (414, "414 URI Too Long"),
308 (415, "415 Unsupported Media Type"),
309 (416, "416 Range Not Satisfiable"),
310 (417, "417 Expectation Failed"),
311 (418, "418 I'm a teapot"),
312 (421, "421 Misdirected Request"),
313 (422, "422 Unprocessable Entity"),
314 (423, "423 Locked"),
315 (424, "424 Failed Dependency"),
316 (426, "426 Upgrade Required"),
317 (428, "428 Precondition Required"),
318 (429, "429 Too Many Requests"),
319 (431, "431 Request Header Fields Too Large"),
320 (451, "451 Unavailable For Legal Reasons"),
321 (500, "500 Internal Server Error"),
322 (501, "501 Not Implemented"),
323 (502, "502 Bad Gateway"),
324 (503, "503 Service Unavailable"),
325 (504, "504 Gateway Timeout"),
326 (505, "505 HTTP Version Not Supported"),
327 (506, "506 Variant Also Negotiates"),
328 (507, "507 Insufficient Storage"),
329 (508, "508 Loop Detected"),
330 (510, "510 Not Extended"),
331 (511, "511 Network Authentication Required"),
332];