h2/frame/
headers.rs

1use super::{util, StreamDependency, StreamId};
2use crate::ext::Protocol;
3use crate::frame::{Error, Frame, Head, Kind};
4use crate::hpack::{self, BytesStr};
5
6use http::header::{self, HeaderName, HeaderValue};
7use http::{uri, HeaderMap, Method, Request, StatusCode, Uri};
8
9use bytes::{BufMut, Bytes, BytesMut};
10
11use std::fmt;
12use std::io::Cursor;
13
14type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>;
15
16/// Header frame
17///
18/// This could be either a request or a response.
19#[derive(Eq, PartialEq)]
20pub struct Headers {
21    /// The ID of the stream with which this frame is associated.
22    stream_id: StreamId,
23
24    /// The stream dependency information, if any.
25    stream_dep: Option<StreamDependency>,
26
27    /// The header block fragment
28    header_block: HeaderBlock,
29
30    /// The associated flags
31    flags: HeadersFlag,
32}
33
34#[derive(Copy, Clone, Eq, PartialEq)]
35pub struct HeadersFlag(u8);
36
37#[derive(Eq, PartialEq)]
38pub struct PushPromise {
39    /// The ID of the stream with which this frame is associated.
40    stream_id: StreamId,
41
42    /// The ID of the stream being reserved by this PushPromise.
43    promised_id: StreamId,
44
45    /// The header block fragment
46    header_block: HeaderBlock,
47
48    /// The associated flags
49    flags: PushPromiseFlag,
50}
51
52#[derive(Copy, Clone, Eq, PartialEq)]
53pub struct PushPromiseFlag(u8);
54
55#[derive(Debug)]
56pub struct Continuation {
57    /// Stream ID of continuation frame
58    stream_id: StreamId,
59
60    header_block: EncodingHeaderBlock,
61}
62
63// TODO: These fields shouldn't be `pub`
64#[derive(Debug, Default, Eq, PartialEq)]
65pub struct Pseudo {
66    // Request
67    pub method: Option<Method>,
68    pub scheme: Option<BytesStr>,
69    pub authority: Option<BytesStr>,
70    pub path: Option<BytesStr>,
71    pub protocol: Option<Protocol>,
72
73    // Response
74    pub status: Option<StatusCode>,
75}
76
77#[derive(Debug)]
78pub struct Iter {
79    /// Pseudo headers
80    pseudo: Option<Pseudo>,
81
82    /// Header fields
83    fields: header::IntoIter<HeaderValue>,
84}
85
86#[derive(Debug, PartialEq, Eq)]
87struct HeaderBlock {
88    /// The decoded header fields
89    fields: HeaderMap,
90
91    /// Precomputed size of all of our header fields, for perf reasons
92    field_size: usize,
93
94    /// Set to true if decoding went over the max header list size.
95    is_over_size: bool,
96
97    /// Pseudo headers, these are broken out as they must be sent as part of the
98    /// headers frame.
99    pseudo: Pseudo,
100}
101
102#[derive(Debug)]
103struct EncodingHeaderBlock {
104    hpack: Bytes,
105}
106
107const END_STREAM: u8 = 0x1;
108const END_HEADERS: u8 = 0x4;
109const PADDED: u8 = 0x8;
110const PRIORITY: u8 = 0x20;
111const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
112
113// ===== impl Headers =====
114
115impl Headers {
116    /// Create a new HEADERS frame
117    pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self {
118        Headers {
119            stream_id,
120            stream_dep: None,
121            header_block: HeaderBlock {
122                field_size: calculate_headermap_size(&fields),
123                fields,
124                is_over_size: false,
125                pseudo,
126            },
127            flags: HeadersFlag::default(),
128        }
129    }
130
131    pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
132        let mut flags = HeadersFlag::default();
133        flags.set_end_stream();
134
135        Headers {
136            stream_id,
137            stream_dep: None,
138            header_block: HeaderBlock {
139                field_size: calculate_headermap_size(&fields),
140                fields,
141                is_over_size: false,
142                pseudo: Pseudo::default(),
143            },
144            flags,
145        }
146    }
147
148    /// Loads the header frame but doesn't actually do HPACK decoding.
149    ///
150    /// HPACK decoding is done in the `load_hpack` step.
151    pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
152        let flags = HeadersFlag(head.flag());
153        let mut pad = 0;
154
155        tracing::trace!("loading headers; flags={:?}", flags);
156
157        if head.stream_id().is_zero() {
158            return Err(Error::InvalidStreamId);
159        }
160
161        // Read the padding length
162        if flags.is_padded() {
163            if src.is_empty() {
164                return Err(Error::MalformedMessage);
165            }
166            pad = src[0] as usize;
167
168            // Drop the padding
169            let _ = src.split_to(1);
170        }
171
172        // Read the stream dependency
173        let stream_dep = if flags.is_priority() {
174            if src.len() < 5 {
175                return Err(Error::MalformedMessage);
176            }
177            let stream_dep = StreamDependency::load(&src[..5])?;
178
179            if stream_dep.dependency_id() == head.stream_id() {
180                return Err(Error::InvalidDependencyId);
181            }
182
183            // Drop the next 5 bytes
184            let _ = src.split_to(5);
185
186            Some(stream_dep)
187        } else {
188            None
189        };
190
191        if pad > 0 {
192            if pad > src.len() {
193                return Err(Error::TooMuchPadding);
194            }
195
196            let len = src.len() - pad;
197            src.truncate(len);
198        }
199
200        let headers = Headers {
201            stream_id: head.stream_id(),
202            stream_dep,
203            header_block: HeaderBlock {
204                fields: HeaderMap::new(),
205                field_size: 0,
206                is_over_size: false,
207                pseudo: Pseudo::default(),
208            },
209            flags,
210        };
211
212        Ok((headers, src))
213    }
214
215    pub fn load_hpack(
216        &mut self,
217        src: &mut BytesMut,
218        max_header_list_size: usize,
219        decoder: &mut hpack::Decoder,
220    ) -> Result<(), Error> {
221        self.header_block.load(src, max_header_list_size, decoder)
222    }
223
224    pub fn stream_id(&self) -> StreamId {
225        self.stream_id
226    }
227
228    pub fn is_end_headers(&self) -> bool {
229        self.flags.is_end_headers()
230    }
231
232    pub fn set_end_headers(&mut self) {
233        self.flags.set_end_headers();
234    }
235
236    pub fn is_end_stream(&self) -> bool {
237        self.flags.is_end_stream()
238    }
239
240    pub fn set_end_stream(&mut self) {
241        self.flags.set_end_stream()
242    }
243
244    pub fn is_over_size(&self) -> bool {
245        self.header_block.is_over_size
246    }
247
248    pub fn into_parts(self) -> (Pseudo, HeaderMap) {
249        (self.header_block.pseudo, self.header_block.fields)
250    }
251
252    #[cfg(feature = "unstable")]
253    pub fn pseudo_mut(&mut self) -> &mut Pseudo {
254        &mut self.header_block.pseudo
255    }
256
257    /// Whether it has status 1xx
258    pub(crate) fn is_informational(&self) -> bool {
259        self.header_block.pseudo.is_informational()
260    }
261
262    pub fn fields(&self) -> &HeaderMap {
263        &self.header_block.fields
264    }
265
266    pub fn into_fields(self) -> HeaderMap {
267        self.header_block.fields
268    }
269
270    pub fn encode(
271        self,
272        encoder: &mut hpack::Encoder,
273        dst: &mut EncodeBuf<'_>,
274    ) -> Option<Continuation> {
275        // At this point, the `is_end_headers` flag should always be set
276        debug_assert!(self.flags.is_end_headers());
277
278        // Get the HEADERS frame head
279        let head = self.head();
280
281        self.header_block
282            .into_encoding(encoder)
283            .encode(&head, dst, |_| {})
284    }
285
286    fn head(&self) -> Head {
287        Head::new(Kind::Headers, self.flags.into(), self.stream_id)
288    }
289}
290
291impl<T> From<Headers> for Frame<T> {
292    fn from(src: Headers) -> Self {
293        Frame::Headers(src)
294    }
295}
296
297impl fmt::Debug for Headers {
298    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
299        let mut builder = f.debug_struct("Headers");
300        builder
301            .field("stream_id", &self.stream_id)
302            .field("flags", &self.flags);
303
304        if let Some(ref protocol) = self.header_block.pseudo.protocol {
305            builder.field("protocol", protocol);
306        }
307
308        if let Some(ref dep) = self.stream_dep {
309            builder.field("stream_dep", dep);
310        }
311
312        // `fields` and `pseudo` purposefully not included
313        builder.finish()
314    }
315}
316
317// ===== util =====
318
319#[derive(Debug, PartialEq, Eq)]
320pub struct ParseU64Error;
321
322pub fn parse_u64(src: &[u8]) -> Result<u64, ParseU64Error> {
323    if src.len() > 19 {
324        // At danger for overflow...
325        return Err(ParseU64Error);
326    }
327
328    let mut ret = 0;
329
330    for &d in src {
331        if d < b'0' || d > b'9' {
332            return Err(ParseU64Error);
333        }
334
335        ret *= 10;
336        ret += (d - b'0') as u64;
337    }
338
339    Ok(ret)
340}
341
342// ===== impl PushPromise =====
343
344#[derive(Debug)]
345pub enum PushPromiseHeaderError {
346    InvalidContentLength(Result<u64, ParseU64Error>),
347    NotSafeAndCacheable,
348}
349
350impl PushPromise {
351    pub fn new(
352        stream_id: StreamId,
353        promised_id: StreamId,
354        pseudo: Pseudo,
355        fields: HeaderMap,
356    ) -> Self {
357        PushPromise {
358            flags: PushPromiseFlag::default(),
359            header_block: HeaderBlock {
360                field_size: calculate_headermap_size(&fields),
361                fields,
362                is_over_size: false,
363                pseudo,
364            },
365            promised_id,
366            stream_id,
367        }
368    }
369
370    pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> {
371        use PushPromiseHeaderError::*;
372        // The spec has some requirements for promised request headers
373        // [https://httpwg.org/specs/rfc7540.html#PushRequests]
374
375        // A promised request "that indicates the presence of a request body
376        // MUST reset the promised stream with a stream error"
377        if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
378            let parsed_length = parse_u64(content_length.as_bytes());
379            if parsed_length != Ok(0) {
380                return Err(InvalidContentLength(parsed_length));
381            }
382        }
383        // "The server MUST include a method in the :method pseudo-header field
384        // that is safe and cacheable"
385        if !Self::safe_and_cacheable(req.method()) {
386            return Err(NotSafeAndCacheable);
387        }
388
389        Ok(())
390    }
391
392    fn safe_and_cacheable(method: &Method) -> bool {
393        // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods
394        // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods
395        method == Method::GET || method == Method::HEAD
396    }
397
398    pub fn fields(&self) -> &HeaderMap {
399        &self.header_block.fields
400    }
401
402    #[cfg(feature = "unstable")]
403    pub fn into_fields(self) -> HeaderMap {
404        self.header_block.fields
405    }
406
407    /// Loads the push promise frame but doesn't actually do HPACK decoding.
408    ///
409    /// HPACK decoding is done in the `load_hpack` step.
410    pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
411        let flags = PushPromiseFlag(head.flag());
412        let mut pad = 0;
413
414        if head.stream_id().is_zero() {
415            return Err(Error::InvalidStreamId);
416        }
417
418        // Read the padding length
419        if flags.is_padded() {
420            if src.is_empty() {
421                return Err(Error::MalformedMessage);
422            }
423
424            // TODO: Ensure payload is sized correctly
425            pad = src[0] as usize;
426
427            // Drop the padding
428            let _ = src.split_to(1);
429        }
430
431        if src.len() < 5 {
432            return Err(Error::MalformedMessage);
433        }
434
435        let (promised_id, _) = StreamId::parse(&src[..4]);
436        // Drop promised_id bytes
437        let _ = src.split_to(4);
438
439        if pad > 0 {
440            if pad > src.len() {
441                return Err(Error::TooMuchPadding);
442            }
443
444            let len = src.len() - pad;
445            src.truncate(len);
446        }
447
448        let frame = PushPromise {
449            flags,
450            header_block: HeaderBlock {
451                fields: HeaderMap::new(),
452                field_size: 0,
453                is_over_size: false,
454                pseudo: Pseudo::default(),
455            },
456            promised_id,
457            stream_id: head.stream_id(),
458        };
459        Ok((frame, src))
460    }
461
462    pub fn load_hpack(
463        &mut self,
464        src: &mut BytesMut,
465        max_header_list_size: usize,
466        decoder: &mut hpack::Decoder,
467    ) -> Result<(), Error> {
468        self.header_block.load(src, max_header_list_size, decoder)
469    }
470
471    pub fn stream_id(&self) -> StreamId {
472        self.stream_id
473    }
474
475    pub fn promised_id(&self) -> StreamId {
476        self.promised_id
477    }
478
479    pub fn is_end_headers(&self) -> bool {
480        self.flags.is_end_headers()
481    }
482
483    pub fn set_end_headers(&mut self) {
484        self.flags.set_end_headers();
485    }
486
487    pub fn is_over_size(&self) -> bool {
488        self.header_block.is_over_size
489    }
490
491    pub fn encode(
492        self,
493        encoder: &mut hpack::Encoder,
494        dst: &mut EncodeBuf<'_>,
495    ) -> Option<Continuation> {
496        // At this point, the `is_end_headers` flag should always be set
497        debug_assert!(self.flags.is_end_headers());
498
499        let head = self.head();
500        let promised_id = self.promised_id;
501
502        self.header_block
503            .into_encoding(encoder)
504            .encode(&head, dst, |dst| {
505                dst.put_u32(promised_id.into());
506            })
507    }
508
509    fn head(&self) -> Head {
510        Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
511    }
512
513    /// Consume `self`, returning the parts of the frame
514    pub fn into_parts(self) -> (Pseudo, HeaderMap) {
515        (self.header_block.pseudo, self.header_block.fields)
516    }
517}
518
519impl<T> From<PushPromise> for Frame<T> {
520    fn from(src: PushPromise) -> Self {
521        Frame::PushPromise(src)
522    }
523}
524
525impl fmt::Debug for PushPromise {
526    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
527        f.debug_struct("PushPromise")
528            .field("stream_id", &self.stream_id)
529            .field("promised_id", &self.promised_id)
530            .field("flags", &self.flags)
531            // `fields` and `pseudo` purposefully not included
532            .finish()
533    }
534}
535
536// ===== impl Continuation =====
537
538impl Continuation {
539    fn head(&self) -> Head {
540        Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
541    }
542
543    pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> {
544        // Get the CONTINUATION frame head
545        let head = self.head();
546
547        self.header_block.encode(&head, dst, |_| {})
548    }
549}
550
551// ===== impl Pseudo =====
552
553impl Pseudo {
554    pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
555        let parts = uri::Parts::from(uri);
556
557        let (scheme, path) = if method == Method::CONNECT && protocol.is_none() {
558            (None, None)
559        } else {
560            let path = parts
561                .path_and_query
562                .map(|v| BytesStr::from(v.as_str()))
563                .unwrap_or(BytesStr::from_static(""));
564
565            let path = if !path.is_empty() {
566                path
567            } else if method == Method::OPTIONS {
568                BytesStr::from_static("*")
569            } else {
570                BytesStr::from_static("/")
571            };
572
573            (parts.scheme, Some(path))
574        };
575
576        let mut pseudo = Pseudo {
577            method: Some(method),
578            scheme: None,
579            authority: None,
580            path,
581            protocol,
582            status: None,
583        };
584
585        // If the URI includes a scheme component, add it to the pseudo headers
586        if let Some(scheme) = scheme {
587            pseudo.set_scheme(scheme);
588        }
589
590        // If the URI includes an authority component, add it to the pseudo
591        // headers
592        if let Some(authority) = parts.authority {
593            pseudo.set_authority(BytesStr::from(authority.as_str()));
594        }
595
596        pseudo
597    }
598
599    pub fn response(status: StatusCode) -> Self {
600        Pseudo {
601            method: None,
602            scheme: None,
603            authority: None,
604            path: None,
605            protocol: None,
606            status: Some(status),
607        }
608    }
609
610    #[cfg(feature = "unstable")]
611    pub fn set_status(&mut self, value: StatusCode) {
612        self.status = Some(value);
613    }
614
615    pub fn set_scheme(&mut self, scheme: uri::Scheme) {
616        let bytes_str = match scheme.as_str() {
617            "http" => BytesStr::from_static("http"),
618            "https" => BytesStr::from_static("https"),
619            s => BytesStr::from(s),
620        };
621        self.scheme = Some(bytes_str);
622    }
623
624    #[cfg(feature = "unstable")]
625    pub fn set_protocol(&mut self, protocol: Protocol) {
626        self.protocol = Some(protocol);
627    }
628
629    pub fn set_authority(&mut self, authority: BytesStr) {
630        self.authority = Some(authority);
631    }
632
633    /// Whether it has status 1xx
634    pub(crate) fn is_informational(&self) -> bool {
635        self.status
636            .map_or(false, |status| status.is_informational())
637    }
638}
639
640// ===== impl EncodingHeaderBlock =====
641
642impl EncodingHeaderBlock {
643    fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation>
644    where
645        F: FnOnce(&mut EncodeBuf<'_>),
646    {
647        let head_pos = dst.get_ref().len();
648
649        // At this point, we don't know how big the h2 frame will be.
650        // So, we write the head with length 0, then write the body, and
651        // finally write the length once we know the size.
652        head.encode(0, dst);
653
654        let payload_pos = dst.get_ref().len();
655
656        f(dst);
657
658        // Now, encode the header payload
659        let continuation = if self.hpack.len() > dst.remaining_mut() {
660            dst.put_slice(&self.hpack.split_to(dst.remaining_mut()));
661
662            Some(Continuation {
663                stream_id: head.stream_id(),
664                header_block: self,
665            })
666        } else {
667            dst.put_slice(&self.hpack);
668
669            None
670        };
671
672        // Compute the header block length
673        let payload_len = (dst.get_ref().len() - payload_pos) as u64;
674
675        // Write the frame length
676        let payload_len_be = payload_len.to_be_bytes();
677        assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
678        (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
679
680        if continuation.is_some() {
681            // There will be continuation frames, so the `is_end_headers` flag
682            // must be unset
683            debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
684
685            dst.get_mut()[head_pos + 4] -= END_HEADERS;
686        }
687
688        continuation
689    }
690}
691
692// ===== impl Iter =====
693
694impl Iterator for Iter {
695    type Item = hpack::Header<Option<HeaderName>>;
696
697    fn next(&mut self) -> Option<Self::Item> {
698        use crate::hpack::Header::*;
699
700        if let Some(ref mut pseudo) = self.pseudo {
701            if let Some(method) = pseudo.method.take() {
702                return Some(Method(method));
703            }
704
705            if let Some(scheme) = pseudo.scheme.take() {
706                return Some(Scheme(scheme));
707            }
708
709            if let Some(authority) = pseudo.authority.take() {
710                return Some(Authority(authority));
711            }
712
713            if let Some(path) = pseudo.path.take() {
714                return Some(Path(path));
715            }
716
717            if let Some(protocol) = pseudo.protocol.take() {
718                return Some(Protocol(protocol));
719            }
720
721            if let Some(status) = pseudo.status.take() {
722                return Some(Status(status));
723            }
724        }
725
726        self.pseudo = None;
727
728        self.fields
729            .next()
730            .map(|(name, value)| Field { name, value })
731    }
732}
733
734// ===== impl HeadersFlag =====
735
736impl HeadersFlag {
737    pub fn empty() -> HeadersFlag {
738        HeadersFlag(0)
739    }
740
741    pub fn load(bits: u8) -> HeadersFlag {
742        HeadersFlag(bits & ALL)
743    }
744
745    pub fn is_end_stream(&self) -> bool {
746        self.0 & END_STREAM == END_STREAM
747    }
748
749    pub fn set_end_stream(&mut self) {
750        self.0 |= END_STREAM;
751    }
752
753    pub fn is_end_headers(&self) -> bool {
754        self.0 & END_HEADERS == END_HEADERS
755    }
756
757    pub fn set_end_headers(&mut self) {
758        self.0 |= END_HEADERS;
759    }
760
761    pub fn is_padded(&self) -> bool {
762        self.0 & PADDED == PADDED
763    }
764
765    pub fn is_priority(&self) -> bool {
766        self.0 & PRIORITY == PRIORITY
767    }
768}
769
770impl Default for HeadersFlag {
771    /// Returns a `HeadersFlag` value with `END_HEADERS` set.
772    fn default() -> Self {
773        HeadersFlag(END_HEADERS)
774    }
775}
776
777impl From<HeadersFlag> for u8 {
778    fn from(src: HeadersFlag) -> u8 {
779        src.0
780    }
781}
782
783impl fmt::Debug for HeadersFlag {
784    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
785        util::debug_flags(fmt, self.0)
786            .flag_if(self.is_end_headers(), "END_HEADERS")
787            .flag_if(self.is_end_stream(), "END_STREAM")
788            .flag_if(self.is_padded(), "PADDED")
789            .flag_if(self.is_priority(), "PRIORITY")
790            .finish()
791    }
792}
793
794// ===== impl PushPromiseFlag =====
795
796impl PushPromiseFlag {
797    pub fn empty() -> PushPromiseFlag {
798        PushPromiseFlag(0)
799    }
800
801    pub fn load(bits: u8) -> PushPromiseFlag {
802        PushPromiseFlag(bits & ALL)
803    }
804
805    pub fn is_end_headers(&self) -> bool {
806        self.0 & END_HEADERS == END_HEADERS
807    }
808
809    pub fn set_end_headers(&mut self) {
810        self.0 |= END_HEADERS;
811    }
812
813    pub fn is_padded(&self) -> bool {
814        self.0 & PADDED == PADDED
815    }
816}
817
818impl Default for PushPromiseFlag {
819    /// Returns a `PushPromiseFlag` value with `END_HEADERS` set.
820    fn default() -> Self {
821        PushPromiseFlag(END_HEADERS)
822    }
823}
824
825impl From<PushPromiseFlag> for u8 {
826    fn from(src: PushPromiseFlag) -> u8 {
827        src.0
828    }
829}
830
831impl fmt::Debug for PushPromiseFlag {
832    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
833        util::debug_flags(fmt, self.0)
834            .flag_if(self.is_end_headers(), "END_HEADERS")
835            .flag_if(self.is_padded(), "PADDED")
836            .finish()
837    }
838}
839
840// ===== HeaderBlock =====
841
842impl HeaderBlock {
843    fn load(
844        &mut self,
845        src: &mut BytesMut,
846        max_header_list_size: usize,
847        decoder: &mut hpack::Decoder,
848    ) -> Result<(), Error> {
849        let mut reg = !self.fields.is_empty();
850        let mut malformed = false;
851        let mut headers_size = self.calculate_header_list_size();
852
853        macro_rules! set_pseudo {
854            ($field:ident, $val:expr) => {{
855                if reg {
856                    tracing::trace!("load_hpack; header malformed -- pseudo not at head of block");
857                    malformed = true;
858                } else if self.pseudo.$field.is_some() {
859                    tracing::trace!("load_hpack; header malformed -- repeated pseudo");
860                    malformed = true;
861                } else {
862                    let __val = $val;
863                    headers_size +=
864                        decoded_header_size(stringify!($field).len() + 1, __val.as_str().len());
865                    if headers_size < max_header_list_size {
866                        self.pseudo.$field = Some(__val);
867                    } else if !self.is_over_size {
868                        tracing::trace!("load_hpack; header list size over max");
869                        self.is_over_size = true;
870                    }
871                }
872            }};
873        }
874
875        let mut cursor = Cursor::new(src);
876
877        // If the header frame is malformed, we still have to continue decoding
878        // the headers. A malformed header frame is a stream level error, but
879        // the hpack state is connection level. In order to maintain correct
880        // state for other streams, the hpack decoding process must complete.
881        let res = decoder.decode(&mut cursor, |header| {
882            use crate::hpack::Header::*;
883
884            match header {
885                Field { name, value } => {
886                    // Connection level header fields are not supported and must
887                    // result in a protocol error.
888
889                    if name == header::CONNECTION
890                        || name == header::TRANSFER_ENCODING
891                        || name == header::UPGRADE
892                        || name == "keep-alive"
893                        || name == "proxy-connection"
894                    {
895                        tracing::trace!("load_hpack; connection level header");
896                        malformed = true;
897                    } else if name == header::TE && value != "trailers" {
898                        tracing::trace!(
899                            "load_hpack; TE header not set to trailers; val={:?}",
900                            value
901                        );
902                        malformed = true;
903                    } else {
904                        reg = true;
905
906                        headers_size += decoded_header_size(name.as_str().len(), value.len());
907                        if headers_size < max_header_list_size {
908                            self.field_size +=
909                                decoded_header_size(name.as_str().len(), value.len());
910                            self.fields.append(name, value);
911                        } else if !self.is_over_size {
912                            tracing::trace!("load_hpack; header list size over max");
913                            self.is_over_size = true;
914                        }
915                    }
916                }
917                Authority(v) => set_pseudo!(authority, v),
918                Method(v) => set_pseudo!(method, v),
919                Scheme(v) => set_pseudo!(scheme, v),
920                Path(v) => set_pseudo!(path, v),
921                Protocol(v) => set_pseudo!(protocol, v),
922                Status(v) => set_pseudo!(status, v),
923            }
924        });
925
926        if let Err(e) = res {
927            tracing::trace!("hpack decoding error; err={:?}", e);
928            return Err(e.into());
929        }
930
931        if malformed {
932            tracing::trace!("malformed message");
933            return Err(Error::MalformedMessage);
934        }
935
936        Ok(())
937    }
938
939    fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock {
940        let mut hpack = BytesMut::new();
941        let headers = Iter {
942            pseudo: Some(self.pseudo),
943            fields: self.fields.into_iter(),
944        };
945
946        encoder.encode(headers, &mut hpack);
947
948        EncodingHeaderBlock {
949            hpack: hpack.freeze(),
950        }
951    }
952
953    /// Calculates the size of the currently decoded header list.
954    ///
955    /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
956    ///
957    /// > The value is based on the uncompressed size of header fields,
958    /// > including the length of the name and value in octets plus an
959    /// > overhead of 32 octets for each header field.
960    fn calculate_header_list_size(&self) -> usize {
961        macro_rules! pseudo_size {
962            ($name:ident) => {{
963                self.pseudo
964                    .$name
965                    .as_ref()
966                    .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
967                    .unwrap_or(0)
968            }};
969        }
970
971        pseudo_size!(method)
972            + pseudo_size!(scheme)
973            + pseudo_size!(status)
974            + pseudo_size!(authority)
975            + pseudo_size!(path)
976            + self.field_size
977    }
978}
979
980fn calculate_headermap_size(map: &HeaderMap) -> usize {
981    map.iter()
982        .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
983        .sum::<usize>()
984}
985
986fn decoded_header_size(name: usize, value: usize) -> usize {
987    name + value + 32
988}
989
990#[cfg(test)]
991mod test {
992    use super::*;
993    use crate::frame;
994    use crate::hpack::{huffman, Encoder};
995
996    #[test]
997    fn test_nameless_header_at_resume() {
998        let mut encoder = Encoder::default();
999        let mut dst = BytesMut::new();
1000
1001        let headers = Headers::new(
1002            StreamId::ZERO,
1003            Default::default(),
1004            HeaderMap::from_iter(vec![
1005                (
1006                    HeaderName::from_static("hello"),
1007                    HeaderValue::from_static("world"),
1008                ),
1009                (
1010                    HeaderName::from_static("hello"),
1011                    HeaderValue::from_static("zomg"),
1012                ),
1013                (
1014                    HeaderName::from_static("hello"),
1015                    HeaderValue::from_static("sup"),
1016                ),
1017            ]),
1018        );
1019
1020        let continuation = headers
1021            .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8))
1022            .unwrap();
1023
1024        assert_eq!(17, dst.len());
1025        assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
1026        assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
1027        assert_eq!("hello", huff_decode(&dst[11..15]));
1028        assert_eq!(0x80 | 4, dst[15]);
1029
1030        let mut world = dst[16..17].to_owned();
1031
1032        dst.clear();
1033
1034        assert!(continuation
1035            .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16))
1036            .is_none());
1037
1038        world.extend_from_slice(&dst[9..12]);
1039        assert_eq!("world", huff_decode(&world));
1040
1041        assert_eq!(24, dst.len());
1042        assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]);
1043
1044        // // Next is not indexed
1045        assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
1046        assert_eq!("zomg", huff_decode(&dst[15..18]));
1047        assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
1048        assert_eq!("sup", huff_decode(&dst[21..]));
1049    }
1050
1051    fn huff_decode(src: &[u8]) -> BytesMut {
1052        let mut buf = BytesMut::new();
1053        huffman::decode(src, &mut buf).unwrap()
1054    }
1055
1056    #[test]
1057    fn test_connect_request_pseudo_headers_omits_path_and_scheme() {
1058        // CONNECT requests MUST NOT include :scheme & :path pseudo-header fields
1059        // See: https://datatracker.ietf.org/doc/html/rfc9113#section-8.5
1060
1061        assert_eq!(
1062            Pseudo::request(
1063                Method::CONNECT,
1064                Uri::from_static("https://example.com:8443"),
1065                None
1066            ),
1067            Pseudo {
1068                method: Method::CONNECT.into(),
1069                authority: BytesStr::from_static("example.com:8443").into(),
1070                ..Default::default()
1071            }
1072        );
1073
1074        assert_eq!(
1075            Pseudo::request(
1076                Method::CONNECT,
1077                Uri::from_static("https://example.com/test"),
1078                None
1079            ),
1080            Pseudo {
1081                method: Method::CONNECT.into(),
1082                authority: BytesStr::from_static("example.com").into(),
1083                ..Default::default()
1084            }
1085        );
1086
1087        assert_eq!(
1088            Pseudo::request(Method::CONNECT, Uri::from_static("example.com:8443"), None),
1089            Pseudo {
1090                method: Method::CONNECT.into(),
1091                authority: BytesStr::from_static("example.com:8443").into(),
1092                ..Default::default()
1093            }
1094        );
1095    }
1096
1097    #[test]
1098    fn test_extended_connect_request_pseudo_headers_includes_path_and_scheme() {
1099        // On requests that contain the :protocol pseudo-header field, the
1100        // :scheme and :path pseudo-header fields of the target URI (see
1101        // Section 5) MUST also be included.
1102        // See: https://datatracker.ietf.org/doc/html/rfc8441#section-4
1103
1104        assert_eq!(
1105            Pseudo::request(
1106                Method::CONNECT,
1107                Uri::from_static("https://example.com:8443"),
1108                Protocol::from_static("the-bread-protocol").into()
1109            ),
1110            Pseudo {
1111                method: Method::CONNECT.into(),
1112                authority: BytesStr::from_static("example.com:8443").into(),
1113                scheme: BytesStr::from_static("https").into(),
1114                path: BytesStr::from_static("/").into(),
1115                protocol: Protocol::from_static("the-bread-protocol").into(),
1116                ..Default::default()
1117            }
1118        );
1119
1120        assert_eq!(
1121            Pseudo::request(
1122                Method::CONNECT,
1123                Uri::from_static("https://example.com:8443/test"),
1124                Protocol::from_static("the-bread-protocol").into()
1125            ),
1126            Pseudo {
1127                method: Method::CONNECT.into(),
1128                authority: BytesStr::from_static("example.com:8443").into(),
1129                scheme: BytesStr::from_static("https").into(),
1130                path: BytesStr::from_static("/test").into(),
1131                protocol: Protocol::from_static("the-bread-protocol").into(),
1132                ..Default::default()
1133            }
1134        );
1135
1136        assert_eq!(
1137            Pseudo::request(
1138                Method::CONNECT,
1139                Uri::from_static("http://example.com/a/b/c"),
1140                Protocol::from_static("the-bread-protocol").into()
1141            ),
1142            Pseudo {
1143                method: Method::CONNECT.into(),
1144                authority: BytesStr::from_static("example.com").into(),
1145                scheme: BytesStr::from_static("http").into(),
1146                path: BytesStr::from_static("/a/b/c").into(),
1147                protocol: Protocol::from_static("the-bread-protocol").into(),
1148                ..Default::default()
1149            }
1150        );
1151    }
1152
1153    #[test]
1154    fn test_options_request_with_empty_path_has_asterisk_as_pseudo_path() {
1155        // an OPTIONS request for an "http" or "https" URI that does not include a path component;
1156        // these MUST include a ":path" pseudo-header field with a value of '*' (see Section 7.1 of [HTTP]).
1157        // See: https://datatracker.ietf.org/doc/html/rfc9113#section-8.3.1
1158        assert_eq!(
1159            Pseudo::request(Method::OPTIONS, Uri::from_static("example.com:8080"), None,),
1160            Pseudo {
1161                method: Method::OPTIONS.into(),
1162                authority: BytesStr::from_static("example.com:8080").into(),
1163                path: BytesStr::from_static("*").into(),
1164                ..Default::default()
1165            }
1166        );
1167    }
1168
1169    #[test]
1170    fn test_non_option_and_non_connect_requests_include_path_and_scheme() {
1171        let methods = [
1172            Method::GET,
1173            Method::POST,
1174            Method::PUT,
1175            Method::DELETE,
1176            Method::HEAD,
1177            Method::PATCH,
1178            Method::TRACE,
1179        ];
1180
1181        for method in methods {
1182            assert_eq!(
1183                Pseudo::request(
1184                    method.clone(),
1185                    Uri::from_static("http://example.com:8080"),
1186                    None,
1187                ),
1188                Pseudo {
1189                    method: method.clone().into(),
1190                    authority: BytesStr::from_static("example.com:8080").into(),
1191                    scheme: BytesStr::from_static("http").into(),
1192                    path: BytesStr::from_static("/").into(),
1193                    ..Default::default()
1194                }
1195            );
1196            assert_eq!(
1197                Pseudo::request(
1198                    method.clone(),
1199                    Uri::from_static("https://example.com/a/b/c"),
1200                    None,
1201                ),
1202                Pseudo {
1203                    method: method.into(),
1204                    authority: BytesStr::from_static("example.com").into(),
1205                    scheme: BytesStr::from_static("https").into(),
1206                    path: BytesStr::from_static("/a/b/c").into(),
1207                    ..Default::default()
1208                }
1209            );
1210        }
1211    }
1212}