1use crate::data::{ByteSlice125, Data, Incoming};
13use crate::{
14 base::{self, Header, OpCode, MAX_HEADER_SIZE},
15 extension::Extension,
16 Parsing, Storage,
17};
18use bytes::{Buf, BytesMut};
19use futures::{
20 io::{ReadHalf, WriteHalf},
21 lock::BiLock,
22 prelude::*,
23};
24use std::{fmt, io, str};
25
26const MAX_MESSAGE_SIZE: usize = 256 * 1024 * 1024;
28
29const MAX_FRAME_SIZE: usize = MAX_MESSAGE_SIZE;
31
32#[derive(Copy, Clone, Debug, PartialEq, Eq)]
34pub enum Mode {
35 Client,
37 Server,
39}
40
41impl Mode {
42 pub fn is_client(self) -> bool {
43 if let Mode::Client = self {
44 true
45 } else {
46 false
47 }
48 }
49
50 pub fn is_server(self) -> bool {
51 !self.is_client()
52 }
53}
54
55#[derive(Clone, Copy, Debug)]
57struct Id(u32);
58
59impl fmt::Display for Id {
60 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
61 write!(f, "{:08x}", self.0)
62 }
63}
64
65#[derive(Debug)]
67pub struct Sender<T> {
68 id: Id,
69 mode: Mode,
70 codec: base::Codec,
71 writer: BiLock<WriteHalf<T>>,
72 mask_buffer: Vec<u8>,
73 extensions: BiLock<Vec<Box<dyn Extension + Send>>>,
74 has_extensions: bool,
75}
76
77#[derive(Debug)]
79pub struct Receiver<T> {
80 id: Id,
81 mode: Mode,
82 codec: base::Codec,
83 reader: ReadHalf<T>,
84 writer: BiLock<WriteHalf<T>>,
85 extensions: BiLock<Vec<Box<dyn Extension + Send>>>,
86 has_extensions: bool,
87 buffer: BytesMut,
88 ctrl_buffer: BytesMut,
89 max_message_size: usize,
90 is_closed: bool,
91}
92
93#[derive(Debug)]
99pub struct Builder<T> {
100 id: Id,
101 mode: Mode,
102 socket: T,
103 codec: base::Codec,
104 extensions: Vec<Box<dyn Extension + Send>>,
105 buffer: BytesMut,
106 max_message_size: usize,
107}
108
109impl<T: AsyncRead + AsyncWrite + Unpin> Builder<T> {
110 pub fn new(socket: T, mode: Mode) -> Self {
119 let mut codec = base::Codec::default();
120 codec.set_max_data_size(MAX_FRAME_SIZE);
121 Builder {
122 id: Id(rand::random()),
123 mode,
124 socket,
125 codec,
126 extensions: Vec::new(),
127 buffer: BytesMut::new(),
128 max_message_size: MAX_MESSAGE_SIZE,
129 }
130 }
131
132 pub fn set_buffer(&mut self, b: BytesMut) {
134 self.buffer = b
135 }
136
137 pub fn add_extensions<I>(&mut self, extensions: I)
141 where
142 I: IntoIterator<Item = Box<dyn Extension + Send>>,
143 {
144 for e in extensions.into_iter().filter(|e| e.is_enabled()) {
145 log::debug!("{}: using extension: {}", self.id, e.name());
146 self.codec.add_reserved_bits(e.reserved_bits());
147 self.extensions.push(e)
148 }
149 }
150
151 pub fn set_max_message_size(&mut self, max: usize) {
158 self.max_message_size = max
159 }
160
161 pub fn set_max_frame_size(&mut self, max: usize) {
163 self.codec.set_max_data_size(max);
164 }
165
166 pub fn finish(self) -> (Sender<T>, Receiver<T>) {
168 let (rhlf, whlf) = self.socket.split();
169 let (wrt1, wrt2) = BiLock::new(whlf);
170 let has_extensions = !self.extensions.is_empty();
171 let (ext1, ext2) = BiLock::new(self.extensions);
172
173 let recv = Receiver {
174 id: self.id,
175 mode: self.mode,
176 reader: rhlf,
177 writer: wrt1,
178 codec: self.codec.clone(),
179 extensions: ext1,
180 has_extensions,
181 buffer: self.buffer,
182 ctrl_buffer: BytesMut::new(),
183 max_message_size: self.max_message_size,
184 is_closed: false,
185 };
186
187 let send = Sender {
188 id: self.id,
189 mode: self.mode,
190 writer: wrt2,
191 mask_buffer: Vec::new(),
192 codec: self.codec,
193 extensions: ext2,
194 has_extensions,
195 };
196
197 (send, recv)
198 }
199}
200
201impl<T: AsyncRead + AsyncWrite + Unpin> Receiver<T> {
202 pub async fn receive(&mut self, message: &mut Vec<u8>) -> Result<Incoming<'_>, Error> {
213 let mut first_fragment_opcode = None;
214 let mut length: usize = 0;
215 let message_len = message.len();
216 loop {
217 if self.is_closed {
218 log::debug!("{}: cannot receive, connection is closed", self.id);
219 return Err(Error::Closed);
220 }
221
222 self.ctrl_buffer.clear();
223 let mut header = self.receive_header().await?;
224 log::trace!("{}: recv: {}", self.id, header);
225
226 if header.opcode().is_control() {
228 self.read_buffer(&header).await?;
229 self.ctrl_buffer = self.buffer.split_to(header.payload_len());
230 base::Codec::apply_mask(&header, &mut self.ctrl_buffer);
231 if header.opcode() == OpCode::Pong {
232 return Ok(Incoming::Pong(&self.ctrl_buffer[..]));
233 }
234 if let Some(close_reason) = self.on_control(&header).await? {
235 log::trace!("{}: recv, incoming CLOSE: {:?}", self.id, close_reason);
236 return Ok(Incoming::Closed(close_reason));
237 }
238 continue;
239 }
240
241 length = length.saturating_add(header.payload_len());
242
243 if length > self.max_message_size {
245 log::warn!("{}: accumulated message length exceeds maximum", self.id);
246
247 discard_bytes(length as u64, &mut self.reader).await?;
249 return Err(Error::MessageTooLarge { current: length, maximum: self.max_message_size });
250 }
251
252 {
254 let old_msg_len = message.len();
255
256 let bytes_to_read = {
257 let required = header.payload_len();
258 let buffered = self.buffer.len();
259
260 if buffered == 0 {
261 required
262 } else if required > buffered {
263 message.extend_from_slice(&self.buffer);
264 self.buffer.clear();
265 required - buffered
266 } else {
267 message.extend_from_slice(&self.buffer.split_to(required));
268 0
269 }
270 };
271
272 if bytes_to_read > 0 {
273 let n = message.len();
274 message.resize(n + bytes_to_read, 0u8);
275 self.reader.read_exact(&mut message[n..]).await?
276 }
277
278 debug_assert_eq!(header.payload_len(), message.len() - old_msg_len);
279
280 base::Codec::apply_mask(&header, &mut message[old_msg_len..]);
281 }
282
283 match (header.is_fin(), header.opcode()) {
284 (false, OpCode::Continue) => {
285 if first_fragment_opcode.is_none() {
287 log::debug!("{}: continue frame while not processing message fragments", self.id);
288 return Err(Error::UnexpectedOpCode(OpCode::Continue));
289 }
290 continue;
291 }
292 (false, oc) => {
293 if first_fragment_opcode.is_some() {
295 log::debug!("{}: initial fragment while processing a fragmented message", self.id);
296 return Err(Error::UnexpectedOpCode(oc));
297 }
298 first_fragment_opcode = Some(oc);
299 self.decode_with_extensions(&mut header, message).await?;
300 continue;
301 }
302 (true, OpCode::Continue) => {
303 if let Some(oc) = first_fragment_opcode.take() {
305 header.set_payload_len(message.len());
306 log::trace!("{}: last fragment: total length = {} bytes", self.id, message.len());
307 self.decode_with_extensions(&mut header, message).await?;
308 header.set_opcode(oc);
309 } else {
310 log::debug!("{}: last continue frame while not processing message fragments", self.id);
311 return Err(Error::UnexpectedOpCode(OpCode::Continue));
312 }
313 }
314 (true, oc) => {
315 if first_fragment_opcode.is_some() {
317 log::debug!("{}: regular message while processing fragmented message", self.id);
318 return Err(Error::UnexpectedOpCode(oc));
319 }
320 self.decode_with_extensions(&mut header, message).await?
321 }
322 }
323
324 let num_bytes = message.len() - message_len;
325
326 if header.opcode() == OpCode::Text {
327 return Ok(Incoming::Data(Data::Text(num_bytes)));
328 } else {
329 return Ok(Incoming::Data(Data::Binary(num_bytes)));
330 }
331 }
332 }
333
334 pub async fn receive_data(&mut self, message: &mut Vec<u8>) -> Result<Data, Error> {
336 loop {
337 if let Incoming::Data(d) = self.receive(message).await? {
338 return Ok(d);
339 }
340 }
341 }
342
343 async fn receive_header(&mut self) -> Result<Header, Error> {
345 loop {
346 match self.codec.decode_header(&self.buffer)? {
347 Parsing::Done { value: header, offset } => {
348 debug_assert!(offset <= MAX_HEADER_SIZE);
349 self.buffer.advance(offset);
350 return Ok(header);
351 }
352 Parsing::NeedMore(n) => crate::read(&mut self.reader, &mut self.buffer, n).await?,
353 }
354 }
355 }
356
357 async fn read_buffer(&mut self, header: &Header) -> Result<(), Error> {
359 if header.payload_len() <= self.buffer.len() {
360 return Ok(());
361 }
362 let i = self.buffer.len();
363 let d = header.payload_len() - i;
364 self.buffer.resize(i + d, 0u8);
365 self.reader.read_exact(&mut self.buffer[i..]).await?;
366 Ok(())
367 }
368
369 async fn on_control(&mut self, header: &Header) -> Result<Option<CloseReason>, Error> {
375 match header.opcode() {
376 OpCode::Ping => {
377 let mut answer = Header::new(OpCode::Pong);
378 let mut unused = Vec::new();
379 let mut data = Storage::Unique(&mut self.ctrl_buffer);
380 write(self.id, self.mode, &mut self.codec, &mut self.writer, &mut answer, &mut data, &mut unused)
381 .await?;
382 self.flush().await?;
383 Ok(None)
384 }
385 OpCode::Pong => Ok(None),
386 OpCode::Close => {
387 log::trace!("{}: Acknowledging CLOSE to sender", self.id);
388 let (mut header, reason) = close_answer(&self.ctrl_buffer)?;
389 let mut unused = Vec::new();
391 if let Some(CloseReason { code, .. }) = reason {
392 let mut data = code.to_be_bytes();
393 let mut data = Storage::Unique(&mut data);
394 let _ = write(
395 self.id,
396 self.mode,
397 &mut self.codec,
398 &mut self.writer,
399 &mut header,
400 &mut data,
401 &mut unused,
402 )
403 .await;
404 } else {
405 let mut data = Storage::Unique(&mut []);
406 let _ = write(
407 self.id,
408 self.mode,
409 &mut self.codec,
410 &mut self.writer,
411 &mut header,
412 &mut data,
413 &mut unused,
414 )
415 .await;
416 }
417 self.flush().await?;
418 _ = self.writer.lock().await.close().await;
421 self.is_closed = true;
422 Ok(reason)
423 }
424 OpCode::Binary
425 | OpCode::Text
426 | OpCode::Continue
427 | OpCode::Reserved3
428 | OpCode::Reserved4
429 | OpCode::Reserved5
430 | OpCode::Reserved6
431 | OpCode::Reserved7
432 | OpCode::Reserved11
433 | OpCode::Reserved12
434 | OpCode::Reserved13
435 | OpCode::Reserved14
436 | OpCode::Reserved15 => Err(Error::UnexpectedOpCode(header.opcode())),
437 }
438 }
439
440 async fn decode_with_extensions(&mut self, header: &mut Header, message: &mut Vec<u8>) -> Result<(), Error> {
442 if !self.has_extensions {
443 return Ok(());
444 }
445 for e in self.extensions.lock().await.iter_mut() {
446 log::trace!("{}: decoding with extension: {}", self.id, e.name());
447 e.decode(header, message).map_err(Error::Extension)?
448 }
449 Ok(())
450 }
451
452 async fn flush(&mut self) -> Result<(), Error> {
454 log::trace!("{}: Receiver flushing connection", self.id);
455 if self.is_closed {
456 return Ok(());
457 }
458 self.writer.lock().await.flush().await.or(Err(Error::Closed))
459 }
460}
461
462impl<T: AsyncRead + AsyncWrite + Unpin> Sender<T> {
463 pub async fn send_text(&mut self, data: impl AsRef<str>) -> Result<(), Error> {
465 let mut header = Header::new(OpCode::Text);
466 self.send_frame(&mut header, &mut Storage::Shared(data.as_ref().as_bytes())).await
467 }
468
469 pub async fn send_text_owned(&mut self, data: String) -> Result<(), Error> {
473 let mut header = Header::new(OpCode::Text);
474 self.send_frame(&mut header, &mut Storage::Owned(data.into_bytes())).await
475 }
476
477 pub async fn send_binary(&mut self, data: impl AsRef<[u8]>) -> Result<(), Error> {
479 let mut header = Header::new(OpCode::Binary);
480 self.send_frame(&mut header, &mut Storage::Shared(data.as_ref())).await
481 }
482
483 pub async fn send_binary_mut(&mut self, mut data: impl AsMut<[u8]>) -> Result<(), Error> {
488 let mut header = Header::new(OpCode::Binary);
489 self.send_frame(&mut header, &mut Storage::Unique(data.as_mut())).await
490 }
491
492 pub async fn send_ping(&mut self, data: ByteSlice125<'_>) -> Result<(), Error> {
494 let mut header = Header::new(OpCode::Ping);
495 self.write(&mut header, &mut Storage::Shared(data.as_ref())).await
496 }
497
498 pub async fn send_pong(&mut self, data: ByteSlice125<'_>) -> Result<(), Error> {
500 let mut header = Header::new(OpCode::Pong);
501 self.write(&mut header, &mut Storage::Shared(data.as_ref())).await
502 }
503
504 pub async fn flush(&mut self) -> Result<(), Error> {
506 log::trace!("{}: Sender flushing connection", self.id);
507 self.writer.lock().await.flush().await.or(Err(Error::Closed))
508 }
509
510 pub async fn close(&mut self) -> Result<(), Error> {
512 log::trace!("{}: closing connection", self.id);
513 let mut header = Header::new(OpCode::Close);
514 let code = 1000_u16.to_be_bytes(); self.write(&mut header, &mut Storage::Shared(&code[..])).await?;
516 self.flush().await?;
517 self.writer.lock().await.close().await.or(Err(Error::Closed))
518 }
519
520 async fn send_frame(&mut self, header: &mut Header, data: &mut Storage<'_>) -> Result<(), Error> {
524 if !self.has_extensions {
525 return self.write(header, data).await;
526 }
527
528 for e in self.extensions.lock().await.iter_mut() {
529 log::trace!("{}: encoding with extension: {}", self.id, e.name());
530 e.encode(header, data).map_err(Error::Extension)?
531 }
532
533 self.write(header, data).await
534 }
535
536 async fn write(&mut self, header: &mut Header, data: &mut Storage<'_>) -> Result<(), Error> {
541 write(self.id, self.mode, &mut self.codec, &mut self.writer, header, data, &mut self.mask_buffer).await
542 }
543}
544
545async fn write<T: AsyncWrite + Unpin>(
547 id: Id,
548 mode: Mode,
549 codec: &mut base::Codec,
550 writer: &mut BiLock<WriteHalf<T>>,
551 header: &mut Header,
552 data: &mut Storage<'_>,
553 mask_buffer: &mut Vec<u8>,
554) -> Result<(), Error> {
555 if mode.is_client() {
556 header.set_masked(true);
557 header.set_mask(rand::random());
558 }
559 header.set_payload_len(data.as_ref().len());
560
561 log::trace!("{}: send: {}", id, header);
562
563 let header_bytes = codec.encode_header(&header);
564 let mut w = writer.lock().await;
565 w.write_all(&header_bytes).await.or(Err(Error::Closed))?;
566
567 if !header.is_masked() {
568 return w.write_all(data.as_ref()).await.or(Err(Error::Closed));
569 }
570
571 match data {
572 Storage::Shared(slice) => {
573 mask_buffer.clear();
574 mask_buffer.extend_from_slice(slice);
575 base::Codec::apply_mask(header, mask_buffer);
576 w.write_all(mask_buffer).await.or(Err(Error::Closed))
577 }
578 Storage::Unique(slice) => {
579 base::Codec::apply_mask(header, slice);
580 w.write_all(slice).await.or(Err(Error::Closed))
581 }
582 Storage::Owned(ref mut bytes) => {
583 base::Codec::apply_mask(header, bytes);
584 w.write_all(bytes).await.or(Err(Error::Closed))
585 }
586 }
587}
588
589fn close_answer(data: &[u8]) -> Result<(Header, Option<CloseReason>), Error> {
592 let answer = Header::new(OpCode::Close);
593 if data.len() < 2 {
594 return Ok((answer, None));
595 }
596 let descr = std::str::from_utf8(&data[2..])?.into();
598 let code = u16::from_be_bytes([data[0], data[1]]);
599 let reason = CloseReason { code, descr: Some(descr) };
600
601 match code {
605 | 1000 ..= 1003
606 | 1007 ..= 1011
607 | 1012 | 1013 | 1015
610 | 3000 ..= 4999 => Ok((answer, Some(reason))), _ => {
612 Ok((answer, Some(CloseReason { code: 1002, descr: None})))
614 }
615 }
616}
617
618#[non_exhaustive]
620#[derive(Debug)]
621pub enum Error {
622 Io(io::Error),
624 Codec(base::Error),
626 Extension(crate::BoxedError),
628 UnexpectedOpCode(OpCode),
630 Utf8(str::Utf8Error),
632 MessageTooLarge { current: usize, maximum: usize },
634 Closed,
636}
637
638#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
640pub struct CloseReason {
641 pub code: u16,
642 pub descr: Option<String>,
643}
644
645impl fmt::Display for Error {
646 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
647 match self {
648 Error::Io(e) => write!(f, "i/o error: {}", e),
649 Error::Codec(e) => write!(f, "codec error: {}", e),
650 Error::Extension(e) => write!(f, "extension error: {}", e),
651 Error::UnexpectedOpCode(c) => write!(f, "unexpected opcode: {}", c),
652 Error::Utf8(e) => write!(f, "utf-8 error: {}", e),
653 Error::MessageTooLarge { current, maximum } => {
654 write!(f, "message too large: len >= {}, maximum = {}", current, maximum)
655 }
656 Error::Closed => f.write_str("connection closed"),
657 }
658 }
659}
660
661impl std::error::Error for Error {
662 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
663 match self {
664 Error::Io(e) => Some(e),
665 Error::Codec(e) => Some(e),
666 Error::Extension(e) => Some(&**e),
667 Error::Utf8(e) => Some(e),
668 Error::UnexpectedOpCode(_) | Error::MessageTooLarge { .. } | Error::Closed => None,
669 }
670 }
671}
672
673impl From<io::Error> for Error {
674 fn from(e: io::Error) -> Self {
675 if e.kind() == io::ErrorKind::UnexpectedEof {
676 Error::Closed
677 } else {
678 Error::Io(e)
679 }
680 }
681}
682
683impl From<str::Utf8Error> for Error {
684 fn from(e: str::Utf8Error) -> Self {
685 Error::Utf8(e)
686 }
687}
688
689impl From<base::Error> for Error {
690 fn from(e: base::Error) -> Self {
691 Error::Codec(e)
692 }
693}
694
695async fn discard_bytes<R: AsyncRead + Unpin>(n: u64, reader: R) -> Result<u64, io::Error> {
697 futures::io::copy(&mut reader.take(n), &mut futures::io::sink()).await
698}
699
700#[cfg(test)]
701mod tests {
702 use super::discard_bytes;
703 use futures::{io::Cursor, AsyncReadExt};
704
705 #[tokio::test]
706 async fn discard_bytes_works() {
707 let bytes: Vec<u8> = (0..5).collect();
708 let mut cursor = Cursor::new(bytes);
709 discard_bytes(1_u64, &mut cursor).await.unwrap();
710 let mut read = vec![0; 4];
711 cursor.read_exact(&mut read).await.unwrap();
712 assert_eq!(read, vec![1, 2, 3, 4]);
713 }
714}