tungstenite/protocol/frame/
frame.rs

1use byteorder::{ByteOrder, NetworkEndian, ReadBytesExt, WriteBytesExt};
2use log::*;
3use std::{
4    borrow::Cow,
5    default::Default,
6    fmt,
7    io::{Cursor, ErrorKind, Read, Write},
8    result::Result as StdResult,
9    str::Utf8Error,
10    string::{FromUtf8Error, String},
11};
12
13use super::{
14    coding::{CloseCode, Control, Data, OpCode},
15    mask::{apply_mask, generate_mask},
16};
17use crate::error::{Error, ProtocolError, Result};
18
19/// A struct representing the close command.
20#[derive(Debug, Clone, Eq, PartialEq)]
21pub struct CloseFrame<'t> {
22    /// The reason as a code.
23    pub code: CloseCode,
24    /// The reason as text string.
25    pub reason: Cow<'t, str>,
26}
27
28impl<'t> CloseFrame<'t> {
29    /// Convert into a owned string.
30    pub fn into_owned(self) -> CloseFrame<'static> {
31        CloseFrame { code: self.code, reason: self.reason.into_owned().into() }
32    }
33}
34
35impl<'t> fmt::Display for CloseFrame<'t> {
36    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
37        write!(f, "{} ({})", self.reason, self.code)
38    }
39}
40
41/// A struct representing a WebSocket frame header.
42#[allow(missing_copy_implementations)]
43#[derive(Debug, Clone, Eq, PartialEq)]
44pub struct FrameHeader {
45    /// Indicates that the frame is the last one of a possibly fragmented message.
46    pub is_final: bool,
47    /// Reserved for protocol extensions.
48    pub rsv1: bool,
49    /// Reserved for protocol extensions.
50    pub rsv2: bool,
51    /// Reserved for protocol extensions.
52    pub rsv3: bool,
53    /// WebSocket protocol opcode.
54    pub opcode: OpCode,
55    /// A frame mask, if any.
56    pub mask: Option<[u8; 4]>,
57}
58
59impl Default for FrameHeader {
60    fn default() -> Self {
61        FrameHeader {
62            is_final: true,
63            rsv1: false,
64            rsv2: false,
65            rsv3: false,
66            opcode: OpCode::Control(Control::Close),
67            mask: None,
68        }
69    }
70}
71
72impl FrameHeader {
73    /// Parse a header from an input stream.
74    /// Returns `None` if insufficient data and does not consume anything in this case.
75    /// Payload size is returned along with the header.
76    pub fn parse(cursor: &mut Cursor<impl AsRef<[u8]>>) -> Result<Option<(Self, u64)>> {
77        let initial = cursor.position();
78        match Self::parse_internal(cursor) {
79            ret @ Ok(None) => {
80                cursor.set_position(initial);
81                ret
82            }
83            ret => ret,
84        }
85    }
86
87    /// Get the size of the header formatted with given payload length.
88    #[allow(clippy::len_without_is_empty)]
89    pub fn len(&self, length: u64) -> usize {
90        2 + LengthFormat::for_length(length).extra_bytes() + if self.mask.is_some() { 4 } else { 0 }
91    }
92
93    /// Format a header for given payload size.
94    pub fn format(&self, length: u64, output: &mut impl Write) -> Result<()> {
95        let code: u8 = self.opcode.into();
96
97        let one = {
98            code | if self.is_final { 0x80 } else { 0 }
99                | if self.rsv1 { 0x40 } else { 0 }
100                | if self.rsv2 { 0x20 } else { 0 }
101                | if self.rsv3 { 0x10 } else { 0 }
102        };
103
104        let lenfmt = LengthFormat::for_length(length);
105
106        let two = { lenfmt.length_byte() | if self.mask.is_some() { 0x80 } else { 0 } };
107
108        output.write_all(&[one, two])?;
109        match lenfmt {
110            LengthFormat::U8(_) => (),
111            LengthFormat::U16 => output.write_u16::<NetworkEndian>(length as u16)?,
112            LengthFormat::U64 => output.write_u64::<NetworkEndian>(length)?,
113        }
114
115        if let Some(ref mask) = self.mask {
116            output.write_all(mask)?
117        }
118
119        Ok(())
120    }
121
122    /// Generate a random frame mask and store this in the header.
123    ///
124    /// Of course this does not change frame contents. It just generates a mask.
125    pub(crate) fn set_random_mask(&mut self) {
126        self.mask = Some(generate_mask())
127    }
128}
129
130impl FrameHeader {
131    /// Internal parse engine.
132    /// Returns `None` if insufficient data.
133    /// Payload size is returned along with the header.
134    fn parse_internal(cursor: &mut impl Read) -> Result<Option<(Self, u64)>> {
135        let (first, second) = {
136            let mut head = [0u8; 2];
137            if cursor.read(&mut head)? != 2 {
138                return Ok(None);
139            }
140            trace!("Parsed headers {:?}", head);
141            (head[0], head[1])
142        };
143
144        trace!("First: {:b}", first);
145        trace!("Second: {:b}", second);
146
147        let is_final = first & 0x80 != 0;
148
149        let rsv1 = first & 0x40 != 0;
150        let rsv2 = first & 0x20 != 0;
151        let rsv3 = first & 0x10 != 0;
152
153        let opcode = OpCode::from(first & 0x0F);
154        trace!("Opcode: {:?}", opcode);
155
156        let masked = second & 0x80 != 0;
157        trace!("Masked: {:?}", masked);
158
159        let length = {
160            let length_byte = second & 0x7F;
161            let length_length = LengthFormat::for_byte(length_byte).extra_bytes();
162            if length_length > 0 {
163                match cursor.read_uint::<NetworkEndian>(length_length) {
164                    Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
165                        return Ok(None);
166                    }
167                    Err(err) => {
168                        return Err(err.into());
169                    }
170                    Ok(read) => read,
171                }
172            } else {
173                u64::from(length_byte)
174            }
175        };
176
177        let mask = if masked {
178            let mut mask_bytes = [0u8; 4];
179            if cursor.read(&mut mask_bytes)? != 4 {
180                return Ok(None);
181            } else {
182                Some(mask_bytes)
183            }
184        } else {
185            None
186        };
187
188        // Disallow bad opcode
189        match opcode {
190            OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
191                return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F)))
192            }
193            _ => (),
194        }
195
196        let hdr = FrameHeader { is_final, rsv1, rsv2, rsv3, opcode, mask };
197
198        Ok(Some((hdr, length)))
199    }
200}
201
202/// A struct representing a WebSocket frame.
203#[derive(Debug, Clone, Eq, PartialEq)]
204pub struct Frame {
205    header: FrameHeader,
206    payload: Vec<u8>,
207}
208
209impl Frame {
210    /// Get the length of the frame.
211    /// This is the length of the header + the length of the payload.
212    #[inline]
213    pub fn len(&self) -> usize {
214        let length = self.payload.len();
215        self.header.len(length as u64) + length
216    }
217
218    /// Check if the frame is empty.
219    #[inline]
220    pub fn is_empty(&self) -> bool {
221        self.len() == 0
222    }
223
224    /// Get a reference to the frame's header.
225    #[inline]
226    pub fn header(&self) -> &FrameHeader {
227        &self.header
228    }
229
230    /// Get a mutable reference to the frame's header.
231    #[inline]
232    pub fn header_mut(&mut self) -> &mut FrameHeader {
233        &mut self.header
234    }
235
236    /// Get a reference to the frame's payload.
237    #[inline]
238    pub fn payload(&self) -> &Vec<u8> {
239        &self.payload
240    }
241
242    /// Get a mutable reference to the frame's payload.
243    #[inline]
244    pub fn payload_mut(&mut self) -> &mut Vec<u8> {
245        &mut self.payload
246    }
247
248    /// Test whether the frame is masked.
249    #[inline]
250    pub(crate) fn is_masked(&self) -> bool {
251        self.header.mask.is_some()
252    }
253
254    /// Generate a random mask for the frame.
255    ///
256    /// This just generates a mask, payload is not changed. The actual masking is performed
257    /// either on `format()` or on `apply_mask()` call.
258    #[inline]
259    pub(crate) fn set_random_mask(&mut self) {
260        self.header.set_random_mask()
261    }
262
263    /// This method unmasks the payload and should only be called on frames that are actually
264    /// masked. In other words, those frames that have just been received from a client endpoint.
265    #[inline]
266    pub(crate) fn apply_mask(&mut self) {
267        if let Some(mask) = self.header.mask.take() {
268            apply_mask(&mut self.payload, mask)
269        }
270    }
271
272    /// Consume the frame into its payload as binary.
273    #[inline]
274    pub fn into_data(self) -> Vec<u8> {
275        self.payload
276    }
277
278    /// Consume the frame into its payload as string.
279    #[inline]
280    pub fn into_string(self) -> StdResult<String, FromUtf8Error> {
281        String::from_utf8(self.payload)
282    }
283
284    /// Get frame payload as `&str`.
285    #[inline]
286    pub fn to_text(&self) -> Result<&str, Utf8Error> {
287        std::str::from_utf8(&self.payload)
288    }
289
290    /// Consume the frame into a closing frame.
291    #[inline]
292    pub(crate) fn into_close(self) -> Result<Option<CloseFrame<'static>>> {
293        match self.payload.len() {
294            0 => Ok(None),
295            1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
296            _ => {
297                let mut data = self.payload;
298                let code = NetworkEndian::read_u16(&data[0..2]).into();
299                data.drain(0..2);
300                let text = String::from_utf8(data)?;
301                Ok(Some(CloseFrame { code, reason: text.into() }))
302            }
303        }
304    }
305
306    /// Create a new data frame.
307    #[inline]
308    pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
309        debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
310
311        Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
312    }
313
314    /// Create a new Pong control frame.
315    #[inline]
316    pub fn pong(data: Vec<u8>) -> Frame {
317        Frame {
318            header: FrameHeader {
319                opcode: OpCode::Control(Control::Pong),
320                ..FrameHeader::default()
321            },
322            payload: data,
323        }
324    }
325
326    /// Create a new Ping control frame.
327    #[inline]
328    pub fn ping(data: Vec<u8>) -> Frame {
329        Frame {
330            header: FrameHeader {
331                opcode: OpCode::Control(Control::Ping),
332                ..FrameHeader::default()
333            },
334            payload: data,
335        }
336    }
337
338    /// Create a new Close control frame.
339    #[inline]
340    pub fn close(msg: Option<CloseFrame>) -> Frame {
341        let payload = if let Some(CloseFrame { code, reason }) = msg {
342            let mut p = Vec::with_capacity(reason.as_bytes().len() + 2);
343            p.write_u16::<NetworkEndian>(code.into()).unwrap(); // can't fail
344            p.extend_from_slice(reason.as_bytes());
345            p
346        } else {
347            Vec::new()
348        };
349
350        Frame { header: FrameHeader::default(), payload }
351    }
352
353    /// Create a frame from given header and data.
354    pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self {
355        Frame { header, payload }
356    }
357
358    /// Write a frame out to a buffer
359    pub fn format(mut self, output: &mut impl Write) -> Result<()> {
360        self.header.format(self.payload.len() as u64, output)?;
361        self.apply_mask();
362        output.write_all(self.payload())?;
363        Ok(())
364    }
365}
366
367impl fmt::Display for Frame {
368    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
369        write!(
370            f,
371            "
372<FRAME>
373final: {}
374reserved: {} {} {}
375opcode: {}
376length: {}
377payload length: {}
378payload: 0x{}
379            ",
380            self.header.is_final,
381            self.header.rsv1,
382            self.header.rsv2,
383            self.header.rsv3,
384            self.header.opcode,
385            // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
386            self.len(),
387            self.payload.len(),
388            self.payload.iter().map(|byte| format!("{:02x}", byte)).collect::<String>()
389        )
390    }
391}
392
393/// Handling of the length format.
394enum LengthFormat {
395    U8(u8),
396    U16,
397    U64,
398}
399
400impl LengthFormat {
401    /// Get the length format for a given data size.
402    #[inline]
403    fn for_length(length: u64) -> Self {
404        if length < 126 {
405            LengthFormat::U8(length as u8)
406        } else if length < 65536 {
407            LengthFormat::U16
408        } else {
409            LengthFormat::U64
410        }
411    }
412
413    /// Get the size of the length encoding.
414    #[inline]
415    fn extra_bytes(&self) -> usize {
416        match *self {
417            LengthFormat::U8(_) => 0,
418            LengthFormat::U16 => 2,
419            LengthFormat::U64 => 8,
420        }
421    }
422
423    /// Encode the given length.
424    #[inline]
425    fn length_byte(&self) -> u8 {
426        match *self {
427            LengthFormat::U8(b) => b,
428            LengthFormat::U16 => 126,
429            LengthFormat::U64 => 127,
430        }
431    }
432
433    /// Get the length format for a given length byte.
434    #[inline]
435    fn for_byte(byte: u8) -> Self {
436        match byte & 0x7F {
437            126 => LengthFormat::U16,
438            127 => LengthFormat::U64,
439            b => LengthFormat::U8(b),
440        }
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    use super::super::coding::{Data, OpCode};
449    use std::io::Cursor;
450
451    #[test]
452    fn parse() {
453        let mut raw: Cursor<Vec<u8>> =
454            Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
455        let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap();
456        assert_eq!(length, 7);
457        let mut payload = Vec::new();
458        raw.read_to_end(&mut payload).unwrap();
459        let frame = Frame::from_payload(header, payload);
460        assert_eq!(frame.into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
461    }
462
463    #[test]
464    fn format() {
465        let frame = Frame::ping(vec![0x01, 0x02]);
466        let mut buf = Vec::with_capacity(frame.len());
467        frame.format(&mut buf).unwrap();
468        assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]);
469    }
470
471    #[test]
472    fn display() {
473        let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true);
474        let view = format!("{}", f);
475        assert!(view.contains("payload:"));
476    }
477}