1use crate::{as_u64, Parsing};
21use std::{fmt, io};
22
23pub(crate) const MAX_HEADER_SIZE: usize = 14;
25
26pub(crate) const MAX_CTRL_BODY_SIZE: u64 = 125;
28
29#[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Clone, Copy)]
33pub enum OpCode {
34 Continue,
36 Text,
38 Binary,
40 Close,
42 Ping,
44 Pong,
46 Reserved3,
48 Reserved4,
50 Reserved5,
52 Reserved6,
54 Reserved7,
56 Reserved11,
58 Reserved12,
60 Reserved13,
62 Reserved14,
64 Reserved15,
66}
67
68impl OpCode {
69 pub fn is_control(self) -> bool {
71 if let OpCode::Close | OpCode::Ping | OpCode::Pong = self {
72 true
73 } else {
74 false
75 }
76 }
77
78 pub fn is_reserved(self) -> bool {
80 match self {
81 OpCode::Reserved3
82 | OpCode::Reserved4
83 | OpCode::Reserved5
84 | OpCode::Reserved6
85 | OpCode::Reserved7
86 | OpCode::Reserved11
87 | OpCode::Reserved12
88 | OpCode::Reserved13
89 | OpCode::Reserved14
90 | OpCode::Reserved15 => true,
91 _ => false,
92 }
93 }
94}
95
96impl fmt::Display for OpCode {
97 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
98 match self {
99 OpCode::Continue => f.write_str("Continue"),
100 OpCode::Text => f.write_str("Text"),
101 OpCode::Binary => f.write_str("Binary"),
102 OpCode::Close => f.write_str("Close"),
103 OpCode::Ping => f.write_str("Ping"),
104 OpCode::Pong => f.write_str("Pong"),
105 OpCode::Reserved3 => f.write_str("Reserved:3"),
106 OpCode::Reserved4 => f.write_str("Reserved:4"),
107 OpCode::Reserved5 => f.write_str("Reserved:5"),
108 OpCode::Reserved6 => f.write_str("Reserved:6"),
109 OpCode::Reserved7 => f.write_str("Reserved:7"),
110 OpCode::Reserved11 => f.write_str("Reserved:11"),
111 OpCode::Reserved12 => f.write_str("Reserved:12"),
112 OpCode::Reserved13 => f.write_str("Reserved:13"),
113 OpCode::Reserved14 => f.write_str("Reserved:14"),
114 OpCode::Reserved15 => f.write_str("Reserved:15"),
115 }
116 }
117}
118
119#[derive(Clone, Debug)]
122pub struct UnknownOpCode(());
123
124impl fmt::Display for UnknownOpCode {
125 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
126 f.write_str("unknown opcode")
127 }
128}
129
130impl std::error::Error for UnknownOpCode {}
131
132impl TryFrom<u8> for OpCode {
133 type Error = UnknownOpCode;
134
135 fn try_from(val: u8) -> Result<OpCode, Self::Error> {
136 match val {
137 0 => Ok(OpCode::Continue),
138 1 => Ok(OpCode::Text),
139 2 => Ok(OpCode::Binary),
140 3 => Ok(OpCode::Reserved3),
141 4 => Ok(OpCode::Reserved4),
142 5 => Ok(OpCode::Reserved5),
143 6 => Ok(OpCode::Reserved6),
144 7 => Ok(OpCode::Reserved7),
145 8 => Ok(OpCode::Close),
146 9 => Ok(OpCode::Ping),
147 10 => Ok(OpCode::Pong),
148 11 => Ok(OpCode::Reserved11),
149 12 => Ok(OpCode::Reserved12),
150 13 => Ok(OpCode::Reserved13),
151 14 => Ok(OpCode::Reserved14),
152 15 => Ok(OpCode::Reserved15),
153 _ => Err(UnknownOpCode(())),
154 }
155 }
156}
157
158impl From<OpCode> for u8 {
159 fn from(opcode: OpCode) -> u8 {
160 match opcode {
161 OpCode::Continue => 0,
162 OpCode::Text => 1,
163 OpCode::Binary => 2,
164 OpCode::Close => 8,
165 OpCode::Ping => 9,
166 OpCode::Pong => 10,
167 OpCode::Reserved3 => 3,
168 OpCode::Reserved4 => 4,
169 OpCode::Reserved5 => 5,
170 OpCode::Reserved6 => 6,
171 OpCode::Reserved7 => 7,
172 OpCode::Reserved11 => 11,
173 OpCode::Reserved12 => 12,
174 OpCode::Reserved13 => 13,
175 OpCode::Reserved14 => 14,
176 OpCode::Reserved15 => 15,
177 }
178 }
179}
180
181#[derive(Debug, Clone)]
185pub struct Header {
186 fin: bool,
187 rsv1: bool,
188 rsv2: bool,
189 rsv3: bool,
190 masked: bool,
191 opcode: OpCode,
192 mask: u32,
193 payload_len: usize,
194}
195
196impl fmt::Display for Header {
197 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
198 write!(
199 f,
200 "({} (fin {}) (rsv {}{}{}) (mask ({} {:x})) (len {}))",
201 self.opcode,
202 self.fin as u8,
203 self.rsv1 as u8,
204 self.rsv2 as u8,
205 self.rsv3 as u8,
206 self.masked as u8,
207 self.mask,
208 self.payload_len
209 )
210 }
211}
212
213impl Header {
214 pub fn new(oc: OpCode) -> Self {
216 Header { fin: true, rsv1: false, rsv2: false, rsv3: false, masked: false, opcode: oc, mask: 0, payload_len: 0 }
217 }
218
219 pub fn is_fin(&self) -> bool {
221 self.fin
222 }
223
224 pub fn set_fin(&mut self, fin: bool) -> &mut Self {
226 self.fin = fin;
227 self
228 }
229
230 pub fn is_rsv1(&self) -> bool {
232 self.rsv1
233 }
234
235 pub fn set_rsv1(&mut self, rsv1: bool) -> &mut Self {
237 self.rsv1 = rsv1;
238 self
239 }
240
241 pub fn is_rsv2(&self) -> bool {
243 self.rsv2
244 }
245
246 pub fn set_rsv2(&mut self, rsv2: bool) -> &mut Self {
248 self.rsv2 = rsv2;
249 self
250 }
251
252 pub fn is_rsv3(&self) -> bool {
254 self.rsv3
255 }
256
257 pub fn set_rsv3(&mut self, rsv3: bool) -> &mut Self {
259 self.rsv3 = rsv3;
260 self
261 }
262
263 pub fn is_masked(&self) -> bool {
265 self.masked
266 }
267
268 pub fn set_masked(&mut self, masked: bool) -> &mut Self {
270 self.masked = masked;
271 self
272 }
273
274 pub fn opcode(&self) -> OpCode {
276 self.opcode
277 }
278
279 pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Self {
281 self.opcode = opcode;
282 self
283 }
284
285 pub fn mask(&self) -> u32 {
287 self.mask
288 }
289
290 pub fn set_mask(&mut self, mask: u32) -> &mut Self {
292 self.mask = mask;
293 self
294 }
295
296 pub fn payload_len(&self) -> usize {
298 self.payload_len
299 }
300
301 pub fn set_payload_len(&mut self, len: usize) -> &mut Self {
303 self.payload_len = len;
304 self
305 }
306}
307
308const TWO_EXT: u8 = 126;
313
314const EIGHT_EXT: u8 = 127;
317
318#[derive(Debug, Clone)]
322pub struct Codec {
323 max_data_size: usize,
325 reserved_bits: u8,
327 header_buffer: [u8; MAX_HEADER_SIZE],
329}
330
331impl Default for Codec {
332 fn default() -> Self {
333 Codec { max_data_size: 256 * 1024 * 1024, reserved_bits: 0, header_buffer: [0; MAX_HEADER_SIZE] }
334 }
335}
336
337impl Codec {
338 pub fn new() -> Self {
343 Codec::default()
344 }
345
346 pub fn max_data_size(&self) -> usize {
348 self.max_data_size
349 }
350
351 pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
353 self.max_data_size = size;
354 self
355 }
356
357 pub fn reserved_bits(&self) -> (bool, bool, bool) {
359 let r = self.reserved_bits;
360 (r & 4 == 4, r & 2 == 2, r & 1 == 1)
361 }
362
363 pub fn add_reserved_bits(&mut self, bits: (bool, bool, bool)) -> &mut Self {
365 let (r1, r2, r3) = bits;
366 self.reserved_bits |= (r1 as u8) << 2 | (r2 as u8) << 1 | r3 as u8;
367 self
368 }
369
370 pub fn clear_reserved_bits(&mut self) {
372 self.reserved_bits = 0
373 }
374
375 pub fn decode_header(&self, bytes: &[u8]) -> Result<Parsing<Header, usize>, Error> {
377 if bytes.len() < 2 {
378 return Ok(Parsing::NeedMore(2 - bytes.len()));
379 }
380
381 let first = bytes[0];
382 let second = bytes[1];
383 let mut offset = 2;
384
385 let fin = first & 0x80 != 0;
386 let opcode = OpCode::try_from(first & 0xF)?;
387
388 if opcode.is_reserved() {
389 return Err(Error::ReservedOpCode);
390 }
391
392 if opcode.is_control() && !fin {
393 return Err(Error::FragmentedControl);
394 }
395
396 let mut header = Header::new(opcode);
397 header.set_fin(fin);
398
399 let rsv1 = first & 0x40 != 0;
400 if rsv1 && (self.reserved_bits & 4 == 0) {
401 return Err(Error::InvalidReservedBit(1));
402 }
403 header.set_rsv1(rsv1);
404
405 let rsv2 = first & 0x20 != 0;
406 if rsv2 && (self.reserved_bits & 2 == 0) {
407 return Err(Error::InvalidReservedBit(2));
408 }
409 header.set_rsv2(rsv2);
410
411 let rsv3 = first & 0x10 != 0;
412 if rsv3 && (self.reserved_bits & 1 == 0) {
413 return Err(Error::InvalidReservedBit(3));
414 }
415 header.set_rsv3(rsv3);
416 header.set_masked(second & 0x80 != 0);
417
418 let len: u64 = match second & 0x7F {
419 TWO_EXT => {
420 if bytes.len() < offset + 2 {
421 return Ok(Parsing::NeedMore(offset + 2 - bytes.len()));
422 }
423 let len = u16::from_be_bytes([bytes[offset], bytes[offset + 1]]);
424 offset += 2;
425 u64::from(len)
426 }
427 EIGHT_EXT => {
428 if bytes.len() < offset + 8 {
429 return Ok(Parsing::NeedMore(offset + 8 - bytes.len()));
430 }
431 let mut b = [0; 8];
432 b.copy_from_slice(&bytes[offset..offset + 8]);
433 offset += 8;
434 u64::from_be_bytes(b)
435 }
436 n => u64::from(n),
437 };
438
439 if len > MAX_CTRL_BODY_SIZE && header.opcode().is_control() {
440 return Err(Error::InvalidControlFrameLen);
441 }
442
443 let len: usize = if len > as_u64(self.max_data_size) {
444 return Err(Error::PayloadTooLarge { actual: len, maximum: as_u64(self.max_data_size) });
445 } else {
446 len as usize
447 };
448
449 header.set_payload_len(len);
450
451 if header.is_masked() {
452 if bytes.len() < offset + 4 {
453 return Ok(Parsing::NeedMore(offset + 4 - bytes.len()));
454 }
455 let mut b = [0; 4];
456 b.copy_from_slice(&bytes[offset..offset + 4]);
457 offset += 4;
458 header.set_mask(u32::from_be_bytes(b));
459 }
460
461 Ok(Parsing::Done { value: header, offset })
462 }
463
464 pub fn encode_header(&mut self, header: &Header) -> &[u8] {
466 let mut offset = 0;
467
468 let mut first_byte = 0_u8;
469 if header.is_fin() {
470 first_byte |= 0x80
471 }
472 if header.is_rsv1() {
473 first_byte |= 0x40
474 }
475 if header.is_rsv2() {
476 first_byte |= 0x20
477 }
478 if header.is_rsv3() {
479 first_byte |= 0x10
480 }
481
482 let opcode: u8 = header.opcode().into();
483 first_byte |= opcode;
484
485 self.header_buffer[offset] = first_byte;
486 offset += 1;
487
488 let mut second_byte = 0_u8;
489 if header.is_masked() {
490 second_byte |= 0x80
491 }
492
493 let len = header.payload_len();
494
495 if len < usize::from(TWO_EXT) {
496 second_byte |= len as u8;
497 self.header_buffer[offset] = second_byte;
498 offset += 1;
499 } else if len <= usize::from(u16::max_value()) {
500 second_byte |= TWO_EXT;
501 self.header_buffer[offset] = second_byte;
502 offset += 1;
503 self.header_buffer[offset..offset + 2].copy_from_slice(&(len as u16).to_be_bytes());
504 offset += 2;
505 } else {
506 second_byte |= EIGHT_EXT;
507 self.header_buffer[offset] = second_byte;
508 offset += 1;
509 self.header_buffer[offset..offset + 8].copy_from_slice(&as_u64(len).to_be_bytes());
510 offset += 8;
511 }
512
513 if header.is_masked() {
514 self.header_buffer[offset..offset + 4].copy_from_slice(&header.mask().to_be_bytes());
515 offset += 4;
516 }
517
518 &self.header_buffer[..offset]
519 }
520
521 pub fn apply_mask(header: &Header, data: &mut [u8]) {
523 if header.is_masked() {
524 let mask = header.mask().to_be_bytes();
525 for (byte, &key) in data.iter_mut().zip(mask.iter().cycle()) {
526 *byte ^= key;
527 }
528 }
529 }
530}
531
532#[non_exhaustive]
534#[derive(Debug)]
535pub enum Error {
536 Io(io::Error),
538 UnknownOpCode,
540 ReservedOpCode,
542 FragmentedControl,
544 InvalidControlFrameLen,
546 InvalidReservedBit(u8),
548 PayloadTooLarge { actual: u64, maximum: u64 },
550}
551
552impl fmt::Display for Error {
553 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
554 match self {
555 Error::Io(e) => write!(f, "i/o error: {}", e),
556 Error::UnknownOpCode => f.write_str("unknown opcode"),
557 Error::ReservedOpCode => f.write_str("reserved opcode"),
558 Error::FragmentedControl => f.write_str("fragmented control frame"),
559 Error::InvalidControlFrameLen => f.write_str("invalid control frame length"),
560 Error::InvalidReservedBit(n) => write!(f, "invalid reserved bit: {}", n),
561 Error::PayloadTooLarge { actual, maximum } => {
562 write!(f, "payload too large: len = {}, maximum = {}", actual, maximum)
563 }
564 }
565 }
566}
567
568impl std::error::Error for Error {
569 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
570 match self {
571 Error::Io(e) => Some(e),
572 Error::UnknownOpCode
573 | Error::ReservedOpCode
574 | Error::FragmentedControl
575 | Error::InvalidControlFrameLen
576 | Error::InvalidReservedBit(_)
577 | Error::PayloadTooLarge { .. } => None,
578 }
579 }
580}
581
582impl From<io::Error> for Error {
583 fn from(e: io::Error) -> Self {
584 Error::Io(e)
585 }
586}
587
588impl From<UnknownOpCode> for Error {
589 fn from(_: UnknownOpCode) -> Self {
590 Error::UnknownOpCode
591 }
592}
593
594#[cfg(test)]
597mod test {
598 use super::{Codec, Error, OpCode};
599 use crate::Parsing;
600 use quickcheck::QuickCheck;
601
602 #[test]
603 fn decode_partial_header() {
604 let partial_header: &[u8] = &[0x89];
605 assert!(matches! {
606 Codec::new().decode_header(partial_header),
607 Ok(Parsing::NeedMore(1))
608 })
609 }
610
611 #[test]
612 fn decode_partial_len() {
613 let partial_length_1: &[u8] = &[0x89, 0xFE, 0x01];
614 assert!(matches! {
615 Codec::new().decode_header(partial_length_1),
616 Ok(Parsing::NeedMore(1))
617 });
618 let partial_length_2: &[u8] = &[0x89, 0xFF, 0x01, 0x02, 0x03, 0x04];
619 assert!(matches! {
620 Codec::new().decode_header(partial_length_2),
621 Ok(Parsing::NeedMore(4))
622 })
623 }
624
625 #[test]
626 fn decode_partial_mask() {
627 let partial_mask: &[u8] = &[0x82, 0xFE, 0x01, 0x02, 0x00, 0x00];
628 assert!(matches! {
629 Codec::new().decode_header(partial_mask),
630 Ok(Parsing::NeedMore(2))
631 })
632 }
633
634 #[test]
635 fn decode_partial_payload() {
636 let partial_payload: &mut [u8] = &mut [0x82, 0x85, 0x01, 0x02, 0x03, 0x04, 0x00, 0x00];
637 if let Ok(Parsing::Done { value, offset }) = Codec::new().decode_header(partial_payload) {
638 assert_eq!(3, value.payload_len() - (partial_payload.len() - offset))
639 } else {
640 assert!(false)
641 }
642 }
643
644 #[test]
645 fn decode_invalid_control_payload_len() {
646 let ctrl_payload_len: &[u8] = &[0x89, 0xFE, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
648 assert!(matches! {
649 Codec::new().decode_header(ctrl_payload_len),
650 Err(Error::InvalidControlFrameLen)
651 })
652 }
653
654 #[test]
656 fn decode_reserved() {
657 let reserved = [0x90, 0xa0, 0xc0];
659 for res in &reserved {
660 let mut buf = [0; 2];
661 buf[0] |= *res;
662 assert!(matches! {
663 Codec::new().decode_header(&buf),
664 Err(Error::InvalidReservedBit(_))
665 })
666 }
667 }
668
669 #[test]
671 fn decode_fragmented_control() {
672 let second_bytes = [8, 9, 10];
673 for sb in &second_bytes {
674 let mut buf = [0; 2];
675 buf[0] |= *sb;
676 assert!(matches! {
677 Codec::new().decode_header(&buf),
678 Err(Error::FragmentedControl)
679 })
680 }
681 }
682
683 #[test]
685 fn decode_reserved_opcodes() {
686 let reserved = [3, 4, 5, 6, 7, 11, 12, 13, 14, 15];
687 for res in &reserved {
688 let mut buf = [0; 2];
689 buf[0] |= 0x80 | *res;
690 assert!(matches! {
691 Codec::new().decode_header(&buf),
692 Err(Error::ReservedOpCode)
693 })
694 }
695 }
696
697 #[test]
698 fn decode_ping_no_data() {
699 let ping_no_data: &mut [u8] = &mut [0x89, 0x80, 0x00, 0x00, 0x00, 0x01];
700 let c = Codec::new();
701 if let Ok(Parsing::Done { value: header, .. }) = c.decode_header(ping_no_data) {
702 assert!(header.is_fin());
703 assert!(!header.is_rsv1());
704 assert!(!header.is_rsv2());
705 assert!(!header.is_rsv3());
706 assert!(header.opcode() == OpCode::Ping);
707 assert!(header.payload_len() == 0)
708 } else {
709 assert!(false)
710 }
711 }
712
713 #[test]
714 fn reserved_bits() {
715 fn property(bits: (bool, bool, bool)) -> bool {
716 let mut c = Codec::new();
717 assert_eq!((false, false, false), c.reserved_bits());
718 c.add_reserved_bits(bits);
719 bits == c.reserved_bits()
720 }
721 QuickCheck::new().quickcheck(property as fn((bool, bool, bool)) -> bool)
722 }
723}