1use crate::data::{ByteSlice125, Data, Incoming};
13use crate::{
14 base::{self, Header, OpCode, MAX_HEADER_SIZE},
15 extension::Extension,
16 Parsing, Storage,
17};
18use bytes::{Buf, BytesMut};
19use futures::{
20 io::{ReadHalf, WriteHalf},
21 lock::BiLock,
22 prelude::*,
23};
24use std::{fmt, io, str};
25
26const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
28
29const MAX_FRAME_SIZE: usize = MAX_MESSAGE_SIZE;
31
32#[derive(Copy, Clone, Debug, PartialEq, Eq)]
34pub enum Mode {
35 Client,
37 Server,
39}
40
41impl Mode {
42 pub fn is_client(self) -> bool {
43 if let Mode::Client = self {
44 true
45 } else {
46 false
47 }
48 }
49
50 pub fn is_server(self) -> bool {
51 !self.is_client()
52 }
53}
54
55#[derive(Clone, Copy, Debug)]
57struct Id(u32);
58
59impl fmt::Display for Id {
60 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61 write!(f, "{:08x}", self.0)
62 }
63}
64
65#[derive(Debug)]
67pub struct Sender<T> {
68 id: Id,
69 mode: Mode,
70 codec: base::Codec,
71 writer: BiLock<WriteHalf<T>>,
72 mask_buffer: Vec<u8>,
73 extensions: BiLock<Vec<Box<dyn Extension + Send>>>,
74 has_extensions: bool,
75}
76
77#[derive(Debug)]
79pub struct Receiver<T> {
80 id: Id,
81 mode: Mode,
82 codec: base::Codec,
83 reader: ReadHalf<T>,
84 writer: BiLock<WriteHalf<T>>,
85 extensions: BiLock<Vec<Box<dyn Extension + Send>>>,
86 has_extensions: bool,
87 buffer: BytesMut,
88 ctrl_buffer: BytesMut,
89 max_message_size: usize,
90 is_closed: bool,
91}
92
93#[derive(Debug)]
99pub struct Builder<T> {
100 id: Id,
101 mode: Mode,
102 socket: T,
103 codec: base::Codec,
104 extensions: Vec<Box<dyn Extension + Send>>,
105 buffer: BytesMut,
106 max_message_size: usize,
107}
108
109impl<T: AsyncRead + AsyncWrite + Unpin> Builder<T> {
110 pub fn new(socket: T, mode: Mode) -> Self {
119 let mut codec = base::Codec::default();
120 codec.set_max_data_size(MAX_FRAME_SIZE);
121 Builder {
122 id: Id(rand::random()),
123 mode,
124 socket,
125 codec,
126 extensions: Vec::new(),
127 buffer: BytesMut::new(),
128 max_message_size: MAX_MESSAGE_SIZE,
129 }
130 }
131
132 pub fn set_buffer(&mut self, b: BytesMut) {
134 self.buffer = b
135 }
136
137 pub fn add_extensions<I>(&mut self, extensions: I)
141 where
142 I: IntoIterator<Item = Box<dyn Extension + Send>>,
143 {
144 for e in extensions.into_iter().filter(|e| e.is_enabled()) {
145 log::debug!("{}: using extension: {}", self.id, e.name());
146 self.codec.add_reserved_bits(e.reserved_bits());
147 self.extensions.push(e)
148 }
149 }
150
151 pub fn set_max_message_size(&mut self, max: usize) {
158 self.max_message_size = max
159 }
160
161 pub fn set_max_frame_size(&mut self, max: usize) {
163 self.codec.set_max_data_size(max);
164 }
165
166 pub fn finish(self) -> (Sender<T>, Receiver<T>) {
168 let (rhlf, whlf) = self.socket.split();
169 let (wrt1, wrt2) = BiLock::new(whlf);
170 let has_extensions = !self.extensions.is_empty();
171 let (ext1, ext2) = BiLock::new(self.extensions);
172
173 let recv = Receiver {
174 id: self.id,
175 mode: self.mode,
176 reader: rhlf,
177 writer: wrt1,
178 codec: self.codec.clone(),
179 extensions: ext1,
180 has_extensions,
181 buffer: self.buffer,
182 ctrl_buffer: BytesMut::new(),
183 max_message_size: self.max_message_size,
184 is_closed: false,
185 };
186
187 let send = Sender {
188 id: self.id,
189 mode: self.mode,
190 writer: wrt2,
191 mask_buffer: Vec::new(),
192 codec: self.codec,
193 extensions: ext2,
194 has_extensions,
195 };
196
197 (send, recv)
198 }
199}
200
201impl<T: AsyncRead + AsyncWrite + Unpin> Receiver<T> {
202 pub async fn receive(&mut self, message: &mut Vec<u8>) -> Result<Incoming<'_>, Error> {
213 let mut first_fragment_opcode = None;
214 let mut length: usize = 0;
215 let message_len = message.len();
216 loop {
217 if self.is_closed {
218 log::debug!("{}: cannot receive, connection is closed", self.id);
219 return Err(Error::Closed);
220 }
221
222 self.ctrl_buffer.clear();
223 let mut header = self.receive_header().await?;
224 log::trace!("{}: recv: {}", self.id, header);
225
226 if header.opcode().is_control() {
228 self.read_buffer(&header).await?;
229 self.ctrl_buffer = self.buffer.split_to(header.payload_len());
230 base::Codec::apply_mask(&header, &mut self.ctrl_buffer);
231 if header.opcode() == OpCode::Pong {
232 return Ok(Incoming::Pong(&self.ctrl_buffer[..]));
233 }
234 if let Some(close_reason) = self.on_control(&header).await? {
235 log::trace!("{}: recv, incoming CLOSE: {:?}", self.id, close_reason);
236 return Ok(Incoming::Closed(close_reason));
237 }
238 continue;
239 }
240
241 length = length.saturating_add(header.payload_len());
242
243 if length > self.max_message_size {
245 log::warn!("{}: accumulated message length exceeds maximum", self.id);
246
247 discard_bytes(length as u64, &mut self.reader).await?;
249 return Err(Error::MessageTooLarge { current: length, maximum: self.max_message_size });
250 }
251
252 {
254 let old_msg_len = message.len();
255
256 let bytes_to_read = {
257 let required = header.payload_len();
258 let buffered = self.buffer.len();
259
260 if buffered == 0 {
261 required
262 } else if required > buffered {
263 message.extend_from_slice(&self.buffer);
264 self.buffer.clear();
265 required - buffered
266 } else {
267 message.extend_from_slice(&self.buffer.split_to(required));
268 0
269 }
270 };
271
272 if bytes_to_read > 0 {
273 let n = message.len();
274 message.resize(n + bytes_to_read, 0u8);
275 self.reader.read_exact(&mut message[n..]).await?
276 }
277
278 debug_assert_eq!(header.payload_len(), message.len() - old_msg_len);
279
280 base::Codec::apply_mask(&header, &mut message[old_msg_len..]);
281 }
282
283 match (header.is_fin(), header.opcode()) {
284 (false, OpCode::Continue) => {
285 if first_fragment_opcode.is_none() {
287 log::debug!("{}: continue frame while not processing message fragments", self.id);
288 return Err(Error::UnexpectedOpCode(OpCode::Continue));
289 }
290 continue;
291 }
292 (false, oc) => {
293 if first_fragment_opcode.is_some() {
295 log::debug!("{}: initial fragment while processing a fragmented message", self.id);
296 return Err(Error::UnexpectedOpCode(oc));
297 }
298 first_fragment_opcode = Some(oc);
299 self.decode_with_extensions(&mut header, message).await?;
300 continue;
301 }
302 (true, OpCode::Continue) => {
303 if let Some(oc) = first_fragment_opcode.take() {
305 header.set_payload_len(message.len());
306 log::trace!("{}: last fragment: total length = {} bytes", self.id, message.len());
307 self.decode_with_extensions(&mut header, message).await?;
308 header.set_opcode(oc);
309 } else {
310 log::debug!("{}: last continue frame while not processing message fragments", self.id);
311 return Err(Error::UnexpectedOpCode(OpCode::Continue));
312 }
313 }
314 (true, oc) => {
315 if first_fragment_opcode.is_some() {
317 log::debug!("{}: regular message while processing fragmented message", self.id);
318 return Err(Error::UnexpectedOpCode(oc));
319 }
320 self.decode_with_extensions(&mut header, message).await?
321 }
322 }
323
324 let num_bytes = message.len() - message_len;
325
326 if header.opcode() == OpCode::Text {
327 return Ok(Incoming::Data(Data::Text(num_bytes)));
328 } else {
329 return Ok(Incoming::Data(Data::Binary(num_bytes)));
330 }
331 }
332 }
333
334 pub async fn receive_data(&mut self, message: &mut Vec<u8>) -> Result<Data, Error> {
336 loop {
337 if let Incoming::Data(d) = self.receive(message).await? {
338 return Ok(d);
339 }
340 }
341 }
342
343 async fn receive_header(&mut self) -> Result<Header, Error> {
345 loop {
346 match self.codec.decode_header(&self.buffer)? {
347 Parsing::Done { value: header, offset } => {
348 debug_assert!(offset <= MAX_HEADER_SIZE);
349 self.buffer.advance(offset);
350 return Ok(header);
351 }
352 Parsing::NeedMore(n) => crate::read(&mut self.reader, &mut self.buffer, n).await?,
353 }
354 }
355 }
356
357 async fn read_buffer(&mut self, header: &Header) -> Result<(), Error> {
359 if header.payload_len() <= self.buffer.len() {
360 return Ok(());
361 }
362 let i = self.buffer.len();
363 let d = header.payload_len() - i;
364 self.buffer.resize(i + d, 0u8);
365 self.reader.read_exact(&mut self.buffer[i..]).await?;
366 Ok(())
367 }
368
369 async fn on_control(&mut self, header: &Header) -> Result<Option<CloseReason>, Error> {
375 match header.opcode() {
376 OpCode::Ping => {
377 let mut answer = Header::new(OpCode::Pong);
378 let mut unused = Vec::new();
379 let mut data = Storage::Unique(&mut self.ctrl_buffer);
380 write(self.id, self.mode, &mut self.codec, &mut self.writer, &mut answer, &mut data, &mut unused)
381 .await?;
382 self.flush().await?;
383 Ok(None)
384 }
385 OpCode::Pong => Ok(None),
386 OpCode::Close => {
387 log::trace!("{}: Acknowledging CLOSE to sender", self.id);
388 self.is_closed = true;
389 let (mut header, reason) = close_answer(&self.ctrl_buffer)?;
390 let mut unused = Vec::new();
392 if let Some(CloseReason { code, .. }) = reason {
393 let mut data = code.to_be_bytes();
394 let mut data = Storage::Unique(&mut data);
395 let _ = write(
396 self.id,
397 self.mode,
398 &mut self.codec,
399 &mut self.writer,
400 &mut header,
401 &mut data,
402 &mut unused,
403 )
404 .await;
405 } else {
406 let mut data = Storage::Unique(&mut []);
407 let _ = write(
408 self.id,
409 self.mode,
410 &mut self.codec,
411 &mut self.writer,
412 &mut header,
413 &mut data,
414 &mut unused,
415 )
416 .await;
417 }
418 self.flush().await?;
419 self.writer.lock().await.close().await?;
420 Ok(reason)
421 }
422 OpCode::Binary
423 | OpCode::Text
424 | OpCode::Continue
425 | OpCode::Reserved3
426 | OpCode::Reserved4
427 | OpCode::Reserved5
428 | OpCode::Reserved6
429 | OpCode::Reserved7
430 | OpCode::Reserved11
431 | OpCode::Reserved12
432 | OpCode::Reserved13
433 | OpCode::Reserved14
434 | OpCode::Reserved15 => Err(Error::UnexpectedOpCode(header.opcode())),
435 }
436 }
437
438 async fn decode_with_extensions(&mut self, header: &mut Header, message: &mut Vec<u8>) -> Result<(), Error> {
440 if !self.has_extensions {
441 return Ok(());
442 }
443 for e in self.extensions.lock().await.iter_mut() {
444 log::trace!("{}: decoding with extension: {}", self.id, e.name());
445 e.decode(header, message).map_err(Error::Extension)?
446 }
447 Ok(())
448 }
449
450 async fn flush(&mut self) -> Result<(), Error> {
452 log::trace!("{}: Receiver flushing connection", self.id);
453 if self.is_closed {
454 return Ok(());
455 }
456 self.writer.lock().await.flush().await.or(Err(Error::Closed))
457 }
458}
459
460impl<T: AsyncRead + AsyncWrite + Unpin> Sender<T> {
461 pub async fn send_text(&mut self, data: impl AsRef<str>) -> Result<(), Error> {
463 let mut header = Header::new(OpCode::Text);
464 self.send_frame(&mut header, &mut Storage::Shared(data.as_ref().as_bytes())).await
465 }
466
467 pub async fn send_text_owned(&mut self, data: String) -> Result<(), Error> {
471 let mut header = Header::new(OpCode::Text);
472 self.send_frame(&mut header, &mut Storage::Owned(data.into_bytes())).await
473 }
474
475 pub async fn send_binary(&mut self, data: impl AsRef<[u8]>) -> Result<(), Error> {
477 let mut header = Header::new(OpCode::Binary);
478 self.send_frame(&mut header, &mut Storage::Shared(data.as_ref())).await
479 }
480
481 pub async fn send_binary_mut(&mut self, mut data: impl AsMut<[u8]>) -> Result<(), Error> {
486 let mut header = Header::new(OpCode::Binary);
487 self.send_frame(&mut header, &mut Storage::Unique(data.as_mut())).await
488 }
489
490 pub async fn send_ping(&mut self, data: ByteSlice125<'_>) -> Result<(), Error> {
492 let mut header = Header::new(OpCode::Ping);
493 self.write(&mut header, &mut Storage::Shared(data.as_ref())).await
494 }
495
496 pub async fn send_pong(&mut self, data: ByteSlice125<'_>) -> Result<(), Error> {
498 let mut header = Header::new(OpCode::Pong);
499 self.write(&mut header, &mut Storage::Shared(data.as_ref())).await
500 }
501
502 pub async fn flush(&mut self) -> Result<(), Error> {
504 log::trace!("{}: Sender flushing connection", self.id);
505 self.writer.lock().await.flush().await.or(Err(Error::Closed))
506 }
507
508 pub async fn close(&mut self) -> Result<(), Error> {
510 log::trace!("{}: closing connection", self.id);
511 let mut header = Header::new(OpCode::Close);
512 let code = 1000_u16.to_be_bytes(); self.write(&mut header, &mut Storage::Shared(&code[..])).await?;
514 self.flush().await?;
515 self.writer.lock().await.close().await.or(Err(Error::Closed))
516 }
517
518 async fn send_frame(&mut self, header: &mut Header, data: &mut Storage<'_>) -> Result<(), Error> {
522 if !self.has_extensions {
523 return self.write(header, data).await;
524 }
525
526 for e in self.extensions.lock().await.iter_mut() {
527 log::trace!("{}: encoding with extension: {}", self.id, e.name());
528 e.encode(header, data).map_err(Error::Extension)?
529 }
530
531 self.write(header, data).await
532 }
533
534 async fn write(&mut self, header: &mut Header, data: &mut Storage<'_>) -> Result<(), Error> {
539 write(self.id, self.mode, &mut self.codec, &mut self.writer, header, data, &mut self.mask_buffer).await
540 }
541}
542
543async fn write<T: AsyncWrite + Unpin>(
545 id: Id,
546 mode: Mode,
547 codec: &mut base::Codec,
548 writer: &mut BiLock<WriteHalf<T>>,
549 header: &mut Header,
550 data: &mut Storage<'_>,
551 mask_buffer: &mut Vec<u8>,
552) -> Result<(), Error> {
553 if mode.is_client() {
554 header.set_masked(true);
555 header.set_mask(rand::random());
556 }
557 header.set_payload_len(data.as_ref().len());
558
559 log::trace!("{}: send: {}", id, header);
560
561 let header_bytes = codec.encode_header(&header);
562 let mut w = writer.lock().await;
563 w.write_all(&header_bytes).await.or(Err(Error::Closed))?;
564
565 if !header.is_masked() {
566 return w.write_all(data.as_ref()).await.or(Err(Error::Closed));
567 }
568
569 match data {
570 Storage::Shared(slice) => {
571 mask_buffer.clear();
572 mask_buffer.extend_from_slice(slice);
573 base::Codec::apply_mask(header, mask_buffer);
574 w.write_all(mask_buffer).await.or(Err(Error::Closed))
575 }
576 Storage::Unique(slice) => {
577 base::Codec::apply_mask(header, slice);
578 w.write_all(slice).await.or(Err(Error::Closed))
579 }
580 Storage::Owned(ref mut bytes) => {
581 base::Codec::apply_mask(header, bytes);
582 w.write_all(bytes).await.or(Err(Error::Closed))
583 }
584 }
585}
586
587fn close_answer(data: &[u8]) -> Result<(Header, Option<CloseReason>), Error> {
590 let answer = Header::new(OpCode::Close);
591 if data.len() < 2 {
592 return Ok((answer, None));
593 }
594 let descr = std::str::from_utf8(&data[2..])?.into();
596 let code = u16::from_be_bytes([data[0], data[1]]);
597 let reason = CloseReason { code, descr: Some(descr) };
598
599 match code {
603 | 1000 ..= 1003
604 | 1007 ..= 1011
605 | 1012 | 1013 | 1015
608 | 3000 ..= 4999 => Ok((answer, Some(reason))), _ => {
610 Ok((answer, Some(CloseReason { code: 1002, descr: None})))
612 }
613 }
614}
615
616#[non_exhaustive]
618#[derive(Debug)]
619pub enum Error {
620 Io(io::Error),
622 Codec(base::Error),
624 Extension(crate::BoxedError),
626 UnexpectedOpCode(OpCode),
628 Utf8(str::Utf8Error),
630 MessageTooLarge { current: usize, maximum: usize },
632 Closed,
634}
635
636#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
638pub struct CloseReason {
639 pub code: u16,
640 pub descr: Option<String>,
641}
642
643impl fmt::Display for Error {
644 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
645 match self {
646 Error::Io(e) => write!(f, "i/o error: {}", e),
647 Error::Codec(e) => write!(f, "codec error: {}", e),
648 Error::Extension(e) => write!(f, "extension error: {}", e),
649 Error::UnexpectedOpCode(c) => write!(f, "unexpected opcode: {}", c),
650 Error::Utf8(e) => write!(f, "utf-8 error: {}", e),
651 Error::MessageTooLarge { current, maximum } => {
652 write!(f, "message too large: len >= {}, maximum = {}", current, maximum)
653 }
654 Error::Closed => f.write_str("connection closed"),
655 }
656 }
657}
658
659impl std::error::Error for Error {
660 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
661 match self {
662 Error::Io(e) => Some(e),
663 Error::Codec(e) => Some(e),
664 Error::Extension(e) => Some(&**e),
665 Error::Utf8(e) => Some(e),
666 Error::UnexpectedOpCode(_) | Error::MessageTooLarge { .. } | Error::Closed => None,
667 }
668 }
669}
670
671impl From<io::Error> for Error {
672 fn from(e: io::Error) -> Self {
673 if e.kind() == io::ErrorKind::UnexpectedEof {
674 Error::Closed
675 } else {
676 Error::Io(e)
677 }
678 }
679}
680
681impl From<str::Utf8Error> for Error {
682 fn from(e: str::Utf8Error) -> Self {
683 Error::Utf8(e)
684 }
685}
686
687impl From<base::Error> for Error {
688 fn from(e: base::Error) -> Self {
689 Error::Codec(e)
690 }
691}
692
693async fn discard_bytes<R: AsyncRead + Unpin>(n: u64, reader: R) -> Result<u64, io::Error> {
695 futures::io::copy(&mut reader.take(n), &mut futures::io::sink()).await
696}
697
698#[cfg(test)]
699mod tests {
700 use super::discard_bytes;
701 use futures::{io::Cursor, AsyncReadExt};
702
703 #[tokio::test]
704 async fn discard_bytes_works() {
705 let bytes: Vec<u8> = (0..5).collect();
706 let mut cursor = Cursor::new(bytes);
707 discard_bytes(1_u64, &mut cursor).await.unwrap();
708 let mut read = vec![0; 4];
709 cursor.read_exact(&mut read).await.unwrap();
710 assert_eq!(read, vec![1, 2, 3, 4]);
711 }
712}