1use futures::future::Either;
12use std::fmt;
13
14#[derive(Clone, Debug, PartialEq, Eq)]
16pub struct Header<T> {
17 version: Version,
18 tag: Tag,
19 flags: Flags,
20 stream_id: StreamId,
21 length: Len,
22 _marker: std::marker::PhantomData<T>,
23}
24
25impl<T> fmt::Display for Header<T> {
26 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
27 write!(
28 f,
29 "(Header {:?} {} (len {}) (flags {:?}))",
30 self.tag,
31 self.stream_id,
32 self.length.val(),
33 self.flags.val()
34 )
35 }
36}
37
38impl<T> Header<T> {
39 pub fn tag(&self) -> Tag {
40 self.tag
41 }
42
43 pub fn flags(&self) -> Flags {
44 self.flags
45 }
46
47 pub fn stream_id(&self) -> StreamId {
48 self.stream_id
49 }
50
51 pub fn len(&self) -> Len {
52 self.length
53 }
54
55 #[cfg(test)]
56 pub fn set_len(&mut self, len: u32) {
57 self.length = Len(len)
58 }
59
60 fn cast<U>(self) -> Header<U> {
62 Header {
63 version: self.version,
64 tag: self.tag,
65 flags: self.flags,
66 stream_id: self.stream_id,
67 length: self.length,
68 _marker: std::marker::PhantomData,
69 }
70 }
71
72 pub(crate) fn right<U>(self) -> Header<Either<U, T>> {
74 self.cast()
75 }
76
77 pub(crate) fn left<U>(self) -> Header<Either<T, U>> {
79 self.cast()
80 }
81}
82
83impl<A: private::Sealed> From<Header<A>> for Header<()> {
84 fn from(h: Header<A>) -> Header<()> {
85 h.cast()
86 }
87}
88
89impl Header<()> {
90 pub(crate) fn into_data(self) -> Header<Data> {
91 debug_assert_eq!(self.tag, Tag::Data);
92 self.cast()
93 }
94
95 pub(crate) fn into_window_update(self) -> Header<WindowUpdate> {
96 debug_assert_eq!(self.tag, Tag::WindowUpdate);
97 self.cast()
98 }
99
100 pub(crate) fn into_ping(self) -> Header<Ping> {
101 debug_assert_eq!(self.tag, Tag::Ping);
102 self.cast()
103 }
104}
105
106impl<T: HasSyn> Header<T> {
107 pub fn syn(&mut self) {
109 self.flags.0 |= SYN.0
110 }
111}
112
113impl<T: HasAck> Header<T> {
114 pub fn ack(&mut self) {
116 self.flags.0 |= ACK.0
117 }
118}
119
120impl<T: HasFin> Header<T> {
121 pub fn fin(&mut self) {
123 self.flags.0 |= FIN.0
124 }
125}
126
127impl<T: HasRst> Header<T> {
128 pub fn rst(&mut self) {
130 self.flags.0 |= RST.0
131 }
132}
133
134impl Header<Data> {
135 pub fn data(id: StreamId, len: u32) -> Self {
137 Header {
138 version: Version(0),
139 tag: Tag::Data,
140 flags: Flags(0),
141 stream_id: id,
142 length: Len(len),
143 _marker: std::marker::PhantomData,
144 }
145 }
146}
147
148impl Header<WindowUpdate> {
149 pub fn window_update(id: StreamId, credit: u32) -> Self {
151 Header {
152 version: Version(0),
153 tag: Tag::WindowUpdate,
154 flags: Flags(0),
155 stream_id: id,
156 length: Len(credit),
157 _marker: std::marker::PhantomData,
158 }
159 }
160
161 pub fn credit(&self) -> u32 {
163 self.length.0
164 }
165}
166
167impl Header<Ping> {
168 pub fn ping(nonce: u32) -> Self {
170 Header {
171 version: Version(0),
172 tag: Tag::Ping,
173 flags: Flags(0),
174 stream_id: StreamId(0),
175 length: Len(nonce),
176 _marker: std::marker::PhantomData,
177 }
178 }
179
180 pub fn nonce(&self) -> u32 {
182 self.length.0
183 }
184}
185
186impl Header<GoAway> {
187 pub fn term() -> Self {
189 Self::go_away(0)
190 }
191
192 pub fn protocol_error() -> Self {
194 Self::go_away(1)
195 }
196
197 pub fn internal_error() -> Self {
199 Self::go_away(2)
200 }
201
202 fn go_away(code: u32) -> Self {
203 Header {
204 version: Version(0),
205 tag: Tag::GoAway,
206 flags: Flags(0),
207 stream_id: StreamId(0),
208 length: Len(code),
209 _marker: std::marker::PhantomData,
210 }
211 }
212}
213
214#[derive(Clone, Debug)]
216pub enum Data {}
217
218#[derive(Clone, Debug)]
220pub enum WindowUpdate {}
221
222#[derive(Clone, Debug)]
224pub enum Ping {}
225
226#[derive(Clone, Debug)]
228pub enum GoAway {}
229
230pub trait HasSyn: private::Sealed {}
232impl HasSyn for Data {}
233impl HasSyn for WindowUpdate {}
234impl HasSyn for Ping {}
235impl<A: HasSyn, B: HasSyn> HasSyn for Either<A, B> {}
236
237pub trait HasAck: private::Sealed {}
239impl HasAck for Data {}
240impl HasAck for WindowUpdate {}
241impl HasAck for Ping {}
242impl<A: HasAck, B: HasAck> HasAck for Either<A, B> {}
243
244pub trait HasFin: private::Sealed {}
246impl HasFin for Data {}
247impl HasFin for WindowUpdate {}
248
249pub trait HasRst: private::Sealed {}
251impl HasRst for Data {}
252impl HasRst for WindowUpdate {}
253
254pub(super) mod private {
255 pub trait Sealed {}
256
257 impl Sealed for super::Data {}
258 impl Sealed for super::WindowUpdate {}
259 impl Sealed for super::Ping {}
260 impl Sealed for super::GoAway {}
261 impl<A: Sealed, B: Sealed> Sealed for super::Either<A, B> {}
262}
263
264#[derive(Copy, Clone, Debug, PartialEq, Eq)]
266pub enum Tag {
267 Data,
268 WindowUpdate,
269 Ping,
270 GoAway,
271}
272
273#[derive(Copy, Clone, Debug, PartialEq, Eq)]
275pub struct Version(u8);
276
277#[derive(Copy, Clone, Debug, PartialEq, Eq)]
279pub struct Len(u32);
280
281impl Len {
282 pub fn val(self) -> u32 {
283 self.0
284 }
285}
286
287pub const CONNECTION_ID: StreamId = StreamId(0);
288
289#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
293pub struct StreamId(u32);
294
295impl StreamId {
296 pub(crate) fn new(val: u32) -> Self {
297 StreamId(val)
298 }
299
300 pub fn is_server(self) -> bool {
301 self.0 % 2 == 0
302 }
303
304 pub fn is_client(self) -> bool {
305 !self.is_server()
306 }
307
308 pub fn is_session(self) -> bool {
309 self == CONNECTION_ID
310 }
311
312 pub fn val(self) -> u32 {
313 self.0
314 }
315}
316
317impl fmt::Display for StreamId {
318 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
319 write!(f, "{}", self.0)
320 }
321}
322
323impl nohash_hasher::IsEnabled for StreamId {}
324
325#[derive(Copy, Clone, Debug, PartialEq, Eq)]
327pub struct Flags(u16);
328
329impl Flags {
330 pub fn contains(self, other: Flags) -> bool {
331 self.0 & other.0 == other.0
332 }
333
334 pub fn val(self) -> u16 {
335 self.0
336 }
337}
338
339pub const SYN: Flags = Flags(1);
341
342pub const ACK: Flags = Flags(2);
344
345pub const FIN: Flags = Flags(4);
347
348pub const RST: Flags = Flags(8);
350
351pub const HEADER_SIZE: usize = 12;
353
354pub fn encode<T>(hdr: &Header<T>) -> [u8; HEADER_SIZE] {
356 let mut buf = [0; HEADER_SIZE];
357 buf[0] = hdr.version.0;
358 buf[1] = hdr.tag as u8;
359 buf[2..4].copy_from_slice(&hdr.flags.0.to_be_bytes());
360 buf[4..8].copy_from_slice(&hdr.stream_id.0.to_be_bytes());
361 buf[8..HEADER_SIZE].copy_from_slice(&hdr.length.0.to_be_bytes());
362 buf
363}
364
365pub fn decode(buf: &[u8; HEADER_SIZE]) -> Result<Header<()>, HeaderDecodeError> {
367 if buf[0] != 0 {
368 return Err(HeaderDecodeError::Version(buf[0]));
369 }
370
371 let hdr = Header {
372 version: Version(buf[0]),
373 tag: match buf[1] {
374 0 => Tag::Data,
375 1 => Tag::WindowUpdate,
376 2 => Tag::Ping,
377 3 => Tag::GoAway,
378 t => return Err(HeaderDecodeError::Type(t)),
379 },
380 flags: Flags(u16::from_be_bytes([buf[2], buf[3]])),
381 stream_id: StreamId(u32::from_be_bytes([buf[4], buf[5], buf[6], buf[7]])),
382 length: Len(u32::from_be_bytes([buf[8], buf[9], buf[10], buf[11]])),
383 _marker: std::marker::PhantomData,
384 };
385
386 Ok(hdr)
387}
388
389#[non_exhaustive]
391#[derive(Debug, PartialEq)]
392pub enum HeaderDecodeError {
393 Version(u8),
395 Type(u8),
397}
398
399impl std::fmt::Display for HeaderDecodeError {
400 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
401 match self {
402 HeaderDecodeError::Version(v) => write!(f, "unknown version: {}", v),
403 HeaderDecodeError::Type(t) => write!(f, "unknown frame type: {}", t),
404 }
405 }
406}
407
408impl std::error::Error for HeaderDecodeError {}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use quickcheck::{Arbitrary, Gen, QuickCheck};
414
415 impl Arbitrary for Header<()> {
416 fn arbitrary(g: &mut Gen) -> Self {
417 let tag = *g.choose(&[Tag::Data, Tag::WindowUpdate, Tag::Ping, Tag::GoAway]).unwrap();
418
419 Header {
420 version: Version(0),
421 tag,
422 flags: Flags(Arbitrary::arbitrary(g)),
423 stream_id: StreamId(Arbitrary::arbitrary(g)),
424 length: Len(Arbitrary::arbitrary(g)),
425 _marker: std::marker::PhantomData,
426 }
427 }
428 }
429
430 #[test]
431 fn encode_decode_identity() {
432 fn property(hdr: Header<()>) -> bool {
433 match decode(&encode(&hdr)) {
434 Ok(x) => x == hdr,
435 Err(e) => {
436 eprintln!("decode error: {}", e);
437 false
438 }
439 }
440 }
441 QuickCheck::new().tests(10_000).quickcheck(property as fn(Header<()>) -> bool)
442 }
443}