litep2p/yamux/frame/
header.rs

1// Copyright (c) 2018-2019 Parity Technologies (UK) Ltd.
2//
3// Licensed under the Apache License, Version 2.0 or MIT license, at your option.
4//
5// A copy of the Apache License, Version 2.0 is included in the software as
6// LICENSE-APACHE and a copy of the MIT license is included in the software
7// as LICENSE-MIT. You may also obtain a copy of the Apache License, Version 2.0
8// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
9// at https://opensource.org/licenses/MIT.
10
11use futures::future::Either;
12use std::fmt;
13
14/// The message frame header.
15#[derive(Clone, Debug, PartialEq, Eq)]
16pub struct Header<T> {
17    version: Version,
18    tag: Tag,
19    flags: Flags,
20    stream_id: StreamId,
21    length: Len,
22    _marker: std::marker::PhantomData<T>,
23}
24
25impl<T> fmt::Display for Header<T> {
26    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
27        write!(
28            f,
29            "(Header {:?} {} (len {}) (flags {:?}))",
30            self.tag,
31            self.stream_id,
32            self.length.val(),
33            self.flags.val()
34        )
35    }
36}
37
38impl<T> Header<T> {
39    pub fn tag(&self) -> Tag {
40        self.tag
41    }
42
43    pub fn flags(&self) -> Flags {
44        self.flags
45    }
46
47    pub fn stream_id(&self) -> StreamId {
48        self.stream_id
49    }
50
51    pub fn len(&self) -> Len {
52        self.length
53    }
54
55    #[cfg(test)]
56    pub fn set_len(&mut self, len: u32) {
57        self.length = Len(len)
58    }
59
60    /// Arbitrary type cast, use with caution.
61    fn cast<U>(self) -> Header<U> {
62        Header {
63            version: self.version,
64            tag: self.tag,
65            flags: self.flags,
66            stream_id: self.stream_id,
67            length: self.length,
68            _marker: std::marker::PhantomData,
69        }
70    }
71
72    /// Introduce this header to the right of a binary header type.
73    pub(crate) fn right<U>(self) -> Header<Either<U, T>> {
74        self.cast()
75    }
76
77    /// Introduce this header to the left of a binary header type.
78    pub(crate) fn left<U>(self) -> Header<Either<T, U>> {
79        self.cast()
80    }
81}
82
83impl<A: private::Sealed> From<Header<A>> for Header<()> {
84    fn from(h: Header<A>) -> Header<()> {
85        h.cast()
86    }
87}
88
89impl Header<()> {
90    pub(crate) fn into_data(self) -> Header<Data> {
91        debug_assert_eq!(self.tag, Tag::Data);
92        self.cast()
93    }
94
95    pub(crate) fn into_window_update(self) -> Header<WindowUpdate> {
96        debug_assert_eq!(self.tag, Tag::WindowUpdate);
97        self.cast()
98    }
99
100    pub(crate) fn into_ping(self) -> Header<Ping> {
101        debug_assert_eq!(self.tag, Tag::Ping);
102        self.cast()
103    }
104}
105
106impl<T: HasSyn> Header<T> {
107    /// Set the [`SYN`] flag.
108    pub fn syn(&mut self) {
109        self.flags.0 |= SYN.0
110    }
111}
112
113impl<T: HasAck> Header<T> {
114    /// Set the [`ACK`] flag.
115    pub fn ack(&mut self) {
116        self.flags.0 |= ACK.0
117    }
118}
119
120impl<T: HasFin> Header<T> {
121    /// Set the [`FIN`] flag.
122    pub fn fin(&mut self) {
123        self.flags.0 |= FIN.0
124    }
125}
126
127impl<T: HasRst> Header<T> {
128    /// Set the [`RST`] flag.
129    pub fn rst(&mut self) {
130        self.flags.0 |= RST.0
131    }
132}
133
134impl Header<Data> {
135    /// Create a new data frame header.
136    pub fn data(id: StreamId, len: u32) -> Self {
137        Header {
138            version: Version(0),
139            tag: Tag::Data,
140            flags: Flags(0),
141            stream_id: id,
142            length: Len(len),
143            _marker: std::marker::PhantomData,
144        }
145    }
146}
147
148impl Header<WindowUpdate> {
149    /// Create a new window update frame header.
150    pub fn window_update(id: StreamId, credit: u32) -> Self {
151        Header {
152            version: Version(0),
153            tag: Tag::WindowUpdate,
154            flags: Flags(0),
155            stream_id: id,
156            length: Len(credit),
157            _marker: std::marker::PhantomData,
158        }
159    }
160
161    /// The credit this window update grants to the remote.
162    pub fn credit(&self) -> u32 {
163        self.length.0
164    }
165}
166
167impl Header<Ping> {
168    /// Create a new ping frame header.
169    pub fn ping(nonce: u32) -> Self {
170        Header {
171            version: Version(0),
172            tag: Tag::Ping,
173            flags: Flags(0),
174            stream_id: StreamId(0),
175            length: Len(nonce),
176            _marker: std::marker::PhantomData,
177        }
178    }
179
180    /// The nonce of this ping.
181    pub fn nonce(&self) -> u32 {
182        self.length.0
183    }
184}
185
186impl Header<GoAway> {
187    /// Terminate the session without indicating an error to the remote.
188    pub fn term() -> Self {
189        Self::go_away(0)
190    }
191
192    /// Terminate the session indicating a protocol error to the remote.
193    pub fn protocol_error() -> Self {
194        Self::go_away(1)
195    }
196
197    /// Terminate the session indicating an internal error to the remote.
198    pub fn internal_error() -> Self {
199        Self::go_away(2)
200    }
201
202    fn go_away(code: u32) -> Self {
203        Header {
204            version: Version(0),
205            tag: Tag::GoAway,
206            flags: Flags(0),
207            stream_id: StreamId(0),
208            length: Len(code),
209            _marker: std::marker::PhantomData,
210        }
211    }
212}
213
214/// Data message type.
215#[derive(Clone, Debug)]
216pub enum Data {}
217
218/// Window update message type.
219#[derive(Clone, Debug)]
220pub enum WindowUpdate {}
221
222/// Ping message type.
223#[derive(Clone, Debug)]
224pub enum Ping {}
225
226/// Go Away message type.
227#[derive(Clone, Debug)]
228pub enum GoAway {}
229
230/// Types which have a `syn` method.
231pub trait HasSyn: private::Sealed {}
232impl HasSyn for Data {}
233impl HasSyn for WindowUpdate {}
234impl HasSyn for Ping {}
235impl<A: HasSyn, B: HasSyn> HasSyn for Either<A, B> {}
236
237/// Types which have an `ack` method.
238pub trait HasAck: private::Sealed {}
239impl HasAck for Data {}
240impl HasAck for WindowUpdate {}
241impl HasAck for Ping {}
242impl<A: HasAck, B: HasAck> HasAck for Either<A, B> {}
243
244/// Types which have a `fin` method.
245pub trait HasFin: private::Sealed {}
246impl HasFin for Data {}
247impl HasFin for WindowUpdate {}
248
249/// Types which have a `rst` method.
250pub trait HasRst: private::Sealed {}
251impl HasRst for Data {}
252impl HasRst for WindowUpdate {}
253
254pub(super) mod private {
255    pub trait Sealed {}
256
257    impl Sealed for super::Data {}
258    impl Sealed for super::WindowUpdate {}
259    impl Sealed for super::Ping {}
260    impl Sealed for super::GoAway {}
261    impl<A: Sealed, B: Sealed> Sealed for super::Either<A, B> {}
262}
263
264/// A tag is the runtime representation of a message type.
265#[derive(Copy, Clone, Debug, PartialEq, Eq)]
266pub enum Tag {
267    Data,
268    WindowUpdate,
269    Ping,
270    GoAway,
271}
272
273/// The protocol version a message corresponds to.
274#[derive(Copy, Clone, Debug, PartialEq, Eq)]
275pub struct Version(u8);
276
277/// The message length.
278#[derive(Copy, Clone, Debug, PartialEq, Eq)]
279pub struct Len(u32);
280
281impl Len {
282    pub fn val(self) -> u32 {
283        self.0
284    }
285}
286
287pub const CONNECTION_ID: StreamId = StreamId(0);
288
289/// The ID of a stream.
290///
291/// The value 0 denotes no particular stream but the whole session.
292#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
293pub struct StreamId(u32);
294
295impl StreamId {
296    pub(crate) fn new(val: u32) -> Self {
297        StreamId(val)
298    }
299
300    pub fn is_server(self) -> bool {
301        self.0 % 2 == 0
302    }
303
304    pub fn is_client(self) -> bool {
305        !self.is_server()
306    }
307
308    pub fn is_session(self) -> bool {
309        self == CONNECTION_ID
310    }
311
312    pub fn val(self) -> u32 {
313        self.0
314    }
315}
316
317impl fmt::Display for StreamId {
318    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
319        write!(f, "{}", self.0)
320    }
321}
322
323impl nohash_hasher::IsEnabled for StreamId {}
324
325/// Possible flags set on a message.
326#[derive(Copy, Clone, Debug, PartialEq, Eq)]
327pub struct Flags(u16);
328
329impl Flags {
330    pub fn contains(self, other: Flags) -> bool {
331        self.0 & other.0 == other.0
332    }
333
334    pub fn val(self) -> u16 {
335        self.0
336    }
337}
338
339/// Indicates the start of a new stream.
340pub const SYN: Flags = Flags(1);
341
342/// Acknowledges the start of a new stream.
343pub const ACK: Flags = Flags(2);
344
345/// Indicates the half-closing of a stream.
346pub const FIN: Flags = Flags(4);
347
348/// Indicates an immediate stream reset.
349pub const RST: Flags = Flags(8);
350
351/// The serialised header size in bytes.
352pub const HEADER_SIZE: usize = 12;
353
354/// Encode a [`Header`] value.
355pub fn encode<T>(hdr: &Header<T>) -> [u8; HEADER_SIZE] {
356    let mut buf = [0; HEADER_SIZE];
357    buf[0] = hdr.version.0;
358    buf[1] = hdr.tag as u8;
359    buf[2..4].copy_from_slice(&hdr.flags.0.to_be_bytes());
360    buf[4..8].copy_from_slice(&hdr.stream_id.0.to_be_bytes());
361    buf[8..HEADER_SIZE].copy_from_slice(&hdr.length.0.to_be_bytes());
362    buf
363}
364
365/// Decode a [`Header`] value.
366pub fn decode(buf: &[u8; HEADER_SIZE]) -> Result<Header<()>, HeaderDecodeError> {
367    if buf[0] != 0 {
368        return Err(HeaderDecodeError::Version(buf[0]));
369    }
370
371    let hdr = Header {
372        version: Version(buf[0]),
373        tag: match buf[1] {
374            0 => Tag::Data,
375            1 => Tag::WindowUpdate,
376            2 => Tag::Ping,
377            3 => Tag::GoAway,
378            t => return Err(HeaderDecodeError::Type(t)),
379        },
380        flags: Flags(u16::from_be_bytes([buf[2], buf[3]])),
381        stream_id: StreamId(u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]])),
382        length: Len(u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]])),
383        _marker: std::marker::PhantomData,
384    };
385
386    Ok(hdr)
387}
388
389/// Possible errors while decoding a message frame header.
390#[non_exhaustive]
391#[derive(Debug, PartialEq)]
392pub enum HeaderDecodeError {
393    /// Unknown version.
394    Version(u8),
395    /// An unknown frame type.
396    Type(u8),
397}
398
399impl std::fmt::Display for HeaderDecodeError {
400    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
401        match self {
402            HeaderDecodeError::Version(v) => write!(f, "unknown version: {}", v),
403            HeaderDecodeError::Type(t) => write!(f, "unknown frame type: {}", t),
404        }
405    }
406}
407
408impl std::error::Error for HeaderDecodeError {}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use quickcheck::{Arbitrary, Gen, QuickCheck};
414
415    impl Arbitrary for Header<()> {
416        fn arbitrary(g: &mut Gen) -> Self {
417            let tag = *g.choose(&[Tag::Data, Tag::WindowUpdate, Tag::Ping, Tag::GoAway]).unwrap();
418
419            Header {
420                version: Version(0),
421                tag,
422                flags: Flags(Arbitrary::arbitrary(g)),
423                stream_id: StreamId(Arbitrary::arbitrary(g)),
424                length: Len(Arbitrary::arbitrary(g)),
425                _marker: std::marker::PhantomData,
426            }
427        }
428    }
429
430    #[test]
431    fn encode_decode_identity() {
432        fn property(hdr: Header<()>) -> bool {
433            match decode(&encode(&hdr)) {
434                Ok(x) => x == hdr,
435                Err(e) => {
436                    eprintln!("decode error: {}", e);
437                    false
438                }
439            }
440        }
441        QuickCheck::new().tests(10_000).quickcheck(property as fn(Header<()>) -> bool)
442    }
443}