1use alloc::sync::Arc;
2use alloc::vec::Vec;
3use core::fmt::Debug;
4
5use crate::server::ClientHello;
6use crate::{server, sign};
7
8#[derive(Debug)]
10pub struct NoServerSessionStorage {}
11
12impl server::StoresServerSessions for NoServerSessionStorage {
13 fn put(&self, _id: Vec<u8>, _sec: Vec<u8>) -> bool {
14 false
15 }
16 fn get(&self, _id: &[u8]) -> Option<Vec<u8>> {
17 None
18 }
19 fn take(&self, _id: &[u8]) -> Option<Vec<u8>> {
20 None
21 }
22 fn can_cache(&self) -> bool {
23 false
24 }
25}
26
27#[cfg(any(feature = "std", feature = "hashbrown"))]
28mod cache {
29 use alloc::sync::Arc;
30 use alloc::vec::Vec;
31 use core::fmt::{Debug, Formatter};
32
33 use crate::lock::Mutex;
34 use crate::{limited_cache, server};
35
36 pub struct ServerSessionMemoryCache {
40 cache: Mutex<limited_cache::LimitedCache<Vec<u8>, Vec<u8>>>,
41 }
42
43 impl ServerSessionMemoryCache {
44 #[cfg(feature = "std")]
48 pub fn new(size: usize) -> Arc<Self> {
49 Arc::new(Self {
50 cache: Mutex::new(limited_cache::LimitedCache::new(size)),
51 })
52 }
53
54 #[cfg(not(feature = "std"))]
58 pub fn new<M: crate::lock::MakeMutex>(size: usize) -> Arc<Self> {
59 Arc::new(Self {
60 cache: Mutex::new::<M>(limited_cache::LimitedCache::new(size)),
61 })
62 }
63 }
64
65 impl server::StoresServerSessions for ServerSessionMemoryCache {
66 fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool {
67 self.cache
68 .lock()
69 .unwrap()
70 .insert(key, value);
71 true
72 }
73
74 fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
75 self.cache
76 .lock()
77 .unwrap()
78 .get(key)
79 .cloned()
80 }
81
82 fn take(&self, key: &[u8]) -> Option<Vec<u8>> {
83 self.cache.lock().unwrap().remove(key)
84 }
85
86 fn can_cache(&self) -> bool {
87 true
88 }
89 }
90
91 impl Debug for ServerSessionMemoryCache {
92 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
93 f.debug_struct("ServerSessionMemoryCache")
94 .finish()
95 }
96 }
97
98 #[cfg(test)]
99 mod tests {
100 use std::vec;
101
102 use super::*;
103 use crate::server::StoresServerSessions;
104
105 #[test]
106 fn test_serversessionmemorycache_accepts_put() {
107 let c = ServerSessionMemoryCache::new(4);
108 assert!(c.put(vec![0x01], vec![0x02]));
109 }
110
111 #[test]
112 fn test_serversessionmemorycache_persists_put() {
113 let c = ServerSessionMemoryCache::new(4);
114 assert!(c.put(vec![0x01], vec![0x02]));
115 assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
116 assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
117 }
118
119 #[test]
120 fn test_serversessionmemorycache_overwrites_put() {
121 let c = ServerSessionMemoryCache::new(4);
122 assert!(c.put(vec![0x01], vec![0x02]));
123 assert!(c.put(vec![0x01], vec![0x04]));
124 assert_eq!(c.get(&[0x01]), Some(vec![0x04]));
125 }
126
127 #[test]
128 fn test_serversessionmemorycache_drops_to_maintain_size_invariant() {
129 let c = ServerSessionMemoryCache::new(2);
130 assert!(c.put(vec![0x01], vec![0x02]));
131 assert!(c.put(vec![0x03], vec![0x04]));
132 assert!(c.put(vec![0x05], vec![0x06]));
133 assert!(c.put(vec![0x07], vec![0x08]));
134 assert!(c.put(vec![0x09], vec![0x0a]));
135
136 let count = c.get(&[0x01]).iter().count()
137 + c.get(&[0x03]).iter().count()
138 + c.get(&[0x05]).iter().count()
139 + c.get(&[0x07]).iter().count()
140 + c.get(&[0x09]).iter().count();
141
142 assert!(count < 5);
143 }
144 }
145}
146
147#[cfg(any(feature = "std", feature = "hashbrown"))]
148pub use cache::ServerSessionMemoryCache;
149
150#[derive(Debug)]
152pub(super) struct NeverProducesTickets {}
153
154impl server::ProducesTickets for NeverProducesTickets {
155 fn enabled(&self) -> bool {
156 false
157 }
158 fn lifetime(&self) -> u32 {
159 0
160 }
161 fn encrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
162 None
163 }
164 fn decrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
165 None
166 }
167}
168
169#[derive(Debug)]
171pub(super) struct AlwaysResolvesChain(Arc<sign::CertifiedKey>);
172
173impl AlwaysResolvesChain {
174 pub(super) fn new(certified_key: sign::CertifiedKey) -> Self {
176 Self(Arc::new(certified_key))
177 }
178
179 pub(super) fn new_with_extras(certified_key: sign::CertifiedKey, ocsp: Vec<u8>) -> Self {
183 let mut r = Self::new(certified_key);
184
185 {
186 let cert = Arc::make_mut(&mut r.0);
187 if !ocsp.is_empty() {
188 cert.ocsp = Some(ocsp);
189 }
190 }
191
192 r
193 }
194}
195
196impl server::ResolvesServerCert for AlwaysResolvesChain {
197 fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<sign::CertifiedKey>> {
198 Some(Arc::clone(&self.0))
199 }
200}
201
202#[cfg(any(feature = "std", feature = "hashbrown"))]
203mod sni_resolver {
204 use alloc::string::{String, ToString};
205 use alloc::sync::Arc;
206 use core::fmt::Debug;
207
208 use pki_types::{DnsName, ServerName};
209
210 use crate::error::Error;
211 use crate::hash_map::HashMap;
212 use crate::server::ClientHello;
213 use crate::webpki::{verify_server_name, ParsedCertificate};
214 use crate::{server, sign};
215
216 #[derive(Debug)]
219 pub struct ResolvesServerCertUsingSni {
220 by_name: HashMap<String, Arc<sign::CertifiedKey>>,
221 }
222
223 impl ResolvesServerCertUsingSni {
224 pub fn new() -> Self {
226 Self {
227 by_name: HashMap::new(),
228 }
229 }
230
231 pub fn add(&mut self, name: &str, ck: sign::CertifiedKey) -> Result<(), Error> {
237 let server_name = {
238 let checked_name = DnsName::try_from(name)
239 .map_err(|_| Error::General("Bad DNS name".into()))
240 .map(|name| name.to_lowercase_owned())?;
241 ServerName::DnsName(checked_name)
242 };
243
244 ck.end_entity_cert()
254 .and_then(ParsedCertificate::try_from)
255 .and_then(|cert| verify_server_name(&cert, &server_name))?;
256
257 if let ServerName::DnsName(name) = server_name {
258 self.by_name
259 .insert(name.as_ref().to_string(), Arc::new(ck));
260 }
261 Ok(())
262 }
263 }
264
265 impl server::ResolvesServerCert for ResolvesServerCertUsingSni {
266 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<sign::CertifiedKey>> {
267 if let Some(name) = client_hello.server_name() {
268 self.by_name.get(name).cloned()
269 } else {
270 None
272 }
273 }
274 }
275
276 #[cfg(test)]
277 mod tests {
278 use super::*;
279 use crate::server::ResolvesServerCert;
280
281 #[test]
282 fn test_resolvesservercertusingsni_requires_sni() {
283 let rscsni = ResolvesServerCertUsingSni::new();
284 assert!(rscsni
285 .resolve(ClientHello::new(&None, &[], None, &[]))
286 .is_none());
287 }
288
289 #[test]
290 fn test_resolvesservercertusingsni_handles_unknown_name() {
291 let rscsni = ResolvesServerCertUsingSni::new();
292 let name = DnsName::try_from("hello.com")
293 .unwrap()
294 .to_owned();
295 assert!(rscsni
296 .resolve(ClientHello::new(&Some(name), &[], None, &[]))
297 .is_none());
298 }
299 }
300}
301
302#[cfg(any(feature = "std", feature = "hashbrown"))]
303pub use sni_resolver::ResolvesServerCertUsingSni;
304
305#[cfg(test)]
306mod tests {
307 use std::vec;
308
309 use super::*;
310 use crate::server::{ProducesTickets, StoresServerSessions};
311
312 #[test]
313 fn test_noserversessionstorage_drops_put() {
314 let c = NoServerSessionStorage {};
315 assert!(!c.put(vec![0x01], vec![0x02]));
316 }
317
318 #[test]
319 fn test_noserversessionstorage_denies_gets() {
320 let c = NoServerSessionStorage {};
321 c.put(vec![0x01], vec![0x02]);
322 assert_eq!(c.get(&[]), None);
323 assert_eq!(c.get(&[0x01]), None);
324 assert_eq!(c.get(&[0x02]), None);
325 }
326
327 #[test]
328 fn test_noserversessionstorage_denies_takes() {
329 let c = NoServerSessionStorage {};
330 assert_eq!(c.take(&[]), None);
331 assert_eq!(c.take(&[0x01]), None);
332 assert_eq!(c.take(&[0x02]), None);
333 }
334
335 #[test]
336 fn test_neverproducestickets_does_nothing() {
337 let npt = NeverProducesTickets {};
338 assert!(!npt.enabled());
339 assert_eq!(0, npt.lifetime());
340 assert_eq!(None, npt.encrypt(&[]));
341 assert_eq!(None, npt.decrypt(&[]));
342 }
343}