soketto/
base.rs

1// Copyright (c) 2019 Parity Technologies (UK) Ltd.
2// Copyright (c) 2016 twist developers
3//
4// Licensed under the Apache License, Version 2.0
5// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
6// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
7// option. All files in the project carrying such notice may not be copied,
8// modified, or distributed except according to those terms.
9
10// This file is largely based on the original twist implementation.
11// See [frame/base.rs] and [codec/base.rs].
12//
13// [frame/base.rs]: https://github.com/rustyhorde/twist/blob/449d8b75c2/src/frame/base.rs
14// [codec/base.rs]: https://github.com/rustyhorde/twist/blob/449d8b75c2/src/codec/base.rs
15
16//! A websocket [base frame][base] codec.
17//!
18//! [base]: https://tools.ietf.org/html/rfc6455#section-5.2
19
20use crate::{as_u64, Parsing};
21use std::{fmt, io};
22
23/// Max. size of a frame header.
24pub(crate) const MAX_HEADER_SIZE: usize = 14;
25
26/// Max. size of a control frame payload.
27pub(crate) const MAX_CTRL_BODY_SIZE: u64 = 125;
28
29// OpCode /////////////////////////////////////////////////////////////////////////////////////////
30
31/// Operation codes defined in [RFC 6455](https://tools.ietf.org/html/rfc6455#section-5.2).
32#[derive(Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Clone, Copy)]
33pub enum OpCode {
34	/// A continuation frame of a fragmented message.
35	Continue,
36	/// A text data frame.
37	Text,
38	/// A binary data frame.
39	Binary,
40	/// A close control frame.
41	Close,
42	/// A ping control frame.
43	Ping,
44	/// A pong control frame.
45	Pong,
46	/// A reserved op code.
47	Reserved3,
48	/// A reserved op code.
49	Reserved4,
50	/// A reserved op code.
51	Reserved5,
52	/// A reserved op code.
53	Reserved6,
54	/// A reserved op code.
55	Reserved7,
56	/// A reserved op code.
57	Reserved11,
58	/// A reserved op code.
59	Reserved12,
60	/// A reserved op code.
61	Reserved13,
62	/// A reserved op code.
63	Reserved14,
64	/// A reserved op code.
65	Reserved15,
66}
67
68impl OpCode {
69	/// Is this a control opcode?
70	pub fn is_control(self) -> bool {
71		if let OpCode::Close | OpCode::Ping | OpCode::Pong = self {
72			true
73		} else {
74			false
75		}
76	}
77
78	/// Is this opcode reserved?
79	pub fn is_reserved(self) -> bool {
80		match self {
81			OpCode::Reserved3
82			| OpCode::Reserved4
83			| OpCode::Reserved5
84			| OpCode::Reserved6
85			| OpCode::Reserved7
86			| OpCode::Reserved11
87			| OpCode::Reserved12
88			| OpCode::Reserved13
89			| OpCode::Reserved14
90			| OpCode::Reserved15 => true,
91			_ => false,
92		}
93	}
94}
95
96impl fmt::Display for OpCode {
97	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
98		match self {
99			OpCode::Continue => f.write_str("Continue"),
100			OpCode::Text => f.write_str("Text"),
101			OpCode::Binary => f.write_str("Binary"),
102			OpCode::Close => f.write_str("Close"),
103			OpCode::Ping => f.write_str("Ping"),
104			OpCode::Pong => f.write_str("Pong"),
105			OpCode::Reserved3 => f.write_str("Reserved:3"),
106			OpCode::Reserved4 => f.write_str("Reserved:4"),
107			OpCode::Reserved5 => f.write_str("Reserved:5"),
108			OpCode::Reserved6 => f.write_str("Reserved:6"),
109			OpCode::Reserved7 => f.write_str("Reserved:7"),
110			OpCode::Reserved11 => f.write_str("Reserved:11"),
111			OpCode::Reserved12 => f.write_str("Reserved:12"),
112			OpCode::Reserved13 => f.write_str("Reserved:13"),
113			OpCode::Reserved14 => f.write_str("Reserved:14"),
114			OpCode::Reserved15 => f.write_str("Reserved:15"),
115		}
116	}
117}
118
119/// Error returned by `OpCode::try_from` if an unknown opcode
120/// number is encountered.
121#[derive(Clone, Debug)]
122pub struct UnknownOpCode(());
123
124impl fmt::Display for UnknownOpCode {
125	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
126		f.write_str("unknown opcode")
127	}
128}
129
130impl std::error::Error for UnknownOpCode {}
131
132impl TryFrom<u8> for OpCode {
133	type Error = UnknownOpCode;
134
135	fn try_from(val: u8) -> Result<OpCode, Self::Error> {
136		match val {
137			0 => Ok(OpCode::Continue),
138			1 => Ok(OpCode::Text),
139			2 => Ok(OpCode::Binary),
140			3 => Ok(OpCode::Reserved3),
141			4 => Ok(OpCode::Reserved4),
142			5 => Ok(OpCode::Reserved5),
143			6 => Ok(OpCode::Reserved6),
144			7 => Ok(OpCode::Reserved7),
145			8 => Ok(OpCode::Close),
146			9 => Ok(OpCode::Ping),
147			10 => Ok(OpCode::Pong),
148			11 => Ok(OpCode::Reserved11),
149			12 => Ok(OpCode::Reserved12),
150			13 => Ok(OpCode::Reserved13),
151			14 => Ok(OpCode::Reserved14),
152			15 => Ok(OpCode::Reserved15),
153			_ => Err(UnknownOpCode(())),
154		}
155	}
156}
157
158impl From<OpCode> for u8 {
159	fn from(opcode: OpCode) -> u8 {
160		match opcode {
161			OpCode::Continue => 0,
162			OpCode::Text => 1,
163			OpCode::Binary => 2,
164			OpCode::Close => 8,
165			OpCode::Ping => 9,
166			OpCode::Pong => 10,
167			OpCode::Reserved3 => 3,
168			OpCode::Reserved4 => 4,
169			OpCode::Reserved5 => 5,
170			OpCode::Reserved6 => 6,
171			OpCode::Reserved7 => 7,
172			OpCode::Reserved11 => 11,
173			OpCode::Reserved12 => 12,
174			OpCode::Reserved13 => 13,
175			OpCode::Reserved14 => 14,
176			OpCode::Reserved15 => 15,
177		}
178	}
179}
180
181// Frame header ///////////////////////////////////////////////////////////////////////////////////
182
183/// A websocket base frame header, i.e. everything but the payload.
184#[derive(Debug, Clone)]
185pub struct Header {
186	fin: bool,
187	rsv1: bool,
188	rsv2: bool,
189	rsv3: bool,
190	masked: bool,
191	opcode: OpCode,
192	mask: u32,
193	payload_len: usize,
194}
195
196impl fmt::Display for Header {
197	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
198		write!(
199			f,
200			"({} (fin {}) (rsv {}{}{}) (mask ({} {:x})) (len {}))",
201			self.opcode,
202			self.fin as u8,
203			self.rsv1 as u8,
204			self.rsv2 as u8,
205			self.rsv3 as u8,
206			self.masked as u8,
207			self.mask,
208			self.payload_len
209		)
210	}
211}
212
213impl Header {
214	/// Create a new frame header with a given [`OpCode`].
215	pub fn new(oc: OpCode) -> Self {
216		Header { fin: true, rsv1: false, rsv2: false, rsv3: false, masked: false, opcode: oc, mask: 0, payload_len: 0 }
217	}
218
219	/// Is the `fin` flag set?
220	pub fn is_fin(&self) -> bool {
221		self.fin
222	}
223
224	/// Set the `fin` flag.
225	pub fn set_fin(&mut self, fin: bool) -> &mut Self {
226		self.fin = fin;
227		self
228	}
229
230	/// Is the `rsv1` flag set?
231	pub fn is_rsv1(&self) -> bool {
232		self.rsv1
233	}
234
235	/// Set the `rsv1` flag.
236	pub fn set_rsv1(&mut self, rsv1: bool) -> &mut Self {
237		self.rsv1 = rsv1;
238		self
239	}
240
241	/// Is the `rsv2` flag set?
242	pub fn is_rsv2(&self) -> bool {
243		self.rsv2
244	}
245
246	/// Set the `rsv2` flag.
247	pub fn set_rsv2(&mut self, rsv2: bool) -> &mut Self {
248		self.rsv2 = rsv2;
249		self
250	}
251
252	/// Is the `rsv3` flag set?
253	pub fn is_rsv3(&self) -> bool {
254		self.rsv3
255	}
256
257	/// Set the `rsv3` flag.
258	pub fn set_rsv3(&mut self, rsv3: bool) -> &mut Self {
259		self.rsv3 = rsv3;
260		self
261	}
262
263	/// Is the `masked` flag set?
264	pub fn is_masked(&self) -> bool {
265		self.masked
266	}
267
268	/// Set the `masked` flag.
269	pub fn set_masked(&mut self, masked: bool) -> &mut Self {
270		self.masked = masked;
271		self
272	}
273
274	/// Get the `opcode`.
275	pub fn opcode(&self) -> OpCode {
276		self.opcode
277	}
278
279	/// Set the `opcode`
280	pub fn set_opcode(&mut self, opcode: OpCode) -> &mut Self {
281		self.opcode = opcode;
282		self
283	}
284
285	/// Get the `mask`.
286	pub fn mask(&self) -> u32 {
287		self.mask
288	}
289
290	/// Set the `mask`
291	pub fn set_mask(&mut self, mask: u32) -> &mut Self {
292		self.mask = mask;
293		self
294	}
295
296	/// Get the payload length.
297	pub fn payload_len(&self) -> usize {
298		self.payload_len
299	}
300
301	/// Set the payload length.
302	pub fn set_payload_len(&mut self, len: usize) -> &mut Self {
303		self.payload_len = len;
304		self
305	}
306}
307
308// Base codec ////////////////////////////////////////////////////////////////////////////////////.
309
310/// If the payload length byte is 126, the following two bytes represent the
311/// actual payload length.
312const TWO_EXT: u8 = 126;
313
314/// If the payload length byte is 127, the following eight bytes represent
315/// the actual payload length.
316const EIGHT_EXT: u8 = 127;
317
318/// Codec for encoding/decoding websocket [base] frames.
319///
320/// [base]: https://tools.ietf.org/html/rfc6455#section-5.2
321#[derive(Debug, Clone)]
322pub struct Codec {
323	/// Maximum size of payload data per frame.
324	max_data_size: usize,
325	/// Bits reserved by an extension.
326	reserved_bits: u8,
327	/// Scratch buffer used during header encoding.
328	header_buffer: [u8; MAX_HEADER_SIZE],
329}
330
331impl Default for Codec {
332	fn default() -> Self {
333		Codec { max_data_size: 256 * 1024 * 1024, reserved_bits: 0, header_buffer: [0; MAX_HEADER_SIZE] }
334	}
335}
336
337impl Codec {
338	/// Create a new base frame codec.
339	///
340	/// The codec will support decoding payload lengths up to 256 MiB
341	/// (use `set_max_data_size` to change this value).
342	pub fn new() -> Self {
343		Codec::default()
344	}
345
346	/// Get the configured maximum payload length.
347	pub fn max_data_size(&self) -> usize {
348		self.max_data_size
349	}
350
351	/// Limit the maximum size of payload data to `size` bytes.
352	pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
353		self.max_data_size = size;
354		self
355	}
356
357	/// The reserved bits currently configured.
358	pub fn reserved_bits(&self) -> (bool, bool, bool) {
359		let r = self.reserved_bits;
360		(r & 4 == 4, r & 2 == 2, r & 1 == 1)
361	}
362
363	/// Add to the reserved bits in use.
364	pub fn add_reserved_bits(&mut self, bits: (bool, bool, bool)) -> &mut Self {
365		let (r1, r2, r3) = bits;
366		self.reserved_bits |= (r1 as u8) << 2 | (r2 as u8) << 1 | r3 as u8;
367		self
368	}
369
370	/// Reset the reserved bits.
371	pub fn clear_reserved_bits(&mut self) {
372		self.reserved_bits = 0
373	}
374
375	/// Decode a websocket frame header.
376	pub fn decode_header(&self, bytes: &[u8]) -> Result<Parsing<Header, usize>, Error> {
377		if bytes.len() < 2 {
378			return Ok(Parsing::NeedMore(2 - bytes.len()));
379		}
380
381		let first = bytes[0];
382		let second = bytes[1];
383		let mut offset = 2;
384
385		let fin = first & 0x80 != 0;
386		let opcode = OpCode::try_from(first & 0xF)?;
387
388		if opcode.is_reserved() {
389			return Err(Error::ReservedOpCode);
390		}
391
392		if opcode.is_control() && !fin {
393			return Err(Error::FragmentedControl);
394		}
395
396		let mut header = Header::new(opcode);
397		header.set_fin(fin);
398
399		let rsv1 = first & 0x40 != 0;
400		if rsv1 && (self.reserved_bits & 4 == 0) {
401			return Err(Error::InvalidReservedBit(1));
402		}
403		header.set_rsv1(rsv1);
404
405		let rsv2 = first & 0x20 != 0;
406		if rsv2 && (self.reserved_bits & 2 == 0) {
407			return Err(Error::InvalidReservedBit(2));
408		}
409		header.set_rsv2(rsv2);
410
411		let rsv3 = first & 0x10 != 0;
412		if rsv3 && (self.reserved_bits & 1 == 0) {
413			return Err(Error::InvalidReservedBit(3));
414		}
415		header.set_rsv3(rsv3);
416		header.set_masked(second & 0x80 != 0);
417
418		let len: u64 = match second & 0x7F {
419			TWO_EXT => {
420				if bytes.len() < offset + 2 {
421					return Ok(Parsing::NeedMore(offset + 2 - bytes.len()));
422				}
423				let len = u16::from_be_bytes([bytes[offset], bytes[offset + 1]]);
424				offset += 2;
425				u64::from(len)
426			}
427			EIGHT_EXT => {
428				if bytes.len() < offset + 8 {
429					return Ok(Parsing::NeedMore(offset + 8 - bytes.len()));
430				}
431				let mut b = [0; 8];
432				b.copy_from_slice(&bytes[offset..offset + 8]);
433				offset += 8;
434				u64::from_be_bytes(b)
435			}
436			n => u64::from(n),
437		};
438
439		if len > MAX_CTRL_BODY_SIZE && header.opcode().is_control() {
440			return Err(Error::InvalidControlFrameLen);
441		}
442
443		let len: usize = if len > as_u64(self.max_data_size) {
444			return Err(Error::PayloadTooLarge { actual: len, maximum: as_u64(self.max_data_size) });
445		} else {
446			len as usize
447		};
448
449		header.set_payload_len(len);
450
451		if header.is_masked() {
452			if bytes.len() < offset + 4 {
453				return Ok(Parsing::NeedMore(offset + 4 - bytes.len()));
454			}
455			let mut b = [0; 4];
456			b.copy_from_slice(&bytes[offset..offset + 4]);
457			offset += 4;
458			header.set_mask(u32::from_be_bytes(b));
459		}
460
461		Ok(Parsing::Done { value: header, offset })
462	}
463
464	/// Encode a websocket frame header.
465	pub fn encode_header(&mut self, header: &Header) -> &[u8] {
466		let mut offset = 0;
467
468		let mut first_byte = 0_u8;
469		if header.is_fin() {
470			first_byte |= 0x80
471		}
472		if header.is_rsv1() {
473			first_byte |= 0x40
474		}
475		if header.is_rsv2() {
476			first_byte |= 0x20
477		}
478		if header.is_rsv3() {
479			first_byte |= 0x10
480		}
481
482		let opcode: u8 = header.opcode().into();
483		first_byte |= opcode;
484
485		self.header_buffer[offset] = first_byte;
486		offset += 1;
487
488		let mut second_byte = 0_u8;
489		if header.is_masked() {
490			second_byte |= 0x80
491		}
492
493		let len = header.payload_len();
494
495		if len < usize::from(TWO_EXT) {
496			second_byte |= len as u8;
497			self.header_buffer[offset] = second_byte;
498			offset += 1;
499		} else if len <= usize::from(u16::max_value()) {
500			second_byte |= TWO_EXT;
501			self.header_buffer[offset] = second_byte;
502			offset += 1;
503			self.header_buffer[offset..offset + 2].copy_from_slice(&(len as u16).to_be_bytes());
504			offset += 2;
505		} else {
506			second_byte |= EIGHT_EXT;
507			self.header_buffer[offset] = second_byte;
508			offset += 1;
509			self.header_buffer[offset..offset + 8].copy_from_slice(&as_u64(len).to_be_bytes());
510			offset += 8;
511		}
512
513		if header.is_masked() {
514			self.header_buffer[offset..offset + 4].copy_from_slice(&header.mask().to_be_bytes());
515			offset += 4;
516		}
517
518		&self.header_buffer[..offset]
519	}
520
521	/// Use the given header's mask and apply it to the data.
522	pub fn apply_mask(header: &Header, data: &mut [u8]) {
523		if header.is_masked() {
524			let mask = header.mask().to_be_bytes();
525			for (byte, &key) in data.iter_mut().zip(mask.iter().cycle()) {
526				*byte ^= key;
527			}
528		}
529	}
530}
531
532/// Error cases the base frame decoder may encounter.
533#[non_exhaustive]
534#[derive(Debug)]
535pub enum Error {
536	/// An I/O error has been encountered.
537	Io(io::Error),
538	/// Some unknown opcode number has been decoded.
539	UnknownOpCode,
540	/// The opcode decoded is reserved.
541	ReservedOpCode,
542	/// A fragmented control frame (fin bit not set) has been decoded.
543	FragmentedControl,
544	/// A control frame with an invalid length code has been decoded.
545	InvalidControlFrameLen,
546	/// The reserved bit is invalid.
547	InvalidReservedBit(u8),
548	/// The payload length of a frame exceeded the configured maximum.
549	PayloadTooLarge { actual: u64, maximum: u64 },
550}
551
552impl fmt::Display for Error {
553	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
554		match self {
555			Error::Io(e) => write!(f, "i/o error: {}", e),
556			Error::UnknownOpCode => f.write_str("unknown opcode"),
557			Error::ReservedOpCode => f.write_str("reserved opcode"),
558			Error::FragmentedControl => f.write_str("fragmented control frame"),
559			Error::InvalidControlFrameLen => f.write_str("invalid control frame length"),
560			Error::InvalidReservedBit(n) => write!(f, "invalid reserved bit: {}", n),
561			Error::PayloadTooLarge { actual, maximum } => {
562				write!(f, "payload too large: len = {}, maximum = {}", actual, maximum)
563			}
564		}
565	}
566}
567
568impl std::error::Error for Error {
569	fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
570		match self {
571			Error::Io(e) => Some(e),
572			Error::UnknownOpCode
573			| Error::ReservedOpCode
574			| Error::FragmentedControl
575			| Error::InvalidControlFrameLen
576			| Error::InvalidReservedBit(_)
577			| Error::PayloadTooLarge { .. } => None,
578		}
579	}
580}
581
582impl From<io::Error> for Error {
583	fn from(e: io::Error) -> Self {
584		Error::Io(e)
585	}
586}
587
588impl From<UnknownOpCode> for Error {
589	fn from(_: UnknownOpCode) -> Self {
590		Error::UnknownOpCode
591	}
592}
593
594// Tests //////////////////////////////////////////////////////////////////////////////////////////
595
596#[cfg(test)]
597mod test {
598	use super::{Codec, Error, OpCode};
599	use crate::Parsing;
600	use quickcheck::QuickCheck;
601
602	#[test]
603	fn decode_partial_header() {
604		let partial_header: &[u8] = &[0x89];
605		assert!(matches! {
606			Codec::new().decode_header(partial_header),
607			Ok(Parsing::NeedMore(1))
608		})
609	}
610
611	#[test]
612	fn decode_partial_len() {
613		let partial_length_1: &[u8] = &[0x89, 0xFE, 0x01];
614		assert!(matches! {
615			Codec::new().decode_header(partial_length_1),
616			Ok(Parsing::NeedMore(1))
617		});
618		let partial_length_2: &[u8] = &[0x89, 0xFF, 0x01, 0x02, 0x03, 0x04];
619		assert!(matches! {
620			Codec::new().decode_header(partial_length_2),
621			Ok(Parsing::NeedMore(4))
622		})
623	}
624
625	#[test]
626	fn decode_partial_mask() {
627		let partial_mask: &[u8] = &[0x82, 0xFE, 0x01, 0x02, 0x00, 0x00];
628		assert!(matches! {
629			Codec::new().decode_header(partial_mask),
630			Ok(Parsing::NeedMore(2))
631		})
632	}
633
634	#[test]
635	fn decode_partial_payload() {
636		let partial_payload: &mut [u8] = &mut [0x82, 0x85, 0x01, 0x02, 0x03, 0x04, 0x00, 0x00];
637		if let Ok(Parsing::Done { value, offset }) = Codec::new().decode_header(partial_payload) {
638			assert_eq!(3, value.payload_len() - (partial_payload.len() - offset))
639		} else {
640			assert!(false)
641		}
642	}
643
644	#[test]
645	fn decode_invalid_control_payload_len() {
646		// Payload on control frame must be 125 bytes or less. 2nd byte must be 0xFD or less.
647		let ctrl_payload_len: &[u8] = &[0x89, 0xFE, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
648		assert!(matches! {
649			Codec::new().decode_header(ctrl_payload_len),
650			Err(Error::InvalidControlFrameLen)
651		})
652	}
653
654	/// Checking that rsv1, rsv2, and rsv3 bit set returns error.
655	#[test]
656	fn decode_reserved() {
657		// rsv1, rsv2, and rsv3.
658		let reserved = [0x90, 0xa0, 0xc0];
659		for res in &reserved {
660			let mut buf = [0; 2];
661			buf[0] |= *res;
662			assert!(matches! {
663				Codec::new().decode_header(&buf),
664				Err(Error::InvalidReservedBit(_))
665			})
666		}
667	}
668
669	/// Checking that a control frame, where fin bit is 0, returns an error.
670	#[test]
671	fn decode_fragmented_control() {
672		let second_bytes = [8, 9, 10];
673		for sb in &second_bytes {
674			let mut buf = [0; 2];
675			buf[0] |= *sb;
676			assert!(matches! {
677				Codec::new().decode_header(&buf),
678				Err(Error::FragmentedControl)
679			})
680		}
681	}
682
683	/// Checking that reserved opcodes return an error.
684	#[test]
685	fn decode_reserved_opcodes() {
686		let reserved = [3, 4, 5, 6, 7, 11, 12, 13, 14, 15];
687		for res in &reserved {
688			let mut buf = [0; 2];
689			buf[0] |= 0x80 | *res;
690			assert!(matches! {
691				Codec::new().decode_header(&buf),
692				Err(Error::ReservedOpCode)
693			})
694		}
695	}
696
697	#[test]
698	fn decode_ping_no_data() {
699		let ping_no_data: &mut [u8] = &mut [0x89, 0x80, 0x00, 0x00, 0x00, 0x01];
700		let c = Codec::new();
701		if let Ok(Parsing::Done { value: header, .. }) = c.decode_header(ping_no_data) {
702			assert!(header.is_fin());
703			assert!(!header.is_rsv1());
704			assert!(!header.is_rsv2());
705			assert!(!header.is_rsv3());
706			assert!(header.opcode() == OpCode::Ping);
707			assert!(header.payload_len() == 0)
708		} else {
709			assert!(false)
710		}
711	}
712
713	#[test]
714	fn reserved_bits() {
715		fn property(bits: (bool, bool, bool)) -> bool {
716			let mut c = Codec::new();
717			assert_eq!((false, false, false), c.reserved_bits());
718			c.add_reserved_bits(bits);
719			bits == c.reserved_bits()
720		}
721		QuickCheck::new().quickcheck(property as fn((bool, bool, bool)) -> bool)
722	}
723}