rustls/msgs/
persist.rs

1use alloc::vec::Vec;
2use core::cmp;
3#[cfg(feature = "tls12")]
4use core::mem;
5
6use pki_types::{DnsName, UnixTime};
7use zeroize::Zeroizing;
8
9use crate::enums::{CipherSuite, ProtocolVersion};
10use crate::error::InvalidMessage;
11use crate::msgs::base::{PayloadU16, PayloadU8};
12use crate::msgs::codec::{Codec, Reader};
13use crate::msgs::handshake::CertificateChain;
14#[cfg(feature = "tls12")]
15use crate::msgs::handshake::SessionId;
16#[cfg(feature = "tls12")]
17use crate::tls12::Tls12CipherSuite;
18use crate::tls13::Tls13CipherSuite;
19
20pub(crate) struct Retrieved<T> {
21    pub(crate) value: T,
22    retrieved_at: UnixTime,
23}
24
25impl<T> Retrieved<T> {
26    pub(crate) fn new(value: T, retrieved_at: UnixTime) -> Self {
27        Self {
28            value,
29            retrieved_at,
30        }
31    }
32
33    pub(crate) fn map<M>(&self, f: impl FnOnce(&T) -> Option<&M>) -> Option<Retrieved<&M>> {
34        Some(Retrieved {
35            value: f(&self.value)?,
36            retrieved_at: self.retrieved_at,
37        })
38    }
39}
40
41impl Retrieved<&Tls13ClientSessionValue> {
42    pub(crate) fn obfuscated_ticket_age(&self) -> u32 {
43        let age_secs = self
44            .retrieved_at
45            .as_secs()
46            .saturating_sub(self.value.common.epoch);
47        let age_millis = age_secs as u32 * 1000;
48        age_millis.wrapping_add(self.value.age_add)
49    }
50}
51
52impl<T: core::ops::Deref<Target = ClientSessionCommon>> Retrieved<T> {
53    pub(crate) fn has_expired(&self) -> bool {
54        let common = &*self.value;
55        common.lifetime_secs != 0
56            && common
57                .epoch
58                .saturating_add(u64::from(common.lifetime_secs))
59                < self.retrieved_at.as_secs()
60    }
61}
62
63impl<T> core::ops::Deref for Retrieved<T> {
64    type Target = T;
65
66    fn deref(&self) -> &Self::Target {
67        &self.value
68    }
69}
70
71#[derive(Debug)]
72pub struct Tls13ClientSessionValue {
73    suite: &'static Tls13CipherSuite,
74    age_add: u32,
75    max_early_data_size: u32,
76    pub(crate) common: ClientSessionCommon,
77    quic_params: PayloadU16,
78}
79
80impl Tls13ClientSessionValue {
81    pub(crate) fn new(
82        suite: &'static Tls13CipherSuite,
83        ticket: Vec<u8>,
84        secret: &[u8],
85        server_cert_chain: CertificateChain<'static>,
86        time_now: UnixTime,
87        lifetime_secs: u32,
88        age_add: u32,
89        max_early_data_size: u32,
90    ) -> Self {
91        Self {
92            suite,
93            age_add,
94            max_early_data_size,
95            common: ClientSessionCommon::new(
96                ticket,
97                secret,
98                time_now,
99                lifetime_secs,
100                server_cert_chain,
101            ),
102            quic_params: PayloadU16(Vec::new()),
103        }
104    }
105
106    pub fn max_early_data_size(&self) -> u32 {
107        self.max_early_data_size
108    }
109
110    pub fn suite(&self) -> &'static Tls13CipherSuite {
111        self.suite
112    }
113
114    #[doc(hidden)]
115    /// Test only: rewind epoch by `delta` seconds.
116    pub fn rewind_epoch(&mut self, delta: u32) {
117        self.common.epoch -= delta as u64;
118    }
119
120    #[doc(hidden)]
121    /// Test only: replace `max_early_data_size` with `new`
122    pub fn _private_set_max_early_data_size(&mut self, new: u32) {
123        self.max_early_data_size = new;
124    }
125
126    pub fn set_quic_params(&mut self, quic_params: &[u8]) {
127        self.quic_params = PayloadU16(quic_params.to_vec());
128    }
129
130    pub fn quic_params(&self) -> Vec<u8> {
131        self.quic_params.0.clone()
132    }
133}
134
135impl core::ops::Deref for Tls13ClientSessionValue {
136    type Target = ClientSessionCommon;
137
138    fn deref(&self) -> &Self::Target {
139        &self.common
140    }
141}
142
143#[derive(Debug, Clone)]
144pub struct Tls12ClientSessionValue {
145    #[cfg(feature = "tls12")]
146    suite: &'static Tls12CipherSuite,
147    #[cfg(feature = "tls12")]
148    pub(crate) session_id: SessionId,
149    #[cfg(feature = "tls12")]
150    extended_ms: bool,
151    #[doc(hidden)]
152    #[cfg(feature = "tls12")]
153    pub(crate) common: ClientSessionCommon,
154}
155
156#[cfg(feature = "tls12")]
157impl Tls12ClientSessionValue {
158    pub(crate) fn new(
159        suite: &'static Tls12CipherSuite,
160        session_id: SessionId,
161        ticket: Vec<u8>,
162        master_secret: &[u8],
163        server_cert_chain: CertificateChain<'static>,
164        time_now: UnixTime,
165        lifetime_secs: u32,
166        extended_ms: bool,
167    ) -> Self {
168        Self {
169            suite,
170            session_id,
171            extended_ms,
172            common: ClientSessionCommon::new(
173                ticket,
174                master_secret,
175                time_now,
176                lifetime_secs,
177                server_cert_chain,
178            ),
179        }
180    }
181
182    pub(crate) fn take_ticket(&mut self) -> Vec<u8> {
183        mem::take(&mut self.common.ticket.0)
184    }
185
186    pub(crate) fn extended_ms(&self) -> bool {
187        self.extended_ms
188    }
189
190    pub(crate) fn suite(&self) -> &'static Tls12CipherSuite {
191        self.suite
192    }
193
194    #[doc(hidden)]
195    /// Test only: rewind epoch by `delta` seconds.
196    pub fn rewind_epoch(&mut self, delta: u32) {
197        self.common.epoch -= delta as u64;
198    }
199}
200
201#[cfg(feature = "tls12")]
202impl core::ops::Deref for Tls12ClientSessionValue {
203    type Target = ClientSessionCommon;
204
205    fn deref(&self) -> &Self::Target {
206        &self.common
207    }
208}
209
210#[derive(Debug, Clone)]
211pub struct ClientSessionCommon {
212    ticket: PayloadU16,
213    secret: Zeroizing<PayloadU8>,
214    epoch: u64,
215    lifetime_secs: u32,
216    server_cert_chain: CertificateChain<'static>,
217}
218
219impl ClientSessionCommon {
220    fn new(
221        ticket: Vec<u8>,
222        secret: &[u8],
223        time_now: UnixTime,
224        lifetime_secs: u32,
225        server_cert_chain: CertificateChain<'static>,
226    ) -> Self {
227        Self {
228            ticket: PayloadU16(ticket),
229            secret: Zeroizing::new(PayloadU8(secret.to_vec())),
230            epoch: time_now.as_secs(),
231            lifetime_secs: cmp::min(lifetime_secs, MAX_TICKET_LIFETIME),
232            server_cert_chain,
233        }
234    }
235
236    pub(crate) fn server_cert_chain(&self) -> &CertificateChain<'static> {
237        &self.server_cert_chain
238    }
239
240    pub(crate) fn secret(&self) -> &[u8] {
241        self.secret.0.as_ref()
242    }
243
244    pub(crate) fn ticket(&self) -> &[u8] {
245        self.ticket.0.as_ref()
246    }
247}
248
249static MAX_TICKET_LIFETIME: u32 = 7 * 24 * 60 * 60;
250
251/// This is the maximum allowed skew between server and client clocks, over
252/// the maximum ticket lifetime period.  This encompasses TCP retransmission
253/// times in case packet loss occurs when the client sends the ClientHello
254/// or receives the NewSessionTicket, _and_ actual clock skew over this period.
255static MAX_FRESHNESS_SKEW_MS: u32 = 60 * 1000;
256
257// --- Server types ---
258#[derive(Debug)]
259pub struct ServerSessionValue {
260    pub(crate) sni: Option<DnsName<'static>>,
261    pub(crate) version: ProtocolVersion,
262    pub(crate) cipher_suite: CipherSuite,
263    pub(crate) master_secret: Zeroizing<PayloadU8>,
264    pub(crate) extended_ms: bool,
265    pub(crate) client_cert_chain: Option<CertificateChain<'static>>,
266    pub(crate) alpn: Option<PayloadU8>,
267    pub(crate) application_data: PayloadU16,
268    pub creation_time_sec: u64,
269    pub(crate) age_obfuscation_offset: u32,
270    freshness: Option<bool>,
271}
272
273impl Codec<'_> for ServerSessionValue {
274    fn encode(&self, bytes: &mut Vec<u8>) {
275        if let Some(ref sni) = self.sni {
276            1u8.encode(bytes);
277            let sni_bytes: &str = sni.as_ref();
278            PayloadU8::new(Vec::from(sni_bytes)).encode(bytes);
279        } else {
280            0u8.encode(bytes);
281        }
282        self.version.encode(bytes);
283        self.cipher_suite.encode(bytes);
284        self.master_secret.encode(bytes);
285        (u8::from(self.extended_ms)).encode(bytes);
286        if let Some(ref chain) = self.client_cert_chain {
287            1u8.encode(bytes);
288            chain.encode(bytes);
289        } else {
290            0u8.encode(bytes);
291        }
292        if let Some(ref alpn) = self.alpn {
293            1u8.encode(bytes);
294            alpn.encode(bytes);
295        } else {
296            0u8.encode(bytes);
297        }
298        self.application_data.encode(bytes);
299        self.creation_time_sec.encode(bytes);
300        self.age_obfuscation_offset
301            .encode(bytes);
302    }
303
304    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
305        let has_sni = u8::read(r)?;
306        let sni = if has_sni == 1 {
307            let dns_name = PayloadU8::read(r)?;
308            let dns_name = match DnsName::try_from(dns_name.0.as_slice()) {
309                Ok(dns_name) => dns_name.to_owned(),
310                Err(_) => return Err(InvalidMessage::InvalidServerName),
311            };
312
313            Some(dns_name)
314        } else {
315            None
316        };
317
318        let v = ProtocolVersion::read(r)?;
319        let cs = CipherSuite::read(r)?;
320        let ms = Zeroizing::new(PayloadU8::read(r)?);
321        let ems = u8::read(r)?;
322        let has_ccert = u8::read(r)? == 1;
323        let ccert = if has_ccert {
324            Some(CertificateChain::read(r)?.into_owned())
325        } else {
326            None
327        };
328        let has_alpn = u8::read(r)? == 1;
329        let alpn = if has_alpn {
330            Some(PayloadU8::read(r)?)
331        } else {
332            None
333        };
334        let application_data = PayloadU16::read(r)?;
335        let creation_time_sec = u64::read(r)?;
336        let age_obfuscation_offset = u32::read(r)?;
337
338        Ok(Self {
339            sni,
340            version: v,
341            cipher_suite: cs,
342            master_secret: ms,
343            extended_ms: ems == 1u8,
344            client_cert_chain: ccert,
345            alpn,
346            application_data,
347            creation_time_sec,
348            age_obfuscation_offset,
349            freshness: None,
350        })
351    }
352}
353
354impl ServerSessionValue {
355    pub(crate) fn new(
356        sni: Option<&DnsName<'_>>,
357        v: ProtocolVersion,
358        cs: CipherSuite,
359        ms: &[u8],
360        client_cert_chain: Option<CertificateChain<'static>>,
361        alpn: Option<Vec<u8>>,
362        application_data: Vec<u8>,
363        creation_time: UnixTime,
364        age_obfuscation_offset: u32,
365    ) -> Self {
366        Self {
367            sni: sni.map(|dns| dns.to_owned()),
368            version: v,
369            cipher_suite: cs,
370            master_secret: Zeroizing::new(PayloadU8::new(ms.to_vec())),
371            extended_ms: false,
372            client_cert_chain,
373            alpn: alpn.map(PayloadU8::new),
374            application_data: PayloadU16::new(application_data),
375            creation_time_sec: creation_time.as_secs(),
376            age_obfuscation_offset,
377            freshness: None,
378        }
379    }
380
381    #[cfg(feature = "tls12")]
382    pub(crate) fn set_extended_ms_used(&mut self) {
383        self.extended_ms = true;
384    }
385
386    pub(crate) fn set_freshness(
387        mut self,
388        obfuscated_client_age_ms: u32,
389        time_now: UnixTime,
390    ) -> Self {
391        let client_age_ms = obfuscated_client_age_ms.wrapping_sub(self.age_obfuscation_offset);
392        let server_age_ms = (time_now
393            .as_secs()
394            .saturating_sub(self.creation_time_sec) as u32)
395            .saturating_mul(1000);
396
397        let age_difference = if client_age_ms < server_age_ms {
398            server_age_ms - client_age_ms
399        } else {
400            client_age_ms - server_age_ms
401        };
402
403        self.freshness = Some(age_difference <= MAX_FRESHNESS_SKEW_MS);
404        self
405    }
406
407    pub(crate) fn is_fresh(&self) -> bool {
408        self.freshness.unwrap_or_default()
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[cfg(feature = "std")] // for UnixTime::now
417    #[test]
418    fn serversessionvalue_is_debug() {
419        use std::{println, vec};
420        let ssv = ServerSessionValue::new(
421            None,
422            ProtocolVersion::TLSv1_3,
423            CipherSuite::TLS13_AES_128_GCM_SHA256,
424            &[1, 2, 3],
425            None,
426            None,
427            vec![4, 5, 6],
428            UnixTime::now(),
429            0x12345678,
430        );
431        println!("{:?}", ssv);
432    }
433
434    #[test]
435    fn serversessionvalue_no_sni() {
436        let bytes = [
437            0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
438            0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
439        ];
440        let mut rd = Reader::init(&bytes);
441        let ssv = ServerSessionValue::read(&mut rd).unwrap();
442        assert_eq!(ssv.get_encoding(), bytes);
443    }
444
445    #[test]
446    fn serversessionvalue_with_cert() {
447        let bytes = [
448            0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
449            0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
450        ];
451        let mut rd = Reader::init(&bytes);
452        let ssv = ServerSessionValue::read(&mut rd).unwrap();
453        assert_eq!(ssv.get_encoding(), bytes);
454    }
455}