1use alloc::vec::Vec;
2use core::mem;
3use core::ops::Range;
4
5use super::buffers::{Coalescer, Delocator, Locator};
6use crate::error::InvalidMessage;
7use crate::msgs::codec::{u24, Codec};
8use crate::msgs::message::InboundPlainMessage;
9use crate::{ContentType, ProtocolVersion};
10
11#[derive(Debug)]
12pub(crate) struct HandshakeDeframer {
13 spans: Vec<FragmentSpan>,
15
16 outer_discard: usize,
19}
20
21impl HandshakeDeframer {
22 pub(crate) fn input_message(
37 &mut self,
38 msg: InboundPlainMessage<'_>,
39 containing_buffer: &Locator,
40 outer_discard: usize,
41 ) {
42 debug_assert_eq!(msg.typ, ContentType::Handshake);
43 debug_assert!(containing_buffer.fully_contains(msg.payload));
44 debug_assert!(self.outer_discard <= outer_discard);
45
46 self.outer_discard = outer_discard;
47
48 if let Some(_last_incomplete) = self
57 .spans
58 .last()
59 .filter(|span| !span.is_complete())
60 {
61 self.spans.push(FragmentSpan {
62 version: msg.version,
63 size: None,
64 bounds: containing_buffer.locate(msg.payload),
65 });
66 return;
67 }
68
69 for span in DissectHandshakeIter::new(msg, containing_buffer) {
72 self.spans.push(span);
73 }
74 }
75
76 pub(crate) fn has_message_ready(&self) -> bool {
78 match self.spans.first() {
79 Some(span) => span.is_complete(),
80 None => false,
81 }
82 }
83
84 pub(crate) fn is_active(&self) -> bool {
86 !self.spans.is_empty()
87 }
88
89 pub(crate) fn is_aligned(&self) -> bool {
92 self.spans
93 .iter()
94 .all(|span| span.is_complete())
95 }
96
97 pub(crate) fn iter<'a, 'b>(&'a mut self, containing_buffer: &'b [u8]) -> HandshakeIter<'a, 'b> {
99 HandshakeIter {
100 deframer: self,
101 containing_buffer: Delocator::new(containing_buffer),
102 index: 0,
103 }
104 }
105
106 pub(crate) fn coalesce(&mut self, containing_buffer: &mut [u8]) -> Result<(), InvalidMessage> {
160 while let Some(i) = self.requires_coalesce() {
164 self.coalesce_one(i, Coalescer::new(containing_buffer));
165 }
166
167 match self
169 .spans
170 .iter()
171 .any(|span| span.size.unwrap_or_default() > MAX_HANDSHAKE_SIZE)
172 {
173 true => Err(InvalidMessage::HandshakePayloadTooLarge),
174 false => Ok(()),
175 }
176 }
177
178 fn coalesce_one(&mut self, index: usize, mut containing_buffer: Coalescer<'_>) {
181 let second = self.spans.remove(index + 1);
182 let mut first = self.spans.remove(index);
183
184 let len = second.bounds.len();
186 let target = Range {
187 start: first.bounds.end,
188 end: first.bounds.end + len,
189 };
190
191 containing_buffer.copy_within(second.bounds, target);
192 let delocator = containing_buffer.delocator();
193
194 first.bounds.end += len;
196
197 let msg = InboundPlainMessage {
199 typ: ContentType::Handshake,
200 version: first.version,
201 payload: delocator.slice_from_range(&first.bounds),
202 };
203
204 for (i, span) in DissectHandshakeIter::new(msg, &delocator.locator()).enumerate() {
205 self.spans.insert(index + i, span);
206 }
207 }
208
209 fn requires_coalesce(&self) -> Option<usize> {
214 self.spans
215 .split_last()
216 .and_then(|(_last, elements)| {
217 elements
218 .iter()
219 .enumerate()
220 .find_map(|(i, span)| (!span.is_complete()).then_some(i))
221 })
222 }
223}
224
225impl Default for HandshakeDeframer {
226 fn default() -> Self {
227 Self {
228 spans: Vec::with_capacity(16),
231 outer_discard: 0,
232 }
233 }
234}
235
236struct DissectHandshakeIter<'a, 'b> {
237 version: ProtocolVersion,
238 payload: &'b [u8],
239 containing_buffer: &'a Locator,
240}
241
242impl<'a, 'b> DissectHandshakeIter<'a, 'b> {
243 fn new(msg: InboundPlainMessage<'b>, containing_buffer: &'a Locator) -> Self {
244 Self {
245 version: msg.version,
246 payload: msg.payload,
247 containing_buffer,
248 }
249 }
250}
251
252impl Iterator for DissectHandshakeIter<'_, '_> {
253 type Item = FragmentSpan;
254
255 fn next(&mut self) -> Option<Self::Item> {
256 if self.payload.is_empty() {
257 return None;
258 }
259
260 if self.payload.len() < HANDSHAKE_HEADER_LEN {
262 let buf = mem::take(&mut self.payload);
263 let bounds = self.containing_buffer.locate(buf);
264 return Some(FragmentSpan {
265 version: self.version,
266 size: None,
267 bounds: bounds.clone(),
268 });
269 }
270
271 let (header, rest) = mem::take(&mut self.payload).split_at(HANDSHAKE_HEADER_LEN);
272
273 let size = u24::read_bytes(&header[1..])
275 .unwrap()
276 .into();
277
278 let available = if size < rest.len() {
279 self.payload = &rest[size..];
280 size
281 } else {
282 rest.len()
283 };
284
285 let mut bounds = self.containing_buffer.locate(header);
286 bounds.end += available;
287 Some(FragmentSpan {
288 version: self.version,
289 size: Some(size),
290 bounds: bounds.clone(),
291 })
292 }
293}
294
295pub(crate) struct HandshakeIter<'a, 'b> {
296 deframer: &'a mut HandshakeDeframer,
297 containing_buffer: Delocator<'b>,
298 index: usize,
299}
300
301impl<'a, 'b> Iterator for HandshakeIter<'a, 'b> {
302 type Item = (InboundPlainMessage<'b>, usize);
303
304 fn next(&mut self) -> Option<Self::Item> {
305 let next_span = self.deframer.spans.get(self.index)?;
306
307 if !next_span.is_complete() {
308 return None;
309 }
310
311 let discard = if self.deframer.spans.len() - 1 == self.index {
315 mem::take(&mut self.deframer.outer_discard)
316 } else {
317 0
318 };
319
320 self.index += 1;
321 Some((
322 InboundPlainMessage {
323 typ: ContentType::Handshake,
324 version: next_span.version,
325 payload: self
326 .containing_buffer
327 .slice_from_range(&next_span.bounds),
328 },
329 discard,
330 ))
331 }
332}
333
334impl Drop for HandshakeIter<'_, '_> {
335 fn drop(&mut self) {
336 self.deframer.spans.drain(..self.index);
337 }
338}
339
340#[derive(Debug)]
341struct FragmentSpan {
342 version: ProtocolVersion,
344
345 size: Option<usize>,
350
351 bounds: Range<usize>,
353}
354
355impl FragmentSpan {
356 fn is_complete(&self) -> bool {
359 match self.size {
360 Some(sz) => sz + HANDSHAKE_HEADER_LEN == self.bounds.len(),
361 None => false,
362 }
363 }
364}
365
366const HANDSHAKE_HEADER_LEN: usize = 1 + 3;
367
368const MAX_HANDSHAKE_SIZE: usize = 0xffff;
372
373#[cfg(test)]
374mod tests {
375 use std::vec;
376
377 use super::*;
378 use crate::msgs::deframer::DeframerIter;
379
380 fn add_bytes(hs: &mut HandshakeDeframer, slice: &[u8], within: &[u8]) {
381 let msg = InboundPlainMessage {
382 typ: ContentType::Handshake,
383 version: ProtocolVersion::TLSv1_3,
384 payload: slice,
385 };
386 let locator = Locator::new(within);
387 let discard = locator.locate(slice).end;
388 hs.input_message(msg, &locator, discard);
389 }
390
391 #[test]
392 fn coalesce() {
393 let mut input = vec![0, 0, 0, 0x21, 0, 0, 0, 0, 0x01, 0xff, 0x00, 0x01];
394 let mut hs = HandshakeDeframer::default();
395
396 add_bytes(&mut hs, &input[3..4], &input);
397 assert_eq!(hs.requires_coalesce(), None);
398 add_bytes(&mut hs, &input[4..6], &input);
399 assert_eq!(hs.requires_coalesce(), Some(0));
400 add_bytes(&mut hs, &input[8..10], &input);
401 assert_eq!(hs.requires_coalesce(), Some(0));
402
403 std::println!("before: {hs:?}");
404 hs.coalesce(&mut input).unwrap();
405 std::println!("after: {hs:?}");
406
407 let (msg, discard) = hs.iter(&input).next().unwrap();
408 std::println!("msg {msg:?} discard {discard:?}");
409 assert_eq!(msg.typ, ContentType::Handshake);
410 assert_eq!(msg.version, ProtocolVersion::TLSv1_3);
411 assert_eq!(msg.payload, &[0x21, 0x00, 0x00, 0x01, 0xff]);
412
413 input.drain(..discard);
414 assert_eq!(input, &[0, 1]);
415 }
416
417 #[test]
418 fn append() {
419 let mut input = vec![0, 0, 0, 0x21, 0, 0, 5, 0, 0, 1, 2, 3, 4, 5, 0];
420 let mut hs = HandshakeDeframer::default();
421
422 add_bytes(&mut hs, &input[3..7], &input);
423 add_bytes(&mut hs, &input[9..14], &input);
424 assert_eq!(hs.spans.len(), 2);
425
426 hs.coalesce(&mut input).unwrap();
427 assert_eq!(hs.spans.len(), 1);
428
429 let (msg, discard) = std::dbg!(hs.iter(&input).next().unwrap());
430 assert_eq!(msg.typ, ContentType::Handshake);
431 assert_eq!(msg.version, ProtocolVersion::TLSv1_3);
432 assert_eq!(msg.payload, &[0x21, 0x00, 0x00, 0x05, 1, 2, 3, 4, 5]);
433
434 input.drain(..discard);
435 assert_eq!(input, &[0]);
436 }
437
438 #[test]
439 fn coalesce_rejects_excess_size_message() {
440 const X: u8 = 0xff;
441 let mut input = vec![0x21, 0x01, 0x00, X, 0x00, 0xab, X];
442 let mut hs = HandshakeDeframer::default();
443
444 add_bytes(&mut hs, &input[0..3], &input);
447 add_bytes(&mut hs, &input[4..6], &input);
448
449 assert_eq!(
450 hs.coalesce(&mut input),
451 Err(InvalidMessage::HandshakePayloadTooLarge)
452 );
453 }
454
455 #[test]
456 fn iter_only_returns_full_messages() {
457 let input = [0, 0, 0, 0x21, 0, 0, 1, 0xab, 0x21, 0, 0, 1];
458
459 let mut hs = HandshakeDeframer::default();
460
461 add_bytes(&mut hs, &input[3..8], &input);
462 add_bytes(&mut hs, &input[8..12], &input);
463
464 let mut iter = hs.iter(&input);
465 let (msg, discard) = iter.next().unwrap();
466 assert!(iter.next().is_none());
467
468 assert_eq!(msg.typ, ContentType::Handshake);
469 assert_eq!(msg.version, ProtocolVersion::TLSv1_3);
470 assert_eq!(msg.payload, &[0x21, 0x00, 0x00, 0x01, 0xab]);
471 assert_eq!(discard, 0);
472 }
473
474 #[test]
475 fn handshake_flight() {
476 let mut input = include_bytes!("../../testdata/handshake-test.1.bin").to_vec();
478 let locator = Locator::new(&input);
479
480 let mut hs = HandshakeDeframer::default();
481
482 let mut iter = DeframerIter::new(&mut input[..]);
483
484 while let Some(message) = iter.next() {
485 let plain = message.unwrap().into_plain_message();
486 std::println!("message {plain:?}");
487
488 hs.input_message(plain, &locator, iter.bytes_consumed());
489 }
490
491 hs.coalesce(&mut input[..]).unwrap();
492
493 let mut iter = hs.iter(&input[..]);
494 for _ in 0..4 {
495 let (msg, discard) = iter.next().unwrap();
496 assert!(matches!(
497 msg,
498 InboundPlainMessage {
499 typ: ContentType::Handshake,
500 ..
501 }
502 ));
503 assert_eq!(discard, 0);
504 }
505
506 let (msg, discard) = iter.next().unwrap();
507 assert!(matches!(
508 msg,
509 InboundPlainMessage {
510 typ: ContentType::Handshake,
511 ..
512 }
513 ));
514 assert_eq!(discard, 4280);
515 drop(iter);
516
517 input.drain(0..discard);
518 assert!(input.is_empty());
519 }
520}