soketto/handshake/
http.rs1use super::{WebSocketKey, SEC_WEBSOCKET_EXTENSIONS};
18use crate::connection::{self, Mode};
19use crate::extension::Extension;
20use crate::handshake;
21use bytes::BytesMut;
22use futures::prelude::*;
23use http::{header, HeaderMap, Response};
24use std::mem;
25
26pub type Error = handshake::Error;
28
29pub struct Server {
33 extensions: Vec<Box<dyn Extension + Send>>,
35 buffer: BytesMut,
37}
38
39impl Server {
40 pub fn new() -> Self {
42 Server { extensions: Vec::new(), buffer: BytesMut::new() }
43 }
44
45 pub fn set_buffer(&mut self, b: BytesMut) -> &mut Self {
47 self.buffer = b;
48 self
49 }
50
51 pub fn take_buffer(&mut self) -> BytesMut {
53 mem::take(&mut self.buffer)
54 }
55
56 pub fn add_extension(&mut self, e: Box<dyn Extension + Send>) -> &mut Self {
58 self.extensions.push(e);
59 self
60 }
61
62 pub fn drain_extensions(&mut self) -> impl Iterator<Item = Box<dyn Extension + Send>> + '_ {
64 self.extensions.drain(..)
65 }
66
67 pub fn receive_request<B>(&mut self, req: &http::Request<B>) -> Result<http::Response<()>, Error> {
70 if !is_upgrade_request(&req) {
71 return Err(Error::InvalidSecWebSocketAccept);
72 }
73
74 let key = match req.headers().get("Sec-WebSocket-Key") {
75 Some(key) => key,
76 None => {
77 return Err(Error::HeaderNotFound("Sec-WebSocket-Key".into()).into());
78 }
79 };
80
81 if req.headers().get("Sec-WebSocket-Version").map(|v| v.as_bytes()) != Some(b"13") {
82 return Err(Error::HeaderNotFound("Sec-WebSocket-Version".into()).into());
83 }
84
85 let key: &WebSocketKey = match key.as_bytes().try_into() {
87 Ok(key) => key,
88 Err(_) => return Err(Error::InvalidSecWebSocketAccept),
89 };
90 let accept_key = handshake::generate_accept_key(key);
91
92 let extension_config = req
94 .headers()
95 .iter()
96 .filter(|&(name, _)| name.as_str().eq_ignore_ascii_case(SEC_WEBSOCKET_EXTENSIONS))
97 .map(|(_, value)| Ok(std::str::from_utf8(value.as_bytes())?.to_string()))
98 .collect::<Result<Vec<_>, Error>>()?;
99
100 for config_str in &extension_config {
102 handshake::configure_extensions(&mut self.extensions, &config_str)?;
103 }
104
105 let mut response = Response::builder()
107 .status(http::StatusCode::SWITCHING_PROTOCOLS)
108 .header(http::header::CONNECTION, "upgrade")
109 .header(http::header::UPGRADE, "websocket")
110 .header("Sec-WebSocket-Accept", &accept_key[..]);
111
112 if !self.extensions.is_empty() {
115 let mut buf = bytes::BytesMut::new();
116 let enabled_extensions = self.extensions.iter().filter(|e| e.is_enabled()).peekable();
117 handshake::append_extension_header_value(enabled_extensions, &mut buf);
118 response = response.header("Sec-WebSocket-Extensions", buf.as_ref());
119 }
120
121 let response = response.body(()).expect("bug: failed to build response");
122 Ok(response)
123 }
124
125 pub fn into_builder<T: AsyncRead + AsyncWrite + Unpin>(mut self, socket: T) -> connection::Builder<T> {
127 let mut builder = connection::Builder::new(socket, Mode::Server);
128 builder.set_buffer(self.buffer);
129 builder.add_extensions(self.extensions.drain(..));
130 builder
131 }
132}
133
134pub fn is_upgrade_request<B>(request: &http::Request<B>) -> bool {
136 header_contains_value(request.headers(), header::CONNECTION, b"upgrade")
137 && header_contains_value(request.headers(), header::UPGRADE, b"websocket")
138}
139
140fn header_contains_value(headers: &HeaderMap, header: header::HeaderName, value: &[u8]) -> bool {
142 pub fn trim(x: &[u8]) -> &[u8] {
143 let from = match x.iter().position(|x| !x.is_ascii_whitespace()) {
144 Some(i) => i,
145 None => return &[],
146 };
147 let to = x.iter().rposition(|x| !x.is_ascii_whitespace()).unwrap();
148 &x[from..=to]
149 }
150
151 for header in headers.get_all(header) {
152 if header.as_bytes().split(|&c| c == b',').any(|x| trim(x).eq_ignore_ascii_case(value)) {
153 return true;
154 }
155 }
156 false
157}