rustls/msgs/
handshake.rs

1use alloc::collections::BTreeSet;
2#[cfg(feature = "logging")]
3use alloc::string::String;
4use alloc::vec;
5use alloc::vec::Vec;
6use core::ops::Deref;
7use core::{fmt, iter};
8
9use pki_types::{CertificateDer, DnsName};
10
11#[cfg(feature = "tls12")]
12use crate::crypto::ActiveKeyExchange;
13use crate::crypto::SecureRandom;
14use crate::enums::{
15    CertificateCompressionAlgorithm, CipherSuite, EchClientHelloType, HandshakeType,
16    ProtocolVersion, SignatureScheme,
17};
18use crate::error::InvalidMessage;
19#[cfg(feature = "tls12")]
20use crate::ffdhe_groups::FfdheGroup;
21use crate::log::warn;
22use crate::msgs::base::{Payload, PayloadU16, PayloadU24, PayloadU8};
23use crate::msgs::codec::{self, Codec, LengthPrefixedBuffer, ListLength, Reader, TlsListElement};
24use crate::msgs::enums::{
25    CertificateStatusType, ClientCertificateType, Compression, ECCurveType, ECPointFormat,
26    EchVersion, ExtensionType, HpkeAead, HpkeKdf, HpkeKem, KeyUpdateRequest, NamedGroup,
27    PSKKeyExchangeMode, ServerNameType,
28};
29use crate::rand;
30use crate::verify::DigitallySignedStruct;
31use crate::x509::wrap_in_sequence;
32
33/// Create a newtype wrapper around a given type.
34///
35/// This is used to create newtypes for the various TLS message types which is used to wrap
36/// the `PayloadU8` or `PayloadU16` types. This is typically used for types where we don't need
37/// anything other than access to the underlying bytes.
38macro_rules! wrapped_payload(
39  ($(#[$comment:meta])* $vis:vis struct $name:ident, $inner:ident,) => {
40    $(#[$comment])*
41    #[derive(Clone, Debug)]
42    $vis struct $name($inner);
43
44    impl From<Vec<u8>> for $name {
45        fn from(v: Vec<u8>) -> Self {
46            Self($inner::new(v))
47        }
48    }
49
50    impl AsRef<[u8]> for $name {
51        fn as_ref(&self) -> &[u8] {
52            self.0.0.as_slice()
53        }
54    }
55
56    impl Codec<'_> for $name {
57        fn encode(&self, bytes: &mut Vec<u8>) {
58            self.0.encode(bytes);
59        }
60
61        fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
62            Ok(Self($inner::read(r)?))
63        }
64    }
65  }
66);
67
68#[derive(Clone, Copy, Eq, PartialEq)]
69pub struct Random(pub(crate) [u8; 32]);
70
71impl fmt::Debug for Random {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        super::base::hex(f, &self.0)
74    }
75}
76
77static HELLO_RETRY_REQUEST_RANDOM: Random = Random([
78    0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
79    0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
80]);
81
82static ZERO_RANDOM: Random = Random([0u8; 32]);
83
84impl Codec<'_> for Random {
85    fn encode(&self, bytes: &mut Vec<u8>) {
86        bytes.extend_from_slice(&self.0);
87    }
88
89    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
90        let bytes = match r.take(32) {
91            Some(bytes) => bytes,
92            None => return Err(InvalidMessage::MissingData("Random")),
93        };
94
95        let mut opaque = [0; 32];
96        opaque.clone_from_slice(bytes);
97        Ok(Self(opaque))
98    }
99}
100
101impl Random {
102    pub(crate) fn new(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
103        let mut data = [0u8; 32];
104        secure_random.fill(&mut data)?;
105        Ok(Self(data))
106    }
107}
108
109impl From<[u8; 32]> for Random {
110    #[inline]
111    fn from(bytes: [u8; 32]) -> Self {
112        Self(bytes)
113    }
114}
115
116#[derive(Copy, Clone)]
117pub struct SessionId {
118    len: usize,
119    data: [u8; 32],
120}
121
122impl fmt::Debug for SessionId {
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        super::base::hex(f, &self.data[..self.len])
125    }
126}
127
128impl PartialEq for SessionId {
129    fn eq(&self, other: &Self) -> bool {
130        if self.len != other.len {
131            return false;
132        }
133
134        let mut diff = 0u8;
135        for i in 0..self.len {
136            diff |= self.data[i] ^ other.data[i];
137        }
138
139        diff == 0u8
140    }
141}
142
143impl Codec<'_> for SessionId {
144    fn encode(&self, bytes: &mut Vec<u8>) {
145        debug_assert!(self.len <= 32);
146        bytes.push(self.len as u8);
147        bytes.extend_from_slice(&self.data[..self.len]);
148    }
149
150    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
151        let len = u8::read(r)? as usize;
152        if len > 32 {
153            return Err(InvalidMessage::TrailingData("SessionID"));
154        }
155
156        let bytes = match r.take(len) {
157            Some(bytes) => bytes,
158            None => return Err(InvalidMessage::MissingData("SessionID")),
159        };
160
161        let mut out = [0u8; 32];
162        out[..len].clone_from_slice(&bytes[..len]);
163        Ok(Self { data: out, len })
164    }
165}
166
167impl SessionId {
168    pub fn random(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
169        let mut data = [0u8; 32];
170        secure_random.fill(&mut data)?;
171        Ok(Self { data, len: 32 })
172    }
173
174    pub(crate) fn empty() -> Self {
175        Self {
176            data: [0u8; 32],
177            len: 0,
178        }
179    }
180
181    #[cfg(feature = "tls12")]
182    pub(crate) fn is_empty(&self) -> bool {
183        self.len == 0
184    }
185}
186
187#[derive(Clone, Debug, PartialEq)]
188pub struct UnknownExtension {
189    pub(crate) typ: ExtensionType,
190    pub(crate) payload: Payload<'static>,
191}
192
193impl UnknownExtension {
194    fn encode(&self, bytes: &mut Vec<u8>) {
195        self.payload.encode(bytes);
196    }
197
198    fn read(typ: ExtensionType, r: &mut Reader<'_>) -> Self {
199        let payload = Payload::read(r).into_owned();
200        Self { typ, payload }
201    }
202}
203
204impl TlsListElement for ECPointFormat {
205    const SIZE_LEN: ListLength = ListLength::U8;
206}
207
208impl TlsListElement for NamedGroup {
209    const SIZE_LEN: ListLength = ListLength::U16;
210}
211
212impl TlsListElement for SignatureScheme {
213    const SIZE_LEN: ListLength = ListLength::U16;
214}
215
216#[derive(Clone, Debug)]
217pub(crate) enum ServerNamePayload {
218    HostName(DnsName<'static>),
219    IpAddress(PayloadU16),
220    Unknown(Payload<'static>),
221}
222
223impl ServerNamePayload {
224    pub(crate) fn new_hostname(hostname: DnsName<'static>) -> Self {
225        Self::HostName(hostname)
226    }
227
228    fn read_hostname(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
229        use pki_types::ServerName;
230        let raw = PayloadU16::read(r)?;
231
232        match ServerName::try_from(raw.0.as_slice()) {
233            Ok(ServerName::DnsName(d)) => Ok(Self::HostName(d.to_owned())),
234            Ok(ServerName::IpAddress(_)) => Ok(Self::IpAddress(raw)),
235            Ok(_) | Err(_) => {
236                warn!(
237                    "Illegal SNI hostname received {:?}",
238                    String::from_utf8_lossy(&raw.0)
239                );
240                Err(InvalidMessage::InvalidServerName)
241            }
242        }
243    }
244
245    fn encode(&self, bytes: &mut Vec<u8>) {
246        match *self {
247            Self::HostName(ref name) => {
248                (name.as_ref().len() as u16).encode(bytes);
249                bytes.extend_from_slice(name.as_ref().as_bytes());
250            }
251            Self::IpAddress(ref r) => r.encode(bytes),
252            Self::Unknown(ref r) => r.encode(bytes),
253        }
254    }
255}
256
257#[derive(Clone, Debug)]
258pub struct ServerName {
259    pub(crate) typ: ServerNameType,
260    pub(crate) payload: ServerNamePayload,
261}
262
263impl Codec<'_> for ServerName {
264    fn encode(&self, bytes: &mut Vec<u8>) {
265        self.typ.encode(bytes);
266        self.payload.encode(bytes);
267    }
268
269    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
270        let typ = ServerNameType::read(r)?;
271
272        let payload = match typ {
273            ServerNameType::HostName => ServerNamePayload::read_hostname(r)?,
274            _ => ServerNamePayload::Unknown(Payload::read(r).into_owned()),
275        };
276
277        Ok(Self { typ, payload })
278    }
279}
280
281impl TlsListElement for ServerName {
282    const SIZE_LEN: ListLength = ListLength::U16;
283}
284
285pub(crate) trait ConvertServerNameList {
286    fn has_duplicate_names_for_type(&self) -> bool;
287    fn single_hostname(&self) -> Option<DnsName<'_>>;
288}
289
290impl ConvertServerNameList for [ServerName] {
291    /// RFC6066: "The ServerNameList MUST NOT contain more than one name of the same name_type."
292    fn has_duplicate_names_for_type(&self) -> bool {
293        has_duplicates::<_, _, u8>(self.iter().map(|name| name.typ))
294    }
295
296    fn single_hostname(&self) -> Option<DnsName<'_>> {
297        fn only_dns_hostnames(name: &ServerName) -> Option<DnsName<'_>> {
298            if let ServerNamePayload::HostName(ref dns) = name.payload {
299                Some(dns.borrow())
300            } else {
301                None
302            }
303        }
304
305        self.iter()
306            .filter_map(only_dns_hostnames)
307            .next()
308    }
309}
310
311wrapped_payload!(pub struct ProtocolName, PayloadU8,);
312
313impl TlsListElement for ProtocolName {
314    const SIZE_LEN: ListLength = ListLength::U16;
315}
316
317pub(crate) trait ConvertProtocolNameList {
318    fn from_slices(names: &[&[u8]]) -> Self;
319    fn to_slices(&self) -> Vec<&[u8]>;
320    fn as_single_slice(&self) -> Option<&[u8]>;
321}
322
323impl ConvertProtocolNameList for Vec<ProtocolName> {
324    fn from_slices(names: &[&[u8]]) -> Self {
325        let mut ret = Self::new();
326
327        for name in names {
328            ret.push(ProtocolName::from(name.to_vec()));
329        }
330
331        ret
332    }
333
334    fn to_slices(&self) -> Vec<&[u8]> {
335        self.iter()
336            .map(|proto| proto.as_ref())
337            .collect::<Vec<&[u8]>>()
338    }
339
340    fn as_single_slice(&self) -> Option<&[u8]> {
341        if self.len() == 1 {
342            Some(self[0].as_ref())
343        } else {
344            None
345        }
346    }
347}
348
349// --- TLS 1.3 Key shares ---
350#[derive(Clone, Debug)]
351pub struct KeyShareEntry {
352    pub(crate) group: NamedGroup,
353    pub(crate) payload: PayloadU16,
354}
355
356impl KeyShareEntry {
357    pub fn new(group: NamedGroup, payload: impl Into<Vec<u8>>) -> Self {
358        Self {
359            group,
360            payload: PayloadU16::new(payload.into()),
361        }
362    }
363
364    pub fn group(&self) -> NamedGroup {
365        self.group
366    }
367}
368
369impl Codec<'_> for KeyShareEntry {
370    fn encode(&self, bytes: &mut Vec<u8>) {
371        self.group.encode(bytes);
372        self.payload.encode(bytes);
373    }
374
375    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
376        let group = NamedGroup::read(r)?;
377        let payload = PayloadU16::read(r)?;
378
379        Ok(Self { group, payload })
380    }
381}
382
383// --- TLS 1.3 PresharedKey offers ---
384#[derive(Clone, Debug)]
385pub(crate) struct PresharedKeyIdentity {
386    pub(crate) identity: PayloadU16,
387    pub(crate) obfuscated_ticket_age: u32,
388}
389
390impl PresharedKeyIdentity {
391    pub(crate) fn new(id: Vec<u8>, age: u32) -> Self {
392        Self {
393            identity: PayloadU16::new(id),
394            obfuscated_ticket_age: age,
395        }
396    }
397}
398
399impl Codec<'_> for PresharedKeyIdentity {
400    fn encode(&self, bytes: &mut Vec<u8>) {
401        self.identity.encode(bytes);
402        self.obfuscated_ticket_age.encode(bytes);
403    }
404
405    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
406        Ok(Self {
407            identity: PayloadU16::read(r)?,
408            obfuscated_ticket_age: u32::read(r)?,
409        })
410    }
411}
412
413impl TlsListElement for PresharedKeyIdentity {
414    const SIZE_LEN: ListLength = ListLength::U16;
415}
416
417wrapped_payload!(pub(crate) struct PresharedKeyBinder, PayloadU8,);
418
419impl TlsListElement for PresharedKeyBinder {
420    const SIZE_LEN: ListLength = ListLength::U16;
421}
422
423#[derive(Clone, Debug)]
424pub struct PresharedKeyOffer {
425    pub(crate) identities: Vec<PresharedKeyIdentity>,
426    pub(crate) binders: Vec<PresharedKeyBinder>,
427}
428
429impl PresharedKeyOffer {
430    /// Make a new one with one entry.
431    pub(crate) fn new(id: PresharedKeyIdentity, binder: Vec<u8>) -> Self {
432        Self {
433            identities: vec![id],
434            binders: vec![PresharedKeyBinder::from(binder)],
435        }
436    }
437}
438
439impl Codec<'_> for PresharedKeyOffer {
440    fn encode(&self, bytes: &mut Vec<u8>) {
441        self.identities.encode(bytes);
442        self.binders.encode(bytes);
443    }
444
445    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
446        Ok(Self {
447            identities: Vec::read(r)?,
448            binders: Vec::read(r)?,
449        })
450    }
451}
452
453// --- RFC6066 certificate status request ---
454wrapped_payload!(pub(crate) struct ResponderId, PayloadU16,);
455
456impl TlsListElement for ResponderId {
457    const SIZE_LEN: ListLength = ListLength::U16;
458}
459
460#[derive(Clone, Debug)]
461pub struct OcspCertificateStatusRequest {
462    pub(crate) responder_ids: Vec<ResponderId>,
463    pub(crate) extensions: PayloadU16,
464}
465
466impl Codec<'_> for OcspCertificateStatusRequest {
467    fn encode(&self, bytes: &mut Vec<u8>) {
468        CertificateStatusType::OCSP.encode(bytes);
469        self.responder_ids.encode(bytes);
470        self.extensions.encode(bytes);
471    }
472
473    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
474        Ok(Self {
475            responder_ids: Vec::read(r)?,
476            extensions: PayloadU16::read(r)?,
477        })
478    }
479}
480
481#[derive(Clone, Debug)]
482pub enum CertificateStatusRequest {
483    Ocsp(OcspCertificateStatusRequest),
484    Unknown((CertificateStatusType, Payload<'static>)),
485}
486
487impl Codec<'_> for CertificateStatusRequest {
488    fn encode(&self, bytes: &mut Vec<u8>) {
489        match self {
490            Self::Ocsp(ref r) => r.encode(bytes),
491            Self::Unknown((typ, payload)) => {
492                typ.encode(bytes);
493                payload.encode(bytes);
494            }
495        }
496    }
497
498    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
499        let typ = CertificateStatusType::read(r)?;
500
501        match typ {
502            CertificateStatusType::OCSP => {
503                let ocsp_req = OcspCertificateStatusRequest::read(r)?;
504                Ok(Self::Ocsp(ocsp_req))
505            }
506            _ => {
507                let data = Payload::read(r).into_owned();
508                Ok(Self::Unknown((typ, data)))
509            }
510        }
511    }
512}
513
514impl CertificateStatusRequest {
515    pub(crate) fn build_ocsp() -> Self {
516        let ocsp = OcspCertificateStatusRequest {
517            responder_ids: Vec::new(),
518            extensions: PayloadU16::empty(),
519        };
520        Self::Ocsp(ocsp)
521    }
522}
523
524// ---
525
526impl TlsListElement for PSKKeyExchangeMode {
527    const SIZE_LEN: ListLength = ListLength::U8;
528}
529
530impl TlsListElement for KeyShareEntry {
531    const SIZE_LEN: ListLength = ListLength::U16;
532}
533
534impl TlsListElement for ProtocolVersion {
535    const SIZE_LEN: ListLength = ListLength::U8;
536}
537
538impl TlsListElement for CertificateCompressionAlgorithm {
539    const SIZE_LEN: ListLength = ListLength::U8;
540}
541
542#[derive(Clone, Debug)]
543pub enum ClientExtension {
544    EcPointFormats(Vec<ECPointFormat>),
545    NamedGroups(Vec<NamedGroup>),
546    SignatureAlgorithms(Vec<SignatureScheme>),
547    ServerName(Vec<ServerName>),
548    SessionTicket(ClientSessionTicket),
549    Protocols(Vec<ProtocolName>),
550    SupportedVersions(Vec<ProtocolVersion>),
551    KeyShare(Vec<KeyShareEntry>),
552    PresharedKeyModes(Vec<PSKKeyExchangeMode>),
553    PresharedKey(PresharedKeyOffer),
554    Cookie(PayloadU16),
555    ExtendedMasterSecretRequest,
556    CertificateStatusRequest(CertificateStatusRequest),
557    TransportParameters(Vec<u8>),
558    TransportParametersDraft(Vec<u8>),
559    EarlyData,
560    CertificateCompressionAlgorithms(Vec<CertificateCompressionAlgorithm>),
561    EncryptedClientHello(EncryptedClientHello),
562    EncryptedClientHelloOuterExtensions(Vec<ExtensionType>),
563    Unknown(UnknownExtension),
564}
565
566impl ClientExtension {
567    pub(crate) fn ext_type(&self) -> ExtensionType {
568        match *self {
569            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
570            Self::NamedGroups(_) => ExtensionType::EllipticCurves,
571            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
572            Self::ServerName(_) => ExtensionType::ServerName,
573            Self::SessionTicket(_) => ExtensionType::SessionTicket,
574            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
575            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
576            Self::KeyShare(_) => ExtensionType::KeyShare,
577            Self::PresharedKeyModes(_) => ExtensionType::PSKKeyExchangeModes,
578            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
579            Self::Cookie(_) => ExtensionType::Cookie,
580            Self::ExtendedMasterSecretRequest => ExtensionType::ExtendedMasterSecret,
581            Self::CertificateStatusRequest(_) => ExtensionType::StatusRequest,
582            Self::TransportParameters(_) => ExtensionType::TransportParameters,
583            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
584            Self::EarlyData => ExtensionType::EarlyData,
585            Self::CertificateCompressionAlgorithms(_) => ExtensionType::CompressCertificate,
586            Self::EncryptedClientHello(_) => ExtensionType::EncryptedClientHello,
587            Self::EncryptedClientHelloOuterExtensions(_) => {
588                ExtensionType::EncryptedClientHelloOuterExtensions
589            }
590            Self::Unknown(ref r) => r.typ,
591        }
592    }
593}
594
595impl Codec<'_> for ClientExtension {
596    fn encode(&self, bytes: &mut Vec<u8>) {
597        self.ext_type().encode(bytes);
598
599        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
600        match *self {
601            Self::EcPointFormats(ref r) => r.encode(nested.buf),
602            Self::NamedGroups(ref r) => r.encode(nested.buf),
603            Self::SignatureAlgorithms(ref r) => r.encode(nested.buf),
604            Self::ServerName(ref r) => r.encode(nested.buf),
605            Self::SessionTicket(ClientSessionTicket::Request)
606            | Self::ExtendedMasterSecretRequest
607            | Self::EarlyData => {}
608            Self::SessionTicket(ClientSessionTicket::Offer(ref r)) => r.encode(nested.buf),
609            Self::Protocols(ref r) => r.encode(nested.buf),
610            Self::SupportedVersions(ref r) => r.encode(nested.buf),
611            Self::KeyShare(ref r) => r.encode(nested.buf),
612            Self::PresharedKeyModes(ref r) => r.encode(nested.buf),
613            Self::PresharedKey(ref r) => r.encode(nested.buf),
614            Self::Cookie(ref r) => r.encode(nested.buf),
615            Self::CertificateStatusRequest(ref r) => r.encode(nested.buf),
616            Self::TransportParameters(ref r) | Self::TransportParametersDraft(ref r) => {
617                nested.buf.extend_from_slice(r);
618            }
619            Self::CertificateCompressionAlgorithms(ref r) => r.encode(nested.buf),
620            Self::EncryptedClientHello(ref r) => r.encode(nested.buf),
621            Self::EncryptedClientHelloOuterExtensions(ref r) => r.encode(nested.buf),
622            Self::Unknown(ref r) => r.encode(nested.buf),
623        }
624    }
625
626    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
627        let typ = ExtensionType::read(r)?;
628        let len = u16::read(r)? as usize;
629        let mut sub = r.sub(len)?;
630
631        let ext = match typ {
632            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
633            ExtensionType::EllipticCurves => Self::NamedGroups(Vec::read(&mut sub)?),
634            ExtensionType::SignatureAlgorithms => Self::SignatureAlgorithms(Vec::read(&mut sub)?),
635            ExtensionType::ServerName => Self::ServerName(Vec::read(&mut sub)?),
636            ExtensionType::SessionTicket => {
637                if sub.any_left() {
638                    let contents = Payload::read(&mut sub).into_owned();
639                    Self::SessionTicket(ClientSessionTicket::Offer(contents))
640                } else {
641                    Self::SessionTicket(ClientSessionTicket::Request)
642                }
643            }
644            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
645            ExtensionType::SupportedVersions => Self::SupportedVersions(Vec::read(&mut sub)?),
646            ExtensionType::KeyShare => Self::KeyShare(Vec::read(&mut sub)?),
647            ExtensionType::PSKKeyExchangeModes => Self::PresharedKeyModes(Vec::read(&mut sub)?),
648            ExtensionType::PreSharedKey => Self::PresharedKey(PresharedKeyOffer::read(&mut sub)?),
649            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
650            ExtensionType::ExtendedMasterSecret if !sub.any_left() => {
651                Self::ExtendedMasterSecretRequest
652            }
653            ExtensionType::StatusRequest => {
654                let csr = CertificateStatusRequest::read(&mut sub)?;
655                Self::CertificateStatusRequest(csr)
656            }
657            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
658            ExtensionType::TransportParametersDraft => {
659                Self::TransportParametersDraft(sub.rest().to_vec())
660            }
661            ExtensionType::EarlyData if !sub.any_left() => Self::EarlyData,
662            ExtensionType::CompressCertificate => {
663                Self::CertificateCompressionAlgorithms(Vec::read(&mut sub)?)
664            }
665            ExtensionType::EncryptedClientHelloOuterExtensions => {
666                Self::EncryptedClientHelloOuterExtensions(Vec::read(&mut sub)?)
667            }
668            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
669        };
670
671        sub.expect_empty("ClientExtension")
672            .map(|_| ext)
673    }
674}
675
676fn trim_hostname_trailing_dot_for_sni(dns_name: &DnsName<'_>) -> DnsName<'static> {
677    let dns_name_str = dns_name.as_ref();
678
679    // RFC6066: "The hostname is represented as a byte string using
680    // ASCII encoding without a trailing dot"
681    if dns_name_str.ends_with('.') {
682        let trimmed = &dns_name_str[0..dns_name_str.len() - 1];
683        DnsName::try_from(trimmed)
684            .unwrap()
685            .to_owned()
686    } else {
687        dns_name.to_owned()
688    }
689}
690
691impl ClientExtension {
692    /// Make a basic SNI ServerNameRequest quoting `hostname`.
693    pub(crate) fn make_sni(dns_name: &DnsName<'_>) -> Self {
694        let name = ServerName {
695            typ: ServerNameType::HostName,
696            payload: ServerNamePayload::new_hostname(trim_hostname_trailing_dot_for_sni(dns_name)),
697        };
698
699        Self::ServerName(vec![name])
700    }
701}
702
703#[derive(Clone, Debug)]
704pub enum ClientSessionTicket {
705    Request,
706    Offer(Payload<'static>),
707}
708
709#[derive(Clone, Debug)]
710pub enum ServerExtension {
711    EcPointFormats(Vec<ECPointFormat>),
712    ServerNameAck,
713    SessionTicketAck,
714    RenegotiationInfo(PayloadU8),
715    Protocols(Vec<ProtocolName>),
716    KeyShare(KeyShareEntry),
717    PresharedKey(u16),
718    ExtendedMasterSecretAck,
719    CertificateStatusAck,
720    SupportedVersions(ProtocolVersion),
721    TransportParameters(Vec<u8>),
722    TransportParametersDraft(Vec<u8>),
723    EarlyData,
724    EncryptedClientHello(ServerEncryptedClientHello),
725    Unknown(UnknownExtension),
726}
727
728impl ServerExtension {
729    pub(crate) fn ext_type(&self) -> ExtensionType {
730        match *self {
731            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
732            Self::ServerNameAck => ExtensionType::ServerName,
733            Self::SessionTicketAck => ExtensionType::SessionTicket,
734            Self::RenegotiationInfo(_) => ExtensionType::RenegotiationInfo,
735            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
736            Self::KeyShare(_) => ExtensionType::KeyShare,
737            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
738            Self::ExtendedMasterSecretAck => ExtensionType::ExtendedMasterSecret,
739            Self::CertificateStatusAck => ExtensionType::StatusRequest,
740            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
741            Self::TransportParameters(_) => ExtensionType::TransportParameters,
742            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
743            Self::EarlyData => ExtensionType::EarlyData,
744            Self::EncryptedClientHello(_) => ExtensionType::EncryptedClientHello,
745            Self::Unknown(ref r) => r.typ,
746        }
747    }
748}
749
750impl Codec<'_> for ServerExtension {
751    fn encode(&self, bytes: &mut Vec<u8>) {
752        self.ext_type().encode(bytes);
753
754        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
755        match *self {
756            Self::EcPointFormats(ref r) => r.encode(nested.buf),
757            Self::ServerNameAck
758            | Self::SessionTicketAck
759            | Self::ExtendedMasterSecretAck
760            | Self::CertificateStatusAck
761            | Self::EarlyData => {}
762            Self::RenegotiationInfo(ref r) => r.encode(nested.buf),
763            Self::Protocols(ref r) => r.encode(nested.buf),
764            Self::KeyShare(ref r) => r.encode(nested.buf),
765            Self::PresharedKey(r) => r.encode(nested.buf),
766            Self::SupportedVersions(ref r) => r.encode(nested.buf),
767            Self::TransportParameters(ref r) | Self::TransportParametersDraft(ref r) => {
768                nested.buf.extend_from_slice(r);
769            }
770            Self::EncryptedClientHello(ref r) => r.encode(nested.buf),
771            Self::Unknown(ref r) => r.encode(nested.buf),
772        }
773    }
774
775    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
776        let typ = ExtensionType::read(r)?;
777        let len = u16::read(r)? as usize;
778        let mut sub = r.sub(len)?;
779
780        let ext = match typ {
781            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
782            ExtensionType::ServerName => Self::ServerNameAck,
783            ExtensionType::SessionTicket => Self::SessionTicketAck,
784            ExtensionType::StatusRequest => Self::CertificateStatusAck,
785            ExtensionType::RenegotiationInfo => Self::RenegotiationInfo(PayloadU8::read(&mut sub)?),
786            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
787            ExtensionType::KeyShare => Self::KeyShare(KeyShareEntry::read(&mut sub)?),
788            ExtensionType::PreSharedKey => Self::PresharedKey(u16::read(&mut sub)?),
789            ExtensionType::ExtendedMasterSecret => Self::ExtendedMasterSecretAck,
790            ExtensionType::SupportedVersions => {
791                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
792            }
793            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
794            ExtensionType::TransportParametersDraft => {
795                Self::TransportParametersDraft(sub.rest().to_vec())
796            }
797            ExtensionType::EarlyData => Self::EarlyData,
798            ExtensionType::EncryptedClientHello => {
799                Self::EncryptedClientHello(ServerEncryptedClientHello::read(&mut sub)?)
800            }
801            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
802        };
803
804        sub.expect_empty("ServerExtension")
805            .map(|_| ext)
806    }
807}
808
809impl ServerExtension {
810    pub(crate) fn make_alpn(proto: &[&[u8]]) -> Self {
811        Self::Protocols(Vec::from_slices(proto))
812    }
813
814    #[cfg(feature = "tls12")]
815    pub(crate) fn make_empty_renegotiation_info() -> Self {
816        let empty = Vec::new();
817        Self::RenegotiationInfo(PayloadU8::new(empty))
818    }
819}
820
821#[derive(Clone, Debug)]
822pub struct ClientHelloPayload {
823    pub client_version: ProtocolVersion,
824    pub random: Random,
825    pub session_id: SessionId,
826    pub cipher_suites: Vec<CipherSuite>,
827    pub compression_methods: Vec<Compression>,
828    pub extensions: Vec<ClientExtension>,
829}
830
831impl Codec<'_> for ClientHelloPayload {
832    fn encode(&self, bytes: &mut Vec<u8>) {
833        self.payload_encode(bytes, Encoding::Standard)
834    }
835
836    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
837        let mut ret = Self {
838            client_version: ProtocolVersion::read(r)?,
839            random: Random::read(r)?,
840            session_id: SessionId::read(r)?,
841            cipher_suites: Vec::read(r)?,
842            compression_methods: Vec::read(r)?,
843            extensions: Vec::new(),
844        };
845
846        if r.any_left() {
847            ret.extensions = Vec::read(r)?;
848        }
849
850        match (r.any_left(), ret.extensions.is_empty()) {
851            (true, _) => Err(InvalidMessage::TrailingData("ClientHelloPayload")),
852            (_, true) => Err(InvalidMessage::MissingData("ClientHelloPayload")),
853            _ => Ok(ret),
854        }
855    }
856}
857
858impl TlsListElement for CipherSuite {
859    const SIZE_LEN: ListLength = ListLength::U16;
860}
861
862impl TlsListElement for Compression {
863    const SIZE_LEN: ListLength = ListLength::U8;
864}
865
866impl TlsListElement for ClientExtension {
867    const SIZE_LEN: ListLength = ListLength::U16;
868}
869
870impl TlsListElement for ExtensionType {
871    const SIZE_LEN: ListLength = ListLength::U8;
872}
873
874impl ClientHelloPayload {
875    pub(crate) fn ech_inner_encoding(&self, to_compress: Vec<ExtensionType>) -> Vec<u8> {
876        let mut bytes = Vec::new();
877        self.payload_encode(&mut bytes, Encoding::EchInnerHello { to_compress });
878        bytes
879    }
880
881    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
882        self.client_version.encode(bytes);
883        self.random.encode(bytes);
884
885        match purpose {
886            // SessionID is required to be empty in the encoded inner client hello.
887            Encoding::EchInnerHello { .. } => SessionId::empty().encode(bytes),
888            _ => self.session_id.encode(bytes),
889        }
890
891        self.cipher_suites.encode(bytes);
892        self.compression_methods.encode(bytes);
893
894        let to_compress = match purpose {
895            // Compressed extensions must be replaced in the encoded inner client hello.
896            Encoding::EchInnerHello { to_compress } if !to_compress.is_empty() => to_compress,
897            _ => {
898                if !self.extensions.is_empty() {
899                    self.extensions.encode(bytes);
900                }
901                return;
902            }
903        };
904
905        // Safety: not empty check in match guard.
906        let first_compressed_type = *to_compress.first().unwrap();
907
908        // Compressed extensions are in a contiguous range and must be replaced
909        // with a marker extension.
910        let compressed_start_idx = self
911            .extensions
912            .iter()
913            .position(|ext| ext.ext_type() == first_compressed_type);
914        let compressed_end_idx = compressed_start_idx.map(|start| start + to_compress.len());
915        let marker_ext = ClientExtension::EncryptedClientHelloOuterExtensions(to_compress);
916
917        let exts = self
918            .extensions
919            .iter()
920            .enumerate()
921            .filter_map(|(i, ext)| {
922                if Some(i) == compressed_start_idx {
923                    Some(&marker_ext)
924                } else if Some(i) > compressed_start_idx && Some(i) < compressed_end_idx {
925                    None
926                } else {
927                    Some(ext)
928                }
929            });
930
931        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
932        for ext in exts {
933            ext.encode(nested.buf);
934        }
935    }
936
937    /// Returns true if there is more than one extension of a given
938    /// type.
939    pub(crate) fn has_duplicate_extension(&self) -> bool {
940        has_duplicates::<_, _, u16>(
941            self.extensions
942                .iter()
943                .map(|ext| ext.ext_type()),
944        )
945    }
946
947    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&ClientExtension> {
948        self.extensions
949            .iter()
950            .find(|x| x.ext_type() == ext)
951    }
952
953    pub(crate) fn sni_extension(&self) -> Option<&[ServerName]> {
954        let ext = self.find_extension(ExtensionType::ServerName)?;
955        match *ext {
956            // Does this comply with RFC6066?
957            //
958            // [RFC6066][] specifies that literal IP addresses are illegal in
959            // `ServerName`s with a `name_type` of `host_name`.
960            //
961            // Some clients incorrectly send such extensions: we choose to
962            // successfully parse these (into `ServerNamePayload::IpAddress`)
963            // but then act like the client sent no `server_name` extension.
964            //
965            // [RFC6066]: https://datatracker.ietf.org/doc/html/rfc6066#section-3
966            ClientExtension::ServerName(ref req)
967                if !req
968                    .iter()
969                    .any(|name| matches!(name.payload, ServerNamePayload::IpAddress(_))) =>
970            {
971                Some(req)
972            }
973            _ => None,
974        }
975    }
976
977    pub fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
978        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
979        match *ext {
980            ClientExtension::SignatureAlgorithms(ref req) => Some(req),
981            _ => None,
982        }
983    }
984
985    pub(crate) fn namedgroups_extension(&self) -> Option<&[NamedGroup]> {
986        let ext = self.find_extension(ExtensionType::EllipticCurves)?;
987        match *ext {
988            ClientExtension::NamedGroups(ref req) => Some(req),
989            _ => None,
990        }
991    }
992
993    #[cfg(feature = "tls12")]
994    pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
995        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
996        match *ext {
997            ClientExtension::EcPointFormats(ref req) => Some(req),
998            _ => None,
999        }
1000    }
1001
1002    pub(crate) fn alpn_extension(&self) -> Option<&Vec<ProtocolName>> {
1003        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
1004        match *ext {
1005            ClientExtension::Protocols(ref req) => Some(req),
1006            _ => None,
1007        }
1008    }
1009
1010    pub(crate) fn quic_params_extension(&self) -> Option<Vec<u8>> {
1011        let ext = self
1012            .find_extension(ExtensionType::TransportParameters)
1013            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
1014        match *ext {
1015            ClientExtension::TransportParameters(ref bytes)
1016            | ClientExtension::TransportParametersDraft(ref bytes) => Some(bytes.to_vec()),
1017            _ => None,
1018        }
1019    }
1020
1021    #[cfg(feature = "tls12")]
1022    pub(crate) fn ticket_extension(&self) -> Option<&ClientExtension> {
1023        self.find_extension(ExtensionType::SessionTicket)
1024    }
1025
1026    pub(crate) fn versions_extension(&self) -> Option<&[ProtocolVersion]> {
1027        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1028        match *ext {
1029            ClientExtension::SupportedVersions(ref vers) => Some(vers),
1030            _ => None,
1031        }
1032    }
1033
1034    pub fn keyshare_extension(&self) -> Option<&[KeyShareEntry]> {
1035        let ext = self.find_extension(ExtensionType::KeyShare)?;
1036        match *ext {
1037            ClientExtension::KeyShare(ref shares) => Some(shares),
1038            _ => None,
1039        }
1040    }
1041
1042    pub(crate) fn has_keyshare_extension_with_duplicates(&self) -> bool {
1043        self.keyshare_extension()
1044            .map(|entries| {
1045                has_duplicates::<_, _, u16>(
1046                    entries
1047                        .iter()
1048                        .map(|kse| u16::from(kse.group)),
1049                )
1050            })
1051            .unwrap_or_default()
1052    }
1053
1054    pub(crate) fn psk(&self) -> Option<&PresharedKeyOffer> {
1055        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1056        match *ext {
1057            ClientExtension::PresharedKey(ref psk) => Some(psk),
1058            _ => None,
1059        }
1060    }
1061
1062    pub(crate) fn check_psk_ext_is_last(&self) -> bool {
1063        self.extensions
1064            .last()
1065            .map_or(false, |ext| ext.ext_type() == ExtensionType::PreSharedKey)
1066    }
1067
1068    pub(crate) fn psk_modes(&self) -> Option<&[PSKKeyExchangeMode]> {
1069        let ext = self.find_extension(ExtensionType::PSKKeyExchangeModes)?;
1070        match *ext {
1071            ClientExtension::PresharedKeyModes(ref psk_modes) => Some(psk_modes),
1072            _ => None,
1073        }
1074    }
1075
1076    pub(crate) fn psk_mode_offered(&self, mode: PSKKeyExchangeMode) -> bool {
1077        self.psk_modes()
1078            .map(|modes| modes.contains(&mode))
1079            .unwrap_or(false)
1080    }
1081
1082    pub(crate) fn set_psk_binder(&mut self, binder: impl Into<Vec<u8>>) {
1083        let last_extension = self.extensions.last_mut();
1084        if let Some(ClientExtension::PresharedKey(ref mut offer)) = last_extension {
1085            offer.binders[0] = PresharedKeyBinder::from(binder.into());
1086        }
1087    }
1088
1089    #[cfg(feature = "tls12")]
1090    pub(crate) fn ems_support_offered(&self) -> bool {
1091        self.find_extension(ExtensionType::ExtendedMasterSecret)
1092            .is_some()
1093    }
1094
1095    pub(crate) fn early_data_extension_offered(&self) -> bool {
1096        self.find_extension(ExtensionType::EarlyData)
1097            .is_some()
1098    }
1099
1100    pub(crate) fn certificate_compression_extension(
1101        &self,
1102    ) -> Option<&[CertificateCompressionAlgorithm]> {
1103        let ext = self.find_extension(ExtensionType::CompressCertificate)?;
1104        match *ext {
1105            ClientExtension::CertificateCompressionAlgorithms(ref algs) => Some(algs),
1106            _ => None,
1107        }
1108    }
1109
1110    pub(crate) fn has_certificate_compression_extension_with_duplicates(&self) -> bool {
1111        if let Some(algs) = self.certificate_compression_extension() {
1112            has_duplicates::<_, _, u16>(algs.iter().cloned())
1113        } else {
1114            false
1115        }
1116    }
1117}
1118
1119#[derive(Clone, Debug)]
1120pub(crate) enum HelloRetryExtension {
1121    KeyShare(NamedGroup),
1122    Cookie(PayloadU16),
1123    SupportedVersions(ProtocolVersion),
1124    EchHelloRetryRequest(Vec<u8>),
1125    Unknown(UnknownExtension),
1126}
1127
1128impl HelloRetryExtension {
1129    pub(crate) fn ext_type(&self) -> ExtensionType {
1130        match *self {
1131            Self::KeyShare(_) => ExtensionType::KeyShare,
1132            Self::Cookie(_) => ExtensionType::Cookie,
1133            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
1134            Self::EchHelloRetryRequest(_) => ExtensionType::EncryptedClientHello,
1135            Self::Unknown(ref r) => r.typ,
1136        }
1137    }
1138}
1139
1140impl Codec<'_> for HelloRetryExtension {
1141    fn encode(&self, bytes: &mut Vec<u8>) {
1142        self.ext_type().encode(bytes);
1143
1144        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1145        match *self {
1146            Self::KeyShare(ref r) => r.encode(nested.buf),
1147            Self::Cookie(ref r) => r.encode(nested.buf),
1148            Self::SupportedVersions(ref r) => r.encode(nested.buf),
1149            Self::EchHelloRetryRequest(ref r) => {
1150                nested.buf.extend_from_slice(r);
1151            }
1152            Self::Unknown(ref r) => r.encode(nested.buf),
1153        }
1154    }
1155
1156    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1157        let typ = ExtensionType::read(r)?;
1158        let len = u16::read(r)? as usize;
1159        let mut sub = r.sub(len)?;
1160
1161        let ext = match typ {
1162            ExtensionType::KeyShare => Self::KeyShare(NamedGroup::read(&mut sub)?),
1163            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
1164            ExtensionType::SupportedVersions => {
1165                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
1166            }
1167            ExtensionType::EncryptedClientHello => Self::EchHelloRetryRequest(sub.rest().to_vec()),
1168            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1169        };
1170
1171        sub.expect_empty("HelloRetryExtension")
1172            .map(|_| ext)
1173    }
1174}
1175
1176impl TlsListElement for HelloRetryExtension {
1177    const SIZE_LEN: ListLength = ListLength::U16;
1178}
1179
1180#[derive(Clone, Debug)]
1181pub struct HelloRetryRequest {
1182    pub(crate) legacy_version: ProtocolVersion,
1183    pub session_id: SessionId,
1184    pub(crate) cipher_suite: CipherSuite,
1185    pub(crate) extensions: Vec<HelloRetryExtension>,
1186}
1187
1188impl Codec<'_> for HelloRetryRequest {
1189    fn encode(&self, bytes: &mut Vec<u8>) {
1190        self.payload_encode(bytes, Encoding::Standard)
1191    }
1192
1193    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1194        let session_id = SessionId::read(r)?;
1195        let cipher_suite = CipherSuite::read(r)?;
1196        let compression = Compression::read(r)?;
1197
1198        if compression != Compression::Null {
1199            return Err(InvalidMessage::UnsupportedCompression);
1200        }
1201
1202        Ok(Self {
1203            legacy_version: ProtocolVersion::Unknown(0),
1204            session_id,
1205            cipher_suite,
1206            extensions: Vec::read(r)?,
1207        })
1208    }
1209}
1210
1211impl HelloRetryRequest {
1212    /// Returns true if there is more than one extension of a given
1213    /// type.
1214    pub(crate) fn has_duplicate_extension(&self) -> bool {
1215        has_duplicates::<_, _, u16>(
1216            self.extensions
1217                .iter()
1218                .map(|ext| ext.ext_type()),
1219        )
1220    }
1221
1222    pub(crate) fn has_unknown_extension(&self) -> bool {
1223        self.extensions.iter().any(|ext| {
1224            ext.ext_type() != ExtensionType::KeyShare
1225                && ext.ext_type() != ExtensionType::SupportedVersions
1226                && ext.ext_type() != ExtensionType::Cookie
1227                && ext.ext_type() != ExtensionType::EncryptedClientHello
1228        })
1229    }
1230
1231    fn find_extension(&self, ext: ExtensionType) -> Option<&HelloRetryExtension> {
1232        self.extensions
1233            .iter()
1234            .find(|x| x.ext_type() == ext)
1235    }
1236
1237    pub fn requested_key_share_group(&self) -> Option<NamedGroup> {
1238        let ext = self.find_extension(ExtensionType::KeyShare)?;
1239        match *ext {
1240            HelloRetryExtension::KeyShare(grp) => Some(grp),
1241            _ => None,
1242        }
1243    }
1244
1245    pub(crate) fn cookie(&self) -> Option<&PayloadU16> {
1246        let ext = self.find_extension(ExtensionType::Cookie)?;
1247        match *ext {
1248            HelloRetryExtension::Cookie(ref ck) => Some(ck),
1249            _ => None,
1250        }
1251    }
1252
1253    pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> {
1254        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1255        match *ext {
1256            HelloRetryExtension::SupportedVersions(ver) => Some(ver),
1257            _ => None,
1258        }
1259    }
1260
1261    pub(crate) fn ech(&self) -> Option<&Vec<u8>> {
1262        let ext = self.find_extension(ExtensionType::EncryptedClientHello)?;
1263        match *ext {
1264            HelloRetryExtension::EchHelloRetryRequest(ref ech) => Some(ech),
1265            _ => None,
1266        }
1267    }
1268
1269    fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
1270        self.legacy_version.encode(bytes);
1271        HELLO_RETRY_REQUEST_RANDOM.encode(bytes);
1272        self.session_id.encode(bytes);
1273        self.cipher_suite.encode(bytes);
1274        Compression::Null.encode(bytes);
1275
1276        match purpose {
1277            // For the purpose of ECH confirmation, the Encrypted Client Hello extension
1278            // must have its payload replaced by 8 zero bytes.
1279            //
1280            // See draft-ietf-tls-esni-18 7.2.1:
1281            // <https://datatracker.ietf.org/doc/html/draft-ietf-tls-esni-18#name-sending-helloretryrequest-2>
1282            Encoding::EchConfirmation => {
1283                let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1284                for ext in &self.extensions {
1285                    match ext.ext_type() {
1286                        ExtensionType::EncryptedClientHello => {
1287                            HelloRetryExtension::EchHelloRetryRequest(vec![0u8; 8])
1288                                .encode(extensions.buf);
1289                        }
1290                        _ => {
1291                            ext.encode(extensions.buf);
1292                        }
1293                    }
1294                }
1295            }
1296            _ => {
1297                self.extensions.encode(bytes);
1298            }
1299        }
1300    }
1301}
1302
1303#[derive(Clone, Debug)]
1304pub struct ServerHelloPayload {
1305    pub(crate) legacy_version: ProtocolVersion,
1306    pub(crate) random: Random,
1307    pub(crate) session_id: SessionId,
1308    pub(crate) cipher_suite: CipherSuite,
1309    pub(crate) compression_method: Compression,
1310    pub(crate) extensions: Vec<ServerExtension>,
1311}
1312
1313impl Codec<'_> for ServerHelloPayload {
1314    fn encode(&self, bytes: &mut Vec<u8>) {
1315        self.payload_encode(bytes, Encoding::Standard)
1316    }
1317
1318    // minus version and random, which have already been read.
1319    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1320        let session_id = SessionId::read(r)?;
1321        let suite = CipherSuite::read(r)?;
1322        let compression = Compression::read(r)?;
1323
1324        // RFC5246:
1325        // "The presence of extensions can be detected by determining whether
1326        //  there are bytes following the compression_method field at the end of
1327        //  the ServerHello."
1328        let extensions = if r.any_left() { Vec::read(r)? } else { vec![] };
1329
1330        let ret = Self {
1331            legacy_version: ProtocolVersion::Unknown(0),
1332            random: ZERO_RANDOM,
1333            session_id,
1334            cipher_suite: suite,
1335            compression_method: compression,
1336            extensions,
1337        };
1338
1339        r.expect_empty("ServerHelloPayload")
1340            .map(|_| ret)
1341    }
1342}
1343
1344impl HasServerExtensions for ServerHelloPayload {
1345    fn extensions(&self) -> &[ServerExtension] {
1346        &self.extensions
1347    }
1348}
1349
1350impl ServerHelloPayload {
1351    pub(crate) fn key_share(&self) -> Option<&KeyShareEntry> {
1352        let ext = self.find_extension(ExtensionType::KeyShare)?;
1353        match *ext {
1354            ServerExtension::KeyShare(ref share) => Some(share),
1355            _ => None,
1356        }
1357    }
1358
1359    pub(crate) fn psk_index(&self) -> Option<u16> {
1360        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1361        match *ext {
1362            ServerExtension::PresharedKey(ref index) => Some(*index),
1363            _ => None,
1364        }
1365    }
1366
1367    pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
1368        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
1369        match *ext {
1370            ServerExtension::EcPointFormats(ref fmts) => Some(fmts),
1371            _ => None,
1372        }
1373    }
1374
1375    #[cfg(feature = "tls12")]
1376    pub(crate) fn ems_support_acked(&self) -> bool {
1377        self.find_extension(ExtensionType::ExtendedMasterSecret)
1378            .is_some()
1379    }
1380
1381    pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> {
1382        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1383        match *ext {
1384            ServerExtension::SupportedVersions(vers) => Some(vers),
1385            _ => None,
1386        }
1387    }
1388
1389    fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
1390        self.legacy_version.encode(bytes);
1391
1392        match encoding {
1393            // When encoding a ServerHello for ECH confirmation, the random value
1394            // has the last 8 bytes zeroed out.
1395            Encoding::EchConfirmation => {
1396                // Indexing safety: self.random is 32 bytes long by definition.
1397                let rand_vec = self.random.get_encoding();
1398                bytes.extend_from_slice(&rand_vec.as_slice()[..24]);
1399                bytes.extend_from_slice(&[0u8; 8]);
1400            }
1401            _ => self.random.encode(bytes),
1402        }
1403
1404        self.session_id.encode(bytes);
1405        self.cipher_suite.encode(bytes);
1406        self.compression_method.encode(bytes);
1407
1408        if !self.extensions.is_empty() {
1409            self.extensions.encode(bytes);
1410        }
1411    }
1412}
1413
1414#[derive(Clone, Default, Debug)]
1415pub struct CertificateChain<'a>(pub Vec<CertificateDer<'a>>);
1416
1417impl CertificateChain<'_> {
1418    pub(crate) fn into_owned(self) -> CertificateChain<'static> {
1419        CertificateChain(
1420            self.0
1421                .into_iter()
1422                .map(|c| c.into_owned())
1423                .collect(),
1424        )
1425    }
1426}
1427
1428impl<'a> Codec<'a> for CertificateChain<'a> {
1429    fn encode(&self, bytes: &mut Vec<u8>) {
1430        Vec::encode(&self.0, bytes)
1431    }
1432
1433    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1434        Vec::read(r).map(Self)
1435    }
1436}
1437
1438impl<'a> Deref for CertificateChain<'a> {
1439    type Target = [CertificateDer<'a>];
1440
1441    fn deref(&self) -> &[CertificateDer<'a>] {
1442        &self.0
1443    }
1444}
1445
1446impl TlsListElement for CertificateDer<'_> {
1447    const SIZE_LEN: ListLength = ListLength::U24 {
1448        max: CERTIFICATE_MAX_SIZE_LIMIT,
1449        error: InvalidMessage::CertificatePayloadTooLarge,
1450    };
1451}
1452
1453/// TLS has a 16MB size limit on any handshake message,
1454/// plus a 16MB limit on any given certificate.
1455///
1456/// We contract that to 64KB to limit the amount of memory allocation
1457/// that is directly controllable by the peer.
1458pub(crate) const CERTIFICATE_MAX_SIZE_LIMIT: usize = 0x1_0000;
1459
1460#[derive(Debug)]
1461pub(crate) enum CertificateExtension<'a> {
1462    CertificateStatus(CertificateStatus<'a>),
1463    Unknown(UnknownExtension),
1464}
1465
1466impl<'a> CertificateExtension<'a> {
1467    pub(crate) fn ext_type(&self) -> ExtensionType {
1468        match *self {
1469            Self::CertificateStatus(_) => ExtensionType::StatusRequest,
1470            Self::Unknown(ref r) => r.typ,
1471        }
1472    }
1473
1474    pub(crate) fn cert_status(&self) -> Option<&[u8]> {
1475        match *self {
1476            Self::CertificateStatus(ref cs) => Some(cs.ocsp_response.0.bytes()),
1477            _ => None,
1478        }
1479    }
1480
1481    pub(crate) fn into_owned(self) -> CertificateExtension<'static> {
1482        match self {
1483            Self::CertificateStatus(st) => CertificateExtension::CertificateStatus(st.into_owned()),
1484            Self::Unknown(unk) => CertificateExtension::Unknown(unk),
1485        }
1486    }
1487}
1488
1489impl<'a> Codec<'a> for CertificateExtension<'a> {
1490    fn encode(&self, bytes: &mut Vec<u8>) {
1491        self.ext_type().encode(bytes);
1492
1493        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1494        match *self {
1495            Self::CertificateStatus(ref r) => r.encode(nested.buf),
1496            Self::Unknown(ref r) => r.encode(nested.buf),
1497        }
1498    }
1499
1500    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1501        let typ = ExtensionType::read(r)?;
1502        let len = u16::read(r)? as usize;
1503        let mut sub = r.sub(len)?;
1504
1505        let ext = match typ {
1506            ExtensionType::StatusRequest => {
1507                let st = CertificateStatus::read(&mut sub)?;
1508                Self::CertificateStatus(st)
1509            }
1510            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1511        };
1512
1513        sub.expect_empty("CertificateExtension")
1514            .map(|_| ext)
1515    }
1516}
1517
1518impl<'a> TlsListElement for CertificateExtension<'a> {
1519    const SIZE_LEN: ListLength = ListLength::U16;
1520}
1521
1522#[derive(Debug)]
1523pub(crate) struct CertificateEntry<'a> {
1524    pub(crate) cert: CertificateDer<'a>,
1525    pub(crate) exts: Vec<CertificateExtension<'a>>,
1526}
1527
1528impl<'a> Codec<'a> for CertificateEntry<'a> {
1529    fn encode(&self, bytes: &mut Vec<u8>) {
1530        self.cert.encode(bytes);
1531        self.exts.encode(bytes);
1532    }
1533
1534    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1535        Ok(Self {
1536            cert: CertificateDer::read(r)?,
1537            exts: Vec::read(r)?,
1538        })
1539    }
1540}
1541
1542impl<'a> CertificateEntry<'a> {
1543    pub(crate) fn new(cert: CertificateDer<'a>) -> Self {
1544        Self {
1545            cert,
1546            exts: Vec::new(),
1547        }
1548    }
1549
1550    pub(crate) fn into_owned(self) -> CertificateEntry<'static> {
1551        CertificateEntry {
1552            cert: self.cert.into_owned(),
1553            exts: self
1554                .exts
1555                .into_iter()
1556                .map(CertificateExtension::into_owned)
1557                .collect(),
1558        }
1559    }
1560
1561    pub(crate) fn has_duplicate_extension(&self) -> bool {
1562        has_duplicates::<_, _, u16>(
1563            self.exts
1564                .iter()
1565                .map(|ext| ext.ext_type()),
1566        )
1567    }
1568
1569    pub(crate) fn has_unknown_extension(&self) -> bool {
1570        self.exts
1571            .iter()
1572            .any(|ext| ext.ext_type() != ExtensionType::StatusRequest)
1573    }
1574
1575    pub(crate) fn ocsp_response(&self) -> Option<&[u8]> {
1576        self.exts
1577            .iter()
1578            .find(|ext| ext.ext_type() == ExtensionType::StatusRequest)
1579            .and_then(CertificateExtension::cert_status)
1580    }
1581}
1582
1583impl<'a> TlsListElement for CertificateEntry<'a> {
1584    const SIZE_LEN: ListLength = ListLength::U24 {
1585        max: CERTIFICATE_MAX_SIZE_LIMIT,
1586        error: InvalidMessage::CertificatePayloadTooLarge,
1587    };
1588}
1589
1590#[derive(Debug)]
1591pub struct CertificatePayloadTls13<'a> {
1592    pub(crate) context: PayloadU8,
1593    pub(crate) entries: Vec<CertificateEntry<'a>>,
1594}
1595
1596impl<'a> Codec<'a> for CertificatePayloadTls13<'a> {
1597    fn encode(&self, bytes: &mut Vec<u8>) {
1598        self.context.encode(bytes);
1599        self.entries.encode(bytes);
1600    }
1601
1602    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1603        Ok(Self {
1604            context: PayloadU8::read(r)?,
1605            entries: Vec::read(r)?,
1606        })
1607    }
1608}
1609
1610impl<'a> CertificatePayloadTls13<'a> {
1611    pub(crate) fn new(
1612        certs: impl Iterator<Item = &'a CertificateDer<'a>>,
1613        ocsp_response: Option<&'a [u8]>,
1614    ) -> Self {
1615        Self {
1616            context: PayloadU8::empty(),
1617            entries: certs
1618                // zip certificate iterator with `ocsp_response` followed by
1619                // an infinite-length iterator of `None`.
1620                .zip(
1621                    ocsp_response
1622                        .into_iter()
1623                        .map(Some)
1624                        .chain(iter::repeat(None)),
1625                )
1626                .map(|(cert, ocsp)| {
1627                    let mut e = CertificateEntry::new(cert.clone());
1628                    if let Some(ocsp) = ocsp {
1629                        e.exts
1630                            .push(CertificateExtension::CertificateStatus(
1631                                CertificateStatus::new(ocsp),
1632                            ));
1633                    }
1634                    e
1635                })
1636                .collect(),
1637        }
1638    }
1639
1640    pub(crate) fn into_owned(self) -> CertificatePayloadTls13<'static> {
1641        CertificatePayloadTls13 {
1642            context: self.context,
1643            entries: self
1644                .entries
1645                .into_iter()
1646                .map(CertificateEntry::into_owned)
1647                .collect(),
1648        }
1649    }
1650
1651    pub(crate) fn any_entry_has_duplicate_extension(&self) -> bool {
1652        for entry in &self.entries {
1653            if entry.has_duplicate_extension() {
1654                return true;
1655            }
1656        }
1657
1658        false
1659    }
1660
1661    pub(crate) fn any_entry_has_unknown_extension(&self) -> bool {
1662        for entry in &self.entries {
1663            if entry.has_unknown_extension() {
1664                return true;
1665            }
1666        }
1667
1668        false
1669    }
1670
1671    pub(crate) fn any_entry_has_extension(&self) -> bool {
1672        for entry in &self.entries {
1673            if !entry.exts.is_empty() {
1674                return true;
1675            }
1676        }
1677
1678        false
1679    }
1680
1681    pub(crate) fn end_entity_ocsp(&self) -> Vec<u8> {
1682        self.entries
1683            .first()
1684            .and_then(CertificateEntry::ocsp_response)
1685            .map(|resp| resp.to_vec())
1686            .unwrap_or_default()
1687    }
1688
1689    pub(crate) fn into_certificate_chain(self) -> CertificateChain<'a> {
1690        CertificateChain(
1691            self.entries
1692                .into_iter()
1693                .map(|e| e.cert)
1694                .collect(),
1695        )
1696    }
1697}
1698
1699/// Describes supported key exchange mechanisms.
1700#[derive(Clone, Copy, Debug, PartialEq)]
1701#[non_exhaustive]
1702pub enum KeyExchangeAlgorithm {
1703    /// Diffie-Hellman Key exchange (with only known parameters as defined in [RFC 7919]).
1704    ///
1705    /// [RFC 7919]: https://datatracker.ietf.org/doc/html/rfc7919
1706    DHE,
1707    /// Key exchange performed via elliptic curve Diffie-Hellman.
1708    ECDHE,
1709}
1710
1711pub(crate) static ALL_KEY_EXCHANGE_ALGORITHMS: &[KeyExchangeAlgorithm] =
1712    &[KeyExchangeAlgorithm::ECDHE, KeyExchangeAlgorithm::DHE];
1713
1714// We don't support arbitrary curves.  It's a terrible
1715// idea and unnecessary attack surface.  Please,
1716// get a grip.
1717#[derive(Debug)]
1718pub(crate) struct EcParameters {
1719    pub(crate) curve_type: ECCurveType,
1720    pub(crate) named_group: NamedGroup,
1721}
1722
1723impl Codec<'_> for EcParameters {
1724    fn encode(&self, bytes: &mut Vec<u8>) {
1725        self.curve_type.encode(bytes);
1726        self.named_group.encode(bytes);
1727    }
1728
1729    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1730        let ct = ECCurveType::read(r)?;
1731        if ct != ECCurveType::NamedCurve {
1732            return Err(InvalidMessage::UnsupportedCurveType);
1733        }
1734
1735        let grp = NamedGroup::read(r)?;
1736
1737        Ok(Self {
1738            curve_type: ct,
1739            named_group: grp,
1740        })
1741    }
1742}
1743
1744#[cfg(feature = "tls12")]
1745pub(crate) trait KxDecode<'a>: fmt::Debug + Sized {
1746    /// Decode a key exchange message given the key_exchange `algo`
1747    fn decode(r: &mut Reader<'a>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage>;
1748}
1749
1750#[cfg(feature = "tls12")]
1751#[derive(Debug)]
1752pub(crate) enum ClientKeyExchangeParams {
1753    Ecdh(ClientEcdhParams),
1754    Dh(ClientDhParams),
1755}
1756
1757#[cfg(feature = "tls12")]
1758impl ClientKeyExchangeParams {
1759    pub(crate) fn pub_key(&self) -> &[u8] {
1760        match self {
1761            Self::Ecdh(ecdh) => &ecdh.public.0,
1762            Self::Dh(dh) => &dh.public.0,
1763        }
1764    }
1765
1766    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
1767        match self {
1768            Self::Ecdh(ecdh) => ecdh.encode(buf),
1769            Self::Dh(dh) => dh.encode(buf),
1770        }
1771    }
1772}
1773
1774#[cfg(feature = "tls12")]
1775impl KxDecode<'_> for ClientKeyExchangeParams {
1776    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
1777        use KeyExchangeAlgorithm::*;
1778        Ok(match algo {
1779            ECDHE => Self::Ecdh(ClientEcdhParams::read(r)?),
1780            DHE => Self::Dh(ClientDhParams::read(r)?),
1781        })
1782    }
1783}
1784
1785#[cfg(feature = "tls12")]
1786#[derive(Debug)]
1787pub(crate) struct ClientEcdhParams {
1788    pub(crate) public: PayloadU8,
1789}
1790
1791#[cfg(feature = "tls12")]
1792impl Codec<'_> for ClientEcdhParams {
1793    fn encode(&self, bytes: &mut Vec<u8>) {
1794        self.public.encode(bytes);
1795    }
1796
1797    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1798        let pb = PayloadU8::read(r)?;
1799        Ok(Self { public: pb })
1800    }
1801}
1802
1803#[cfg(feature = "tls12")]
1804#[derive(Debug)]
1805pub(crate) struct ClientDhParams {
1806    pub(crate) public: PayloadU16,
1807}
1808
1809#[cfg(feature = "tls12")]
1810impl Codec<'_> for ClientDhParams {
1811    fn encode(&self, bytes: &mut Vec<u8>) {
1812        self.public.encode(bytes);
1813    }
1814
1815    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1816        Ok(Self {
1817            public: PayloadU16::read(r)?,
1818        })
1819    }
1820}
1821
1822#[derive(Debug)]
1823pub(crate) struct ServerEcdhParams {
1824    pub(crate) curve_params: EcParameters,
1825    pub(crate) public: PayloadU8,
1826}
1827
1828impl ServerEcdhParams {
1829    #[cfg(feature = "tls12")]
1830    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1831        Self {
1832            curve_params: EcParameters {
1833                curve_type: ECCurveType::NamedCurve,
1834                named_group: kx.group(),
1835            },
1836            public: PayloadU8::new(kx.pub_key().to_vec()),
1837        }
1838    }
1839}
1840
1841impl Codec<'_> for ServerEcdhParams {
1842    fn encode(&self, bytes: &mut Vec<u8>) {
1843        self.curve_params.encode(bytes);
1844        self.public.encode(bytes);
1845    }
1846
1847    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1848        let cp = EcParameters::read(r)?;
1849        let pb = PayloadU8::read(r)?;
1850
1851        Ok(Self {
1852            curve_params: cp,
1853            public: pb,
1854        })
1855    }
1856}
1857
1858#[derive(Debug)]
1859#[allow(non_snake_case)]
1860pub(crate) struct ServerDhParams {
1861    pub(crate) dh_p: PayloadU16,
1862    pub(crate) dh_g: PayloadU16,
1863    pub(crate) dh_Ys: PayloadU16,
1864}
1865
1866impl ServerDhParams {
1867    #[cfg(feature = "tls12")]
1868    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1869        let params = match kx.ffdhe_group() {
1870            Some(params) => params,
1871            None => panic!("invalid NamedGroup for DHE key exchange: {:?}", kx.group()),
1872        };
1873
1874        Self {
1875            dh_p: PayloadU16::new(params.p.to_vec()),
1876            dh_g: PayloadU16::new(params.g.to_vec()),
1877            dh_Ys: PayloadU16::new(kx.pub_key().to_vec()),
1878        }
1879    }
1880
1881    #[cfg(feature = "tls12")]
1882    pub(crate) fn as_ffdhe_group(&self) -> FfdheGroup<'_> {
1883        FfdheGroup::from_params_trimming_leading_zeros(&self.dh_p.0, &self.dh_g.0)
1884    }
1885}
1886
1887impl Codec<'_> for ServerDhParams {
1888    fn encode(&self, bytes: &mut Vec<u8>) {
1889        self.dh_p.encode(bytes);
1890        self.dh_g.encode(bytes);
1891        self.dh_Ys.encode(bytes);
1892    }
1893
1894    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1895        Ok(Self {
1896            dh_p: PayloadU16::read(r)?,
1897            dh_g: PayloadU16::read(r)?,
1898            dh_Ys: PayloadU16::read(r)?,
1899        })
1900    }
1901}
1902
1903#[allow(dead_code)]
1904#[derive(Debug)]
1905pub(crate) enum ServerKeyExchangeParams {
1906    Ecdh(ServerEcdhParams),
1907    Dh(ServerDhParams),
1908}
1909
1910impl ServerKeyExchangeParams {
1911    #[cfg(feature = "tls12")]
1912    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1913        match kx.group().key_exchange_algorithm() {
1914            KeyExchangeAlgorithm::DHE => Self::Dh(ServerDhParams::new(kx)),
1915            KeyExchangeAlgorithm::ECDHE => Self::Ecdh(ServerEcdhParams::new(kx)),
1916        }
1917    }
1918
1919    #[cfg(feature = "tls12")]
1920    pub(crate) fn pub_key(&self) -> &[u8] {
1921        match self {
1922            Self::Ecdh(ecdh) => &ecdh.public.0,
1923            Self::Dh(dh) => &dh.dh_Ys.0,
1924        }
1925    }
1926
1927    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
1928        match self {
1929            Self::Ecdh(ecdh) => ecdh.encode(buf),
1930            Self::Dh(dh) => dh.encode(buf),
1931        }
1932    }
1933}
1934
1935#[cfg(feature = "tls12")]
1936impl KxDecode<'_> for ServerKeyExchangeParams {
1937    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
1938        use KeyExchangeAlgorithm::*;
1939        Ok(match algo {
1940            ECDHE => Self::Ecdh(ServerEcdhParams::read(r)?),
1941            DHE => Self::Dh(ServerDhParams::read(r)?),
1942        })
1943    }
1944}
1945
1946#[derive(Debug)]
1947pub struct ServerKeyExchange {
1948    pub(crate) params: ServerKeyExchangeParams,
1949    pub(crate) dss: DigitallySignedStruct,
1950}
1951
1952impl ServerKeyExchange {
1953    pub fn encode(&self, buf: &mut Vec<u8>) {
1954        self.params.encode(buf);
1955        self.dss.encode(buf);
1956    }
1957}
1958
1959#[derive(Debug)]
1960pub enum ServerKeyExchangePayload {
1961    Known(ServerKeyExchange),
1962    Unknown(Payload<'static>),
1963}
1964
1965impl From<ServerKeyExchange> for ServerKeyExchangePayload {
1966    fn from(value: ServerKeyExchange) -> Self {
1967        Self::Known(value)
1968    }
1969}
1970
1971impl Codec<'_> for ServerKeyExchangePayload {
1972    fn encode(&self, bytes: &mut Vec<u8>) {
1973        match *self {
1974            Self::Known(ref x) => x.encode(bytes),
1975            Self::Unknown(ref x) => x.encode(bytes),
1976        }
1977    }
1978
1979    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1980        // read as Unknown, fully parse when we know the
1981        // KeyExchangeAlgorithm
1982        Ok(Self::Unknown(Payload::read(r).into_owned()))
1983    }
1984}
1985
1986impl ServerKeyExchangePayload {
1987    #[cfg(feature = "tls12")]
1988    pub(crate) fn unwrap_given_kxa(&self, kxa: KeyExchangeAlgorithm) -> Option<ServerKeyExchange> {
1989        if let Self::Unknown(ref unk) = *self {
1990            let mut rd = Reader::init(unk.bytes());
1991
1992            let result = ServerKeyExchange {
1993                params: ServerKeyExchangeParams::decode(&mut rd, kxa).ok()?,
1994                dss: DigitallySignedStruct::read(&mut rd).ok()?,
1995            };
1996
1997            if !rd.any_left() {
1998                return Some(result);
1999            };
2000        }
2001
2002        None
2003    }
2004}
2005
2006// -- EncryptedExtensions (TLS1.3 only) --
2007
2008impl TlsListElement for ServerExtension {
2009    const SIZE_LEN: ListLength = ListLength::U16;
2010}
2011
2012pub(crate) trait HasServerExtensions {
2013    fn extensions(&self) -> &[ServerExtension];
2014
2015    /// Returns true if there is more than one extension of a given
2016    /// type.
2017    fn has_duplicate_extension(&self) -> bool {
2018        has_duplicates::<_, _, u16>(
2019            self.extensions()
2020                .iter()
2021                .map(|ext| ext.ext_type()),
2022        )
2023    }
2024
2025    fn find_extension(&self, ext: ExtensionType) -> Option<&ServerExtension> {
2026        self.extensions()
2027            .iter()
2028            .find(|x| x.ext_type() == ext)
2029    }
2030
2031    fn alpn_protocol(&self) -> Option<&[u8]> {
2032        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
2033        match *ext {
2034            ServerExtension::Protocols(ref protos) => protos.as_single_slice(),
2035            _ => None,
2036        }
2037    }
2038
2039    fn quic_params_extension(&self) -> Option<Vec<u8>> {
2040        let ext = self
2041            .find_extension(ExtensionType::TransportParameters)
2042            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
2043        match *ext {
2044            ServerExtension::TransportParameters(ref bytes)
2045            | ServerExtension::TransportParametersDraft(ref bytes) => Some(bytes.to_vec()),
2046            _ => None,
2047        }
2048    }
2049
2050    fn server_ech_extension(&self) -> Option<ServerEncryptedClientHello> {
2051        let ext = self.find_extension(ExtensionType::EncryptedClientHello)?;
2052        match ext {
2053            ServerExtension::EncryptedClientHello(ech) => Some(ech.clone()),
2054            _ => None,
2055        }
2056    }
2057
2058    fn early_data_extension_offered(&self) -> bool {
2059        self.find_extension(ExtensionType::EarlyData)
2060            .is_some()
2061    }
2062}
2063
2064impl HasServerExtensions for Vec<ServerExtension> {
2065    fn extensions(&self) -> &[ServerExtension] {
2066        self
2067    }
2068}
2069
2070impl TlsListElement for ClientCertificateType {
2071    const SIZE_LEN: ListLength = ListLength::U8;
2072}
2073
2074wrapped_payload!(
2075    /// A `DistinguishedName` is a `Vec<u8>` wrapped in internal types.
2076    ///
2077    /// It contains the DER or BER encoded [`Subject` field from RFC 5280](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6)
2078    /// for a single certificate. The Subject field is [encoded as an RFC 5280 `Name`](https://datatracker.ietf.org/doc/html/rfc5280#page-116).
2079    /// It can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2080    ///
2081    /// ```ignore
2082    /// for name in distinguished_names {
2083    ///     use x509_parser::prelude::FromDer;
2084    ///     println!("{}", x509_parser::x509::X509Name::from_der(&name.0)?.1);
2085    /// }
2086    /// ```
2087    pub struct DistinguishedName,
2088    PayloadU16,
2089);
2090
2091impl DistinguishedName {
2092    /// Create a [`DistinguishedName`] after prepending its outer SEQUENCE encoding.
2093    ///
2094    /// This can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2095    ///
2096    /// ```ignore
2097    /// use x509_parser::prelude::FromDer;
2098    /// println!("{}", x509_parser::x509::X509Name::from_der(dn.as_ref())?.1);
2099    /// ```
2100    pub fn in_sequence(bytes: &[u8]) -> Self {
2101        Self(PayloadU16::new(wrap_in_sequence(bytes)))
2102    }
2103}
2104
2105impl TlsListElement for DistinguishedName {
2106    const SIZE_LEN: ListLength = ListLength::U16;
2107}
2108
2109#[derive(Debug)]
2110pub struct CertificateRequestPayload {
2111    pub(crate) certtypes: Vec<ClientCertificateType>,
2112    pub(crate) sigschemes: Vec<SignatureScheme>,
2113    pub(crate) canames: Vec<DistinguishedName>,
2114}
2115
2116impl Codec<'_> for CertificateRequestPayload {
2117    fn encode(&self, bytes: &mut Vec<u8>) {
2118        self.certtypes.encode(bytes);
2119        self.sigschemes.encode(bytes);
2120        self.canames.encode(bytes);
2121    }
2122
2123    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2124        let certtypes = Vec::read(r)?;
2125        let sigschemes = Vec::read(r)?;
2126        let canames = Vec::read(r)?;
2127
2128        if sigschemes.is_empty() {
2129            warn!("meaningless CertificateRequest message");
2130            Err(InvalidMessage::NoSignatureSchemes)
2131        } else {
2132            Ok(Self {
2133                certtypes,
2134                sigschemes,
2135                canames,
2136            })
2137        }
2138    }
2139}
2140
2141#[derive(Debug)]
2142pub(crate) enum CertReqExtension {
2143    SignatureAlgorithms(Vec<SignatureScheme>),
2144    AuthorityNames(Vec<DistinguishedName>),
2145    CertificateCompressionAlgorithms(Vec<CertificateCompressionAlgorithm>),
2146    Unknown(UnknownExtension),
2147}
2148
2149impl CertReqExtension {
2150    pub(crate) fn ext_type(&self) -> ExtensionType {
2151        match *self {
2152            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
2153            Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities,
2154            Self::CertificateCompressionAlgorithms(_) => ExtensionType::CompressCertificate,
2155            Self::Unknown(ref r) => r.typ,
2156        }
2157    }
2158}
2159
2160impl Codec<'_> for CertReqExtension {
2161    fn encode(&self, bytes: &mut Vec<u8>) {
2162        self.ext_type().encode(bytes);
2163
2164        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2165        match *self {
2166            Self::SignatureAlgorithms(ref r) => r.encode(nested.buf),
2167            Self::AuthorityNames(ref r) => r.encode(nested.buf),
2168            Self::CertificateCompressionAlgorithms(ref r) => r.encode(nested.buf),
2169            Self::Unknown(ref r) => r.encode(nested.buf),
2170        }
2171    }
2172
2173    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2174        let typ = ExtensionType::read(r)?;
2175        let len = u16::read(r)? as usize;
2176        let mut sub = r.sub(len)?;
2177
2178        let ext = match typ {
2179            ExtensionType::SignatureAlgorithms => {
2180                let schemes = Vec::read(&mut sub)?;
2181                if schemes.is_empty() {
2182                    return Err(InvalidMessage::NoSignatureSchemes);
2183                }
2184                Self::SignatureAlgorithms(schemes)
2185            }
2186            ExtensionType::CertificateAuthorities => {
2187                let cas = Vec::read(&mut sub)?;
2188                Self::AuthorityNames(cas)
2189            }
2190            ExtensionType::CompressCertificate => {
2191                Self::CertificateCompressionAlgorithms(Vec::read(&mut sub)?)
2192            }
2193            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2194        };
2195
2196        sub.expect_empty("CertReqExtension")
2197            .map(|_| ext)
2198    }
2199}
2200
2201impl TlsListElement for CertReqExtension {
2202    const SIZE_LEN: ListLength = ListLength::U16;
2203}
2204
2205#[derive(Debug)]
2206pub struct CertificateRequestPayloadTls13 {
2207    pub(crate) context: PayloadU8,
2208    pub(crate) extensions: Vec<CertReqExtension>,
2209}
2210
2211impl Codec<'_> for CertificateRequestPayloadTls13 {
2212    fn encode(&self, bytes: &mut Vec<u8>) {
2213        self.context.encode(bytes);
2214        self.extensions.encode(bytes);
2215    }
2216
2217    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2218        let context = PayloadU8::read(r)?;
2219        let extensions = Vec::read(r)?;
2220
2221        Ok(Self {
2222            context,
2223            extensions,
2224        })
2225    }
2226}
2227
2228impl CertificateRequestPayloadTls13 {
2229    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&CertReqExtension> {
2230        self.extensions
2231            .iter()
2232            .find(|x| x.ext_type() == ext)
2233    }
2234
2235    pub(crate) fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
2236        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
2237        match *ext {
2238            CertReqExtension::SignatureAlgorithms(ref sa) => Some(sa),
2239            _ => None,
2240        }
2241    }
2242
2243    pub(crate) fn authorities_extension(&self) -> Option<&[DistinguishedName]> {
2244        let ext = self.find_extension(ExtensionType::CertificateAuthorities)?;
2245        match *ext {
2246            CertReqExtension::AuthorityNames(ref an) => Some(an),
2247            _ => None,
2248        }
2249    }
2250
2251    pub(crate) fn certificate_compression_extension(
2252        &self,
2253    ) -> Option<&[CertificateCompressionAlgorithm]> {
2254        let ext = self.find_extension(ExtensionType::CompressCertificate)?;
2255        match *ext {
2256            CertReqExtension::CertificateCompressionAlgorithms(ref comps) => Some(comps),
2257            _ => None,
2258        }
2259    }
2260}
2261
2262// -- NewSessionTicket --
2263#[derive(Debug)]
2264pub struct NewSessionTicketPayload {
2265    pub(crate) lifetime_hint: u32,
2266    pub(crate) ticket: PayloadU16,
2267}
2268
2269impl NewSessionTicketPayload {
2270    #[cfg(feature = "tls12")]
2271    pub(crate) fn new(lifetime_hint: u32, ticket: Vec<u8>) -> Self {
2272        Self {
2273            lifetime_hint,
2274            ticket: PayloadU16::new(ticket),
2275        }
2276    }
2277}
2278
2279impl Codec<'_> for NewSessionTicketPayload {
2280    fn encode(&self, bytes: &mut Vec<u8>) {
2281        self.lifetime_hint.encode(bytes);
2282        self.ticket.encode(bytes);
2283    }
2284
2285    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2286        let lifetime = u32::read(r)?;
2287        let ticket = PayloadU16::read(r)?;
2288
2289        Ok(Self {
2290            lifetime_hint: lifetime,
2291            ticket,
2292        })
2293    }
2294}
2295
2296// -- NewSessionTicket electric boogaloo --
2297#[derive(Debug)]
2298pub(crate) enum NewSessionTicketExtension {
2299    EarlyData(u32),
2300    Unknown(UnknownExtension),
2301}
2302
2303impl NewSessionTicketExtension {
2304    pub(crate) fn ext_type(&self) -> ExtensionType {
2305        match *self {
2306            Self::EarlyData(_) => ExtensionType::EarlyData,
2307            Self::Unknown(ref r) => r.typ,
2308        }
2309    }
2310}
2311
2312impl Codec<'_> for NewSessionTicketExtension {
2313    fn encode(&self, bytes: &mut Vec<u8>) {
2314        self.ext_type().encode(bytes);
2315
2316        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2317        match *self {
2318            Self::EarlyData(r) => r.encode(nested.buf),
2319            Self::Unknown(ref r) => r.encode(nested.buf),
2320        }
2321    }
2322
2323    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2324        let typ = ExtensionType::read(r)?;
2325        let len = u16::read(r)? as usize;
2326        let mut sub = r.sub(len)?;
2327
2328        let ext = match typ {
2329            ExtensionType::EarlyData => Self::EarlyData(u32::read(&mut sub)?),
2330            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2331        };
2332
2333        sub.expect_empty("NewSessionTicketExtension")
2334            .map(|_| ext)
2335    }
2336}
2337
2338impl TlsListElement for NewSessionTicketExtension {
2339    const SIZE_LEN: ListLength = ListLength::U16;
2340}
2341
2342#[derive(Debug)]
2343pub struct NewSessionTicketPayloadTls13 {
2344    pub(crate) lifetime: u32,
2345    pub(crate) age_add: u32,
2346    pub(crate) nonce: PayloadU8,
2347    pub(crate) ticket: PayloadU16,
2348    pub(crate) exts: Vec<NewSessionTicketExtension>,
2349}
2350
2351impl NewSessionTicketPayloadTls13 {
2352    pub(crate) fn new(lifetime: u32, age_add: u32, nonce: Vec<u8>, ticket: Vec<u8>) -> Self {
2353        Self {
2354            lifetime,
2355            age_add,
2356            nonce: PayloadU8::new(nonce),
2357            ticket: PayloadU16::new(ticket),
2358            exts: vec![],
2359        }
2360    }
2361
2362    pub(crate) fn has_duplicate_extension(&self) -> bool {
2363        has_duplicates::<_, _, u16>(
2364            self.exts
2365                .iter()
2366                .map(|ext| ext.ext_type()),
2367        )
2368    }
2369
2370    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&NewSessionTicketExtension> {
2371        self.exts
2372            .iter()
2373            .find(|x| x.ext_type() == ext)
2374    }
2375
2376    pub(crate) fn max_early_data_size(&self) -> Option<u32> {
2377        let ext = self.find_extension(ExtensionType::EarlyData)?;
2378        match *ext {
2379            NewSessionTicketExtension::EarlyData(ref sz) => Some(*sz),
2380            _ => None,
2381        }
2382    }
2383}
2384
2385impl Codec<'_> for NewSessionTicketPayloadTls13 {
2386    fn encode(&self, bytes: &mut Vec<u8>) {
2387        self.lifetime.encode(bytes);
2388        self.age_add.encode(bytes);
2389        self.nonce.encode(bytes);
2390        self.ticket.encode(bytes);
2391        self.exts.encode(bytes);
2392    }
2393
2394    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2395        let lifetime = u32::read(r)?;
2396        let age_add = u32::read(r)?;
2397        let nonce = PayloadU8::read(r)?;
2398        let ticket = PayloadU16::read(r)?;
2399        let exts = Vec::read(r)?;
2400
2401        Ok(Self {
2402            lifetime,
2403            age_add,
2404            nonce,
2405            ticket,
2406            exts,
2407        })
2408    }
2409}
2410
2411// -- RFC6066 certificate status types
2412
2413/// Only supports OCSP
2414#[derive(Debug)]
2415pub struct CertificateStatus<'a> {
2416    pub(crate) ocsp_response: PayloadU24<'a>,
2417}
2418
2419impl<'a> Codec<'a> for CertificateStatus<'a> {
2420    fn encode(&self, bytes: &mut Vec<u8>) {
2421        CertificateStatusType::OCSP.encode(bytes);
2422        self.ocsp_response.encode(bytes);
2423    }
2424
2425    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2426        let typ = CertificateStatusType::read(r)?;
2427
2428        match typ {
2429            CertificateStatusType::OCSP => Ok(Self {
2430                ocsp_response: PayloadU24::read(r)?,
2431            }),
2432            _ => Err(InvalidMessage::InvalidCertificateStatusType),
2433        }
2434    }
2435}
2436
2437impl<'a> CertificateStatus<'a> {
2438    pub(crate) fn new(ocsp: &'a [u8]) -> Self {
2439        CertificateStatus {
2440            ocsp_response: PayloadU24(Payload::Borrowed(ocsp)),
2441        }
2442    }
2443
2444    #[cfg(feature = "tls12")]
2445    pub(crate) fn into_inner(self) -> Vec<u8> {
2446        self.ocsp_response.0.into_vec()
2447    }
2448
2449    pub(crate) fn into_owned(self) -> CertificateStatus<'static> {
2450        CertificateStatus {
2451            ocsp_response: self.ocsp_response.into_owned(),
2452        }
2453    }
2454}
2455
2456// -- RFC8879 compressed certificates
2457
2458#[derive(Debug)]
2459pub struct CompressedCertificatePayload<'a> {
2460    pub(crate) alg: CertificateCompressionAlgorithm,
2461    pub(crate) uncompressed_len: u32,
2462    pub(crate) compressed: PayloadU24<'a>,
2463}
2464
2465impl<'a> Codec<'a> for CompressedCertificatePayload<'a> {
2466    fn encode(&self, bytes: &mut Vec<u8>) {
2467        self.alg.encode(bytes);
2468        codec::u24(self.uncompressed_len).encode(bytes);
2469        self.compressed.encode(bytes);
2470    }
2471
2472    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2473        Ok(Self {
2474            alg: CertificateCompressionAlgorithm::read(r)?,
2475            uncompressed_len: codec::u24::read(r)?.0,
2476            compressed: PayloadU24::read(r)?,
2477        })
2478    }
2479}
2480
2481impl CompressedCertificatePayload<'_> {
2482    fn into_owned(self) -> CompressedCertificatePayload<'static> {
2483        CompressedCertificatePayload {
2484            compressed: self.compressed.into_owned(),
2485            ..self
2486        }
2487    }
2488
2489    pub(crate) fn as_borrowed(&self) -> CompressedCertificatePayload<'_> {
2490        CompressedCertificatePayload {
2491            alg: self.alg,
2492            uncompressed_len: self.uncompressed_len,
2493            compressed: PayloadU24(Payload::Borrowed(self.compressed.0.bytes())),
2494        }
2495    }
2496}
2497
2498#[derive(Debug)]
2499pub enum HandshakePayload<'a> {
2500    HelloRequest,
2501    ClientHello(ClientHelloPayload),
2502    ServerHello(ServerHelloPayload),
2503    HelloRetryRequest(HelloRetryRequest),
2504    Certificate(CertificateChain<'a>),
2505    CertificateTls13(CertificatePayloadTls13<'a>),
2506    CompressedCertificate(CompressedCertificatePayload<'a>),
2507    ServerKeyExchange(ServerKeyExchangePayload),
2508    CertificateRequest(CertificateRequestPayload),
2509    CertificateRequestTls13(CertificateRequestPayloadTls13),
2510    CertificateVerify(DigitallySignedStruct),
2511    ServerHelloDone,
2512    EndOfEarlyData,
2513    ClientKeyExchange(Payload<'a>),
2514    NewSessionTicket(NewSessionTicketPayload),
2515    NewSessionTicketTls13(NewSessionTicketPayloadTls13),
2516    EncryptedExtensions(Vec<ServerExtension>),
2517    KeyUpdate(KeyUpdateRequest),
2518    Finished(Payload<'a>),
2519    CertificateStatus(CertificateStatus<'a>),
2520    MessageHash(Payload<'a>),
2521    Unknown(Payload<'a>),
2522}
2523
2524impl HandshakePayload<'_> {
2525    fn encode(&self, bytes: &mut Vec<u8>) {
2526        use self::HandshakePayload::*;
2527        match *self {
2528            HelloRequest | ServerHelloDone | EndOfEarlyData => {}
2529            ClientHello(ref x) => x.encode(bytes),
2530            ServerHello(ref x) => x.encode(bytes),
2531            HelloRetryRequest(ref x) => x.encode(bytes),
2532            Certificate(ref x) => x.encode(bytes),
2533            CertificateTls13(ref x) => x.encode(bytes),
2534            CompressedCertificate(ref x) => x.encode(bytes),
2535            ServerKeyExchange(ref x) => x.encode(bytes),
2536            ClientKeyExchange(ref x) => x.encode(bytes),
2537            CertificateRequest(ref x) => x.encode(bytes),
2538            CertificateRequestTls13(ref x) => x.encode(bytes),
2539            CertificateVerify(ref x) => x.encode(bytes),
2540            NewSessionTicket(ref x) => x.encode(bytes),
2541            NewSessionTicketTls13(ref x) => x.encode(bytes),
2542            EncryptedExtensions(ref x) => x.encode(bytes),
2543            KeyUpdate(ref x) => x.encode(bytes),
2544            Finished(ref x) => x.encode(bytes),
2545            CertificateStatus(ref x) => x.encode(bytes),
2546            MessageHash(ref x) => x.encode(bytes),
2547            Unknown(ref x) => x.encode(bytes),
2548        }
2549    }
2550
2551    fn into_owned(self) -> HandshakePayload<'static> {
2552        use HandshakePayload::*;
2553
2554        match self {
2555            HelloRequest => HelloRequest,
2556            ClientHello(x) => ClientHello(x),
2557            ServerHello(x) => ServerHello(x),
2558            HelloRetryRequest(x) => HelloRetryRequest(x),
2559            Certificate(x) => Certificate(x.into_owned()),
2560            CertificateTls13(x) => CertificateTls13(x.into_owned()),
2561            CompressedCertificate(x) => CompressedCertificate(x.into_owned()),
2562            ServerKeyExchange(x) => ServerKeyExchange(x),
2563            CertificateRequest(x) => CertificateRequest(x),
2564            CertificateRequestTls13(x) => CertificateRequestTls13(x),
2565            CertificateVerify(x) => CertificateVerify(x),
2566            ServerHelloDone => ServerHelloDone,
2567            EndOfEarlyData => EndOfEarlyData,
2568            ClientKeyExchange(x) => ClientKeyExchange(x.into_owned()),
2569            NewSessionTicket(x) => NewSessionTicket(x),
2570            NewSessionTicketTls13(x) => NewSessionTicketTls13(x),
2571            EncryptedExtensions(x) => EncryptedExtensions(x),
2572            KeyUpdate(x) => KeyUpdate(x),
2573            Finished(x) => Finished(x.into_owned()),
2574            CertificateStatus(x) => CertificateStatus(x.into_owned()),
2575            MessageHash(x) => MessageHash(x.into_owned()),
2576            Unknown(x) => Unknown(x.into_owned()),
2577        }
2578    }
2579}
2580
2581#[derive(Debug)]
2582pub struct HandshakeMessagePayload<'a> {
2583    pub typ: HandshakeType,
2584    pub payload: HandshakePayload<'a>,
2585}
2586
2587impl<'a> Codec<'a> for HandshakeMessagePayload<'a> {
2588    fn encode(&self, bytes: &mut Vec<u8>) {
2589        self.payload_encode(bytes, Encoding::Standard);
2590    }
2591
2592    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2593        Self::read_version(r, ProtocolVersion::TLSv1_2)
2594    }
2595}
2596
2597impl<'a> HandshakeMessagePayload<'a> {
2598    pub(crate) fn read_version(
2599        r: &mut Reader<'a>,
2600        vers: ProtocolVersion,
2601    ) -> Result<Self, InvalidMessage> {
2602        let mut typ = HandshakeType::read(r)?;
2603        let len = codec::u24::read(r)?.0 as usize;
2604        let mut sub = r.sub(len)?;
2605
2606        let payload = match typ {
2607            HandshakeType::HelloRequest if sub.left() == 0 => HandshakePayload::HelloRequest,
2608            HandshakeType::ClientHello => {
2609                HandshakePayload::ClientHello(ClientHelloPayload::read(&mut sub)?)
2610            }
2611            HandshakeType::ServerHello => {
2612                let version = ProtocolVersion::read(&mut sub)?;
2613                let random = Random::read(&mut sub)?;
2614
2615                if random == HELLO_RETRY_REQUEST_RANDOM {
2616                    let mut hrr = HelloRetryRequest::read(&mut sub)?;
2617                    hrr.legacy_version = version;
2618                    typ = HandshakeType::HelloRetryRequest;
2619                    HandshakePayload::HelloRetryRequest(hrr)
2620                } else {
2621                    let mut shp = ServerHelloPayload::read(&mut sub)?;
2622                    shp.legacy_version = version;
2623                    shp.random = random;
2624                    HandshakePayload::ServerHello(shp)
2625                }
2626            }
2627            HandshakeType::Certificate if vers == ProtocolVersion::TLSv1_3 => {
2628                let p = CertificatePayloadTls13::read(&mut sub)?;
2629                HandshakePayload::CertificateTls13(p)
2630            }
2631            HandshakeType::Certificate => {
2632                HandshakePayload::Certificate(CertificateChain::read(&mut sub)?)
2633            }
2634            HandshakeType::ServerKeyExchange => {
2635                let p = ServerKeyExchangePayload::read(&mut sub)?;
2636                HandshakePayload::ServerKeyExchange(p)
2637            }
2638            HandshakeType::ServerHelloDone => {
2639                sub.expect_empty("ServerHelloDone")?;
2640                HandshakePayload::ServerHelloDone
2641            }
2642            HandshakeType::ClientKeyExchange => {
2643                HandshakePayload::ClientKeyExchange(Payload::read(&mut sub))
2644            }
2645            HandshakeType::CertificateRequest if vers == ProtocolVersion::TLSv1_3 => {
2646                let p = CertificateRequestPayloadTls13::read(&mut sub)?;
2647                HandshakePayload::CertificateRequestTls13(p)
2648            }
2649            HandshakeType::CertificateRequest => {
2650                let p = CertificateRequestPayload::read(&mut sub)?;
2651                HandshakePayload::CertificateRequest(p)
2652            }
2653            HandshakeType::CompressedCertificate => HandshakePayload::CompressedCertificate(
2654                CompressedCertificatePayload::read(&mut sub)?,
2655            ),
2656            HandshakeType::CertificateVerify => {
2657                HandshakePayload::CertificateVerify(DigitallySignedStruct::read(&mut sub)?)
2658            }
2659            HandshakeType::NewSessionTicket if vers == ProtocolVersion::TLSv1_3 => {
2660                let p = NewSessionTicketPayloadTls13::read(&mut sub)?;
2661                HandshakePayload::NewSessionTicketTls13(p)
2662            }
2663            HandshakeType::NewSessionTicket => {
2664                let p = NewSessionTicketPayload::read(&mut sub)?;
2665                HandshakePayload::NewSessionTicket(p)
2666            }
2667            HandshakeType::EncryptedExtensions => {
2668                HandshakePayload::EncryptedExtensions(Vec::read(&mut sub)?)
2669            }
2670            HandshakeType::KeyUpdate => {
2671                HandshakePayload::KeyUpdate(KeyUpdateRequest::read(&mut sub)?)
2672            }
2673            HandshakeType::EndOfEarlyData => {
2674                sub.expect_empty("EndOfEarlyData")?;
2675                HandshakePayload::EndOfEarlyData
2676            }
2677            HandshakeType::Finished => HandshakePayload::Finished(Payload::read(&mut sub)),
2678            HandshakeType::CertificateStatus => {
2679                HandshakePayload::CertificateStatus(CertificateStatus::read(&mut sub)?)
2680            }
2681            HandshakeType::MessageHash => {
2682                // does not appear on the wire
2683                return Err(InvalidMessage::UnexpectedMessage("MessageHash"));
2684            }
2685            HandshakeType::HelloRetryRequest => {
2686                // not legal on wire
2687                return Err(InvalidMessage::UnexpectedMessage("HelloRetryRequest"));
2688            }
2689            _ => HandshakePayload::Unknown(Payload::read(&mut sub)),
2690        };
2691
2692        sub.expect_empty("HandshakeMessagePayload")
2693            .map(|_| Self { typ, payload })
2694    }
2695
2696    pub(crate) fn encoding_for_binder_signing(&self) -> Vec<u8> {
2697        let mut ret = self.get_encoding();
2698
2699        let binder_len = match self.payload {
2700            HandshakePayload::ClientHello(ref ch) => match ch.extensions.last() {
2701                Some(ClientExtension::PresharedKey(ref offer)) => {
2702                    let mut binders_encoding = Vec::new();
2703                    offer
2704                        .binders
2705                        .encode(&mut binders_encoding);
2706                    binders_encoding.len()
2707                }
2708                _ => 0,
2709            },
2710            _ => 0,
2711        };
2712
2713        let ret_len = ret.len() - binder_len;
2714        ret.truncate(ret_len);
2715        ret
2716    }
2717
2718    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
2719        // output type, length, and encoded payload
2720        match self.typ {
2721            HandshakeType::HelloRetryRequest => HandshakeType::ServerHello,
2722            _ => self.typ,
2723        }
2724        .encode(bytes);
2725
2726        let nested = LengthPrefixedBuffer::new(
2727            ListLength::U24 {
2728                max: usize::MAX,
2729                error: InvalidMessage::MessageTooLarge,
2730            },
2731            bytes,
2732        );
2733
2734        match &self.payload {
2735            // for Server Hello and HelloRetryRequest payloads we need to encode the payload
2736            // differently based on the purpose of the encoding.
2737            HandshakePayload::ServerHello(payload) => payload.payload_encode(nested.buf, encoding),
2738            HandshakePayload::HelloRetryRequest(payload) => {
2739                payload.payload_encode(nested.buf, encoding)
2740            }
2741
2742            // All other payload types are encoded the same regardless of purpose.
2743            _ => self.payload.encode(nested.buf),
2744        }
2745    }
2746
2747    pub(crate) fn build_handshake_hash(hash: &[u8]) -> Self {
2748        Self {
2749            typ: HandshakeType::MessageHash,
2750            payload: HandshakePayload::MessageHash(Payload::new(hash.to_vec())),
2751        }
2752    }
2753
2754    pub(crate) fn into_owned(self) -> HandshakeMessagePayload<'static> {
2755        let Self { typ, payload } = self;
2756        HandshakeMessagePayload {
2757            typ,
2758            payload: payload.into_owned(),
2759        }
2760    }
2761}
2762
2763#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
2764pub struct HpkeSymmetricCipherSuite {
2765    pub kdf_id: HpkeKdf,
2766    pub aead_id: HpkeAead,
2767}
2768
2769impl Codec<'_> for HpkeSymmetricCipherSuite {
2770    fn encode(&self, bytes: &mut Vec<u8>) {
2771        self.kdf_id.encode(bytes);
2772        self.aead_id.encode(bytes);
2773    }
2774
2775    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2776        Ok(Self {
2777            kdf_id: HpkeKdf::read(r)?,
2778            aead_id: HpkeAead::read(r)?,
2779        })
2780    }
2781}
2782
2783impl TlsListElement for HpkeSymmetricCipherSuite {
2784    const SIZE_LEN: ListLength = ListLength::U16;
2785}
2786
2787#[derive(Clone, Debug, PartialEq)]
2788pub struct HpkeKeyConfig {
2789    pub config_id: u8,
2790    pub kem_id: HpkeKem,
2791    pub public_key: PayloadU16,
2792    pub symmetric_cipher_suites: Vec<HpkeSymmetricCipherSuite>,
2793}
2794
2795impl Codec<'_> for HpkeKeyConfig {
2796    fn encode(&self, bytes: &mut Vec<u8>) {
2797        self.config_id.encode(bytes);
2798        self.kem_id.encode(bytes);
2799        self.public_key.encode(bytes);
2800        self.symmetric_cipher_suites
2801            .encode(bytes);
2802    }
2803
2804    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2805        Ok(Self {
2806            config_id: u8::read(r)?,
2807            kem_id: HpkeKem::read(r)?,
2808            public_key: PayloadU16::read(r)?,
2809            symmetric_cipher_suites: Vec::<HpkeSymmetricCipherSuite>::read(r)?,
2810        })
2811    }
2812}
2813
2814#[derive(Clone, Debug, PartialEq)]
2815pub struct EchConfigContents {
2816    pub key_config: HpkeKeyConfig,
2817    pub maximum_name_length: u8,
2818    pub public_name: DnsName<'static>,
2819    pub extensions: Vec<EchConfigExtension>,
2820}
2821
2822impl EchConfigContents {
2823    /// Returns true if there is more than one extension of a given
2824    /// type.
2825    pub(crate) fn has_duplicate_extension(&self) -> bool {
2826        has_duplicates::<_, _, u16>(
2827            self.extensions
2828                .iter()
2829                .map(|ext| ext.ext_type()),
2830        )
2831    }
2832
2833    /// Returns true if there is at least one mandatory unsupported extension.
2834    pub(crate) fn has_unknown_mandatory_extension(&self) -> bool {
2835        self.extensions
2836            .iter()
2837            // An extension is considered mandatory if the high bit of its type is set.
2838            .any(|ext| {
2839                matches!(ext.ext_type(), ExtensionType::Unknown(_))
2840                    && u16::from(ext.ext_type()) & 0x8000 != 0
2841            })
2842    }
2843}
2844
2845impl Codec<'_> for EchConfigContents {
2846    fn encode(&self, bytes: &mut Vec<u8>) {
2847        self.key_config.encode(bytes);
2848        self.maximum_name_length.encode(bytes);
2849        let dns_name = &self.public_name.borrow();
2850        PayloadU8::encode_slice(dns_name.as_ref().as_ref(), bytes);
2851        self.extensions.encode(bytes);
2852    }
2853
2854    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2855        Ok(Self {
2856            key_config: HpkeKeyConfig::read(r)?,
2857            maximum_name_length: u8::read(r)?,
2858            public_name: {
2859                DnsName::try_from(PayloadU8::read(r)?.0.as_slice())
2860                    .map_err(|_| InvalidMessage::InvalidServerName)?
2861                    .to_owned()
2862            },
2863            extensions: Vec::read(r)?,
2864        })
2865    }
2866}
2867
2868/// An encrypted client hello (ECH) config.
2869#[derive(Clone, Debug, PartialEq)]
2870pub enum EchConfigPayload {
2871    /// A recognized V18 ECH configuration.
2872    V18(EchConfigContents),
2873    /// An unknown version ECH configuration.
2874    Unknown {
2875        version: EchVersion,
2876        contents: PayloadU16,
2877    },
2878}
2879
2880impl TlsListElement for EchConfigPayload {
2881    const SIZE_LEN: ListLength = ListLength::U16;
2882}
2883
2884impl Codec<'_> for EchConfigPayload {
2885    fn encode(&self, bytes: &mut Vec<u8>) {
2886        match self {
2887            Self::V18(c) => {
2888                // Write the version, the length, and the contents.
2889                EchVersion::V18.encode(bytes);
2890                let inner = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2891                c.encode(inner.buf);
2892            }
2893            Self::Unknown { version, contents } => {
2894                // Unknown configuration versions are opaque.
2895                version.encode(bytes);
2896                contents.encode(bytes);
2897            }
2898        }
2899    }
2900
2901    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2902        let version = EchVersion::read(r)?;
2903        let length = u16::read(r)?;
2904        let mut contents = r.sub(length as usize)?;
2905
2906        Ok(match version {
2907            EchVersion::V18 => Self::V18(EchConfigContents::read(&mut contents)?),
2908            _ => {
2909                // Note: we don't PayloadU16::read() here because we've already read the length prefix.
2910                let data = PayloadU16::new(contents.rest().into());
2911                Self::Unknown {
2912                    version,
2913                    contents: data,
2914                }
2915            }
2916        })
2917    }
2918}
2919
2920#[derive(Clone, Debug, PartialEq)]
2921pub enum EchConfigExtension {
2922    Unknown(UnknownExtension),
2923}
2924
2925impl EchConfigExtension {
2926    pub(crate) fn ext_type(&self) -> ExtensionType {
2927        match *self {
2928            Self::Unknown(ref r) => r.typ,
2929        }
2930    }
2931}
2932
2933impl Codec<'_> for EchConfigExtension {
2934    fn encode(&self, bytes: &mut Vec<u8>) {
2935        self.ext_type().encode(bytes);
2936
2937        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2938        match *self {
2939            Self::Unknown(ref r) => r.encode(nested.buf),
2940        }
2941    }
2942
2943    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2944        let typ = ExtensionType::read(r)?;
2945        let len = u16::read(r)? as usize;
2946        let mut sub = r.sub(len)?;
2947
2948        #[allow(clippy::match_single_binding)] // Future-proofing.
2949        let ext = match typ {
2950            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2951        };
2952
2953        sub.expect_empty("EchConfigExtension")
2954            .map(|_| ext)
2955    }
2956}
2957
2958impl TlsListElement for EchConfigExtension {
2959    const SIZE_LEN: ListLength = ListLength::U16;
2960}
2961
2962/// Representation of the `ECHClientHello` client extension specified in
2963/// [draft-ietf-tls-esni Section 5].
2964///
2965/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
2966#[derive(Clone, Debug)]
2967pub enum EncryptedClientHello {
2968    /// A `ECHClientHello` with type [EchClientHelloType::ClientHelloOuter].
2969    Outer(EncryptedClientHelloOuter),
2970    /// An empty `ECHClientHello` with type [EchClientHelloType::ClientHelloInner].
2971    ///
2972    /// This variant has no payload.
2973    Inner,
2974}
2975
2976impl Codec<'_> for EncryptedClientHello {
2977    fn encode(&self, bytes: &mut Vec<u8>) {
2978        match self {
2979            Self::Outer(payload) => {
2980                EchClientHelloType::ClientHelloOuter.encode(bytes);
2981                payload.encode(bytes);
2982            }
2983            Self::Inner => {
2984                EchClientHelloType::ClientHelloInner.encode(bytes);
2985                // Empty payload.
2986            }
2987        }
2988    }
2989
2990    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2991        match EchClientHelloType::read(r)? {
2992            EchClientHelloType::ClientHelloOuter => {
2993                Ok(Self::Outer(EncryptedClientHelloOuter::read(r)?))
2994            }
2995            EchClientHelloType::ClientHelloInner => Ok(Self::Inner),
2996            _ => Err(InvalidMessage::InvalidContentType),
2997        }
2998    }
2999}
3000
3001/// Representation of the ECHClientHello extension with type outer specified in
3002/// [draft-ietf-tls-esni Section 5].
3003///
3004/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3005#[derive(Clone, Debug)]
3006pub struct EncryptedClientHelloOuter {
3007    /// The cipher suite used to encrypt ClientHelloInner. Must match a value from
3008    /// ECHConfigContents.cipher_suites list.
3009    pub cipher_suite: HpkeSymmetricCipherSuite,
3010    /// The ECHConfigContents.key_config.config_id for the chosen ECHConfig.
3011    pub config_id: u8,
3012    /// The HPKE encapsulated key, used by servers to decrypt the corresponding payload field.
3013    /// This field is empty in a ClientHelloOuter sent in response to a HelloRetryRequest.
3014    pub enc: PayloadU16,
3015    /// The serialized and encrypted ClientHelloInner structure, encrypted using HPKE.
3016    pub payload: PayloadU16,
3017}
3018
3019impl Codec<'_> for EncryptedClientHelloOuter {
3020    fn encode(&self, bytes: &mut Vec<u8>) {
3021        self.cipher_suite.encode(bytes);
3022        self.config_id.encode(bytes);
3023        self.enc.encode(bytes);
3024        self.payload.encode(bytes);
3025    }
3026
3027    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3028        Ok(Self {
3029            cipher_suite: HpkeSymmetricCipherSuite::read(r)?,
3030            config_id: u8::read(r)?,
3031            enc: PayloadU16::read(r)?,
3032            payload: PayloadU16::read(r)?,
3033        })
3034    }
3035}
3036
3037/// Representation of the ECHEncryptedExtensions extension specified in
3038/// [draft-ietf-tls-esni Section 5].
3039///
3040/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3041#[derive(Clone, Debug)]
3042pub struct ServerEncryptedClientHello {
3043    pub(crate) retry_configs: Vec<EchConfigPayload>,
3044}
3045
3046impl Codec<'_> for ServerEncryptedClientHello {
3047    fn encode(&self, bytes: &mut Vec<u8>) {
3048        self.retry_configs.encode(bytes);
3049    }
3050
3051    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3052        Ok(Self {
3053            retry_configs: Vec::<EchConfigPayload>::read(r)?,
3054        })
3055    }
3056}
3057
3058/// The method of encoding to use for a handshake message.
3059///
3060/// In some cases a handshake message may be encoded differently depending on the purpose
3061/// the encoded message is being used for. For example, a [ServerHelloPayload] may be encoded
3062/// with the last 8 bytes of the random zeroed out when being encoded for ECH confirmation.
3063pub(crate) enum Encoding {
3064    /// Standard RFC 8446 encoding.
3065    Standard,
3066    /// Encoding for ECH confirmation.
3067    EchConfirmation,
3068    /// Encoding for ECH inner client hello.
3069    EchInnerHello { to_compress: Vec<ExtensionType> },
3070}
3071
3072fn has_duplicates<I: IntoIterator<Item = E>, E: Into<T>, T: Eq + Ord>(iter: I) -> bool {
3073    let mut seen = BTreeSet::new();
3074
3075    for x in iter {
3076        if !seen.insert(x.into()) {
3077            return true;
3078        }
3079    }
3080
3081    false
3082}
3083
3084#[cfg(test)]
3085mod tests {
3086    use super::*;
3087
3088    #[test]
3089    fn test_ech_config_dupe_exts() {
3090        let unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3091            typ: ExtensionType::Unknown(0x42),
3092            payload: Payload::new(vec![0x42]),
3093        });
3094        let mut config = config_template();
3095        config
3096            .extensions
3097            .push(unknown_ext.clone());
3098        config.extensions.push(unknown_ext);
3099
3100        assert!(config.has_duplicate_extension());
3101        assert!(!config.has_unknown_mandatory_extension());
3102    }
3103
3104    #[test]
3105    fn test_ech_config_mandatory_exts() {
3106        let mandatory_unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3107            typ: ExtensionType::Unknown(0x42 | 0x8000), // Note: high bit set.
3108            payload: Payload::new(vec![0x42]),
3109        });
3110        let mut config = config_template();
3111        config
3112            .extensions
3113            .push(mandatory_unknown_ext);
3114
3115        assert!(!config.has_duplicate_extension());
3116        assert!(config.has_unknown_mandatory_extension());
3117    }
3118
3119    fn config_template() -> EchConfigContents {
3120        EchConfigContents {
3121            key_config: HpkeKeyConfig {
3122                config_id: 0,
3123                kem_id: HpkeKem::DHKEM_P256_HKDF_SHA256,
3124                public_key: PayloadU16(b"xxx".into()),
3125                symmetric_cipher_suites: vec![HpkeSymmetricCipherSuite {
3126                    kdf_id: HpkeKdf::HKDF_SHA256,
3127                    aead_id: HpkeAead::AES_128_GCM,
3128                }],
3129            },
3130            maximum_name_length: 0,
3131            public_name: DnsName::try_from("example.com").unwrap(),
3132            extensions: vec![],
3133        }
3134    }
3135}