1use super::{
24 scattered::Scattered,
25 sphinx::{Surb, PAYLOAD_DATA_SIZE, SURB_SIZE},
26};
27use arrayref::{array_mut_ref, array_refs, mut_array_refs};
28use hashlink::{linked_hash_map::Entry, LinkedHashMap};
29use log::{debug, log, Level};
30use std::cmp::{max, min};
31
32pub const MESSAGE_ID_SIZE: usize = 16;
34pub type MessageId = [u8; MESSAGE_ID_SIZE];
36const FRAGMENT_INDEX_SIZE: usize = 2;
37type FragmentIndex = u16;
38const FRAGMENT_DATA_SIZE_SIZE: usize = 2;
39type FragmentDataSize = u16;
40const FRAGMENT_NUM_SURBS_SIZE: usize = 1;
41type FragmentNumSurbs = u8;
42const FRAGMENT_HEADER_SIZE: usize = MESSAGE_ID_SIZE +
43 FRAGMENT_INDEX_SIZE + FRAGMENT_INDEX_SIZE + FRAGMENT_DATA_SIZE_SIZE + FRAGMENT_NUM_SURBS_SIZE; pub const FRAGMENT_SIZE: usize = PAYLOAD_DATA_SIZE;
49pub type Fragment = [u8; FRAGMENT_SIZE];
50const FRAGMENT_PAYLOAD_SIZE: usize = FRAGMENT_SIZE - FRAGMENT_HEADER_SIZE;
51type FragmentPayload = [u8; FRAGMENT_PAYLOAD_SIZE];
52const MAX_SURBS_PER_FRAGMENT: usize = FRAGMENT_PAYLOAD_SIZE / SURB_SIZE;
53
54#[allow(clippy::type_complexity)]
55fn split_fragment(
56 fragment: &Fragment,
57) -> (
58 &MessageId,
59 &[u8; FRAGMENT_INDEX_SIZE],
60 &[u8; FRAGMENT_INDEX_SIZE],
61 &[u8; FRAGMENT_DATA_SIZE_SIZE],
62 &[u8; FRAGMENT_NUM_SURBS_SIZE],
63 &FragmentPayload,
64) {
65 array_refs![
66 fragment,
67 MESSAGE_ID_SIZE,
68 FRAGMENT_INDEX_SIZE,
69 FRAGMENT_INDEX_SIZE,
70 FRAGMENT_DATA_SIZE_SIZE,
71 FRAGMENT_NUM_SURBS_SIZE,
72 FRAGMENT_PAYLOAD_SIZE
73 ]
74}
75
76fn message_id(fragment: &Fragment) -> &MessageId {
77 split_fragment(fragment).0
78}
79
80fn num_fragments(fragment: &Fragment) -> usize {
81 (FragmentIndex::from_le_bytes(*split_fragment(fragment).1) as usize) + 1
82}
83
84fn fragment_index(fragment: &Fragment) -> usize {
85 FragmentIndex::from_le_bytes(*split_fragment(fragment).2) as usize
86}
87
88fn fragment_data_size(fragment: &Fragment) -> usize {
89 FragmentDataSize::from_le_bytes(*split_fragment(fragment).3) as usize
90}
91
92fn fragment_num_surbs(fragment: &Fragment) -> usize {
93 FragmentNumSurbs::from_le_bytes(*split_fragment(fragment).4) as usize
94}
95
96fn fragment_payload(fragment: &Fragment) -> &FragmentPayload {
97 split_fragment(fragment).5
98}
99
100#[derive(Debug, thiserror::Error)]
101enum CheckFragmentErr {
102 #[error("Out-of-range index ({index}, max {max})")]
103 Index { index: usize, max: usize },
104 #[error("Bad payload size ({size}, max {max})")]
105 PayloadSize { size: usize, max: usize },
106}
107
108fn check_fragment(fragment: &Fragment) -> Result<(), CheckFragmentErr> {
109 if fragment_index(fragment) >= num_fragments(fragment) {
110 return Err(CheckFragmentErr::Index {
111 index: fragment_index(fragment),
112 max: num_fragments(fragment) - 1,
113 })
114 }
115
116 let data_size = fragment_data_size(fragment);
117 let num_surbs = fragment_num_surbs(fragment);
118 let payload_size = data_size + (num_surbs * SURB_SIZE);
119 if payload_size > FRAGMENT_PAYLOAD_SIZE {
120 return Err(CheckFragmentErr::PayloadSize { size: payload_size, max: FRAGMENT_PAYLOAD_SIZE })
121 }
122
123 Ok(())
124}
125
126#[derive(Debug, PartialEq, Eq)]
127pub struct GenericMessage {
128 pub id: MessageId,
129 pub data: Vec<u8>,
130 pub surbs: Vec<Surb>,
131}
132
133impl GenericMessage {
134 fn from_fragments<'a>(fragments: impl Iterator<Item = &'a Fragment> + Clone) -> Self {
137 let id = *message_id(fragments.clone().next().expect("At least one fragment"));
138
139 let mut data = Vec::with_capacity(fragments.clone().map(fragment_data_size).sum());
140 let mut surbs = Vec::with_capacity(fragments.clone().map(fragment_num_surbs).sum());
141 for fragment in fragments {
142 debug_assert!(check_fragment(fragment).is_ok());
143 let payload = fragment_payload(fragment);
144 data.extend_from_slice(&payload[..fragment_data_size(fragment)]);
145 surbs.extend(
146 payload
147 .rchunks_exact(SURB_SIZE)
149 .map(|surb| {
150 TryInto::<&Surb>::try_into(surb)
151 .expect("All slices returned by rchunks_exact have length SURB_SIZE")
152 })
153 .take(fragment_num_surbs(fragment)),
154 );
155 }
156
157 Self { id, data, surbs }
158 }
159}
160
161#[derive(Debug, thiserror::Error)]
162enum IncompleteMessageInsertErr {
163 #[error("Inconsistent number of fragments for message ({0} vs {1})")]
164 InconsistentNumFragments(usize, usize),
165 #[error("Already have this fragment")]
166 AlreadyHave,
167}
168
169struct IncompleteMessage {
170 fragments: Vec<Option<Box<Fragment>>>,
171 num_received_fragments: usize,
173}
174
175impl IncompleteMessage {
176 fn new(num_fragments: usize) -> Self {
177 Self { fragments: vec![None; num_fragments], num_received_fragments: 0 }
178 }
179
180 fn insert(&mut self, fragment: &Fragment) -> Result<(), IncompleteMessageInsertErr> {
184 debug_assert!(check_fragment(fragment).is_ok());
185
186 if num_fragments(fragment) != self.fragments.len() {
187 return Err(IncompleteMessageInsertErr::InconsistentNumFragments(
188 num_fragments(fragment),
189 self.fragments.len(),
190 ))
191 }
192
193 let slot = &mut self.fragments[fragment_index(fragment)];
194 if slot.is_some() {
195 return Err(IncompleteMessageInsertErr::AlreadyHave)
196 }
197
198 *slot = Some((*fragment).into());
199 self.num_received_fragments += 1;
200 debug_assert!(self.num_received_fragments <= self.fragments.len());
201 Ok(())
202 }
203
204 fn complete_fragments(&self) -> Option<impl Iterator<Item = &Fragment> + Clone> {
207 (self.num_received_fragments == self.fragments.len()).then(|| {
208 self.fragments
209 .iter()
210 .map(|fragment| fragment.as_ref().expect("All fragments received").as_ref())
211 })
212 }
213}
214
215pub struct FragmentAssembler {
216 incomplete_messages: LinkedHashMap<MessageId, IncompleteMessage>,
219 num_incomplete_fragments: usize,
221
222 max_incomplete_messages: usize,
224 max_incomplete_fragments: usize,
226 max_fragments_per_message: usize,
229}
230
231impl FragmentAssembler {
232 pub fn new(
233 max_incomplete_messages: usize,
234 max_incomplete_fragments: usize,
235 max_fragments_per_message: usize,
236 ) -> Self {
237 Self {
238 incomplete_messages: LinkedHashMap::with_capacity(
239 max_incomplete_messages.saturating_add(1),
241 ),
242 num_incomplete_fragments: 0,
243 max_incomplete_messages,
244 max_incomplete_fragments,
245 max_fragments_per_message,
246 }
247 }
248
249 fn need_eviction(&self) -> bool {
250 (self.incomplete_messages.len() > self.max_incomplete_messages) ||
251 (self.num_incomplete_fragments > self.max_incomplete_fragments)
252 }
253
254 fn maybe_evict(&mut self, log_target: &str) {
257 if self.need_eviction() {
258 debug!(target: log_target, "Too many incomplete messages; evicting LRU");
259 let incomplete_message = self
260 .incomplete_messages
261 .pop_front()
262 .expect("Over messages or fragments limit, there must be at least one message")
263 .1;
264 debug_assert!(
265 self.num_incomplete_fragments >= incomplete_message.num_received_fragments
266 );
267 self.num_incomplete_fragments -= incomplete_message.num_received_fragments;
268 debug_assert!(!self.need_eviction());
272 }
273 }
274
275 pub fn insert(&mut self, fragment: &Fragment, log_target: &str) -> Option<GenericMessage> {
278 if let Err(err) = check_fragment(fragment) {
279 debug!(target: log_target, "Received bad fragment: {err}");
280 return None
281 }
282 let num_fragments = num_fragments(fragment);
283 if num_fragments > self.max_fragments_per_message {
284 return None
285 }
286 if num_fragments == 1 {
287 return Some(GenericMessage::from_fragments(std::iter::once(fragment)))
288 }
289 match self.incomplete_messages.entry(*message_id(fragment)) {
290 Entry::Occupied(mut entry) => {
291 let incomplete_message = entry.get_mut();
292 if let Err(err) = incomplete_message.insert(fragment) {
293 let level = match err {
294 IncompleteMessageInsertErr::AlreadyHave => Level::Trace,
295 _ => Level::Debug,
296 };
297 log!(target: log_target, level, "Fragment insert failed: {err}");
298 return None
299 }
300 self.num_incomplete_fragments += 1;
301 let message =
302 incomplete_message.complete_fragments().map(GenericMessage::from_fragments);
303 if message.is_some() {
304 self.num_incomplete_fragments -= entry.remove().num_received_fragments;
305 } else {
306 entry.to_back();
307 self.maybe_evict(log_target);
308 }
309 message
310 },
311 Entry::Vacant(entry) => {
312 let mut incomplete_message = IncompleteMessage::new(num_fragments);
313 assert!(incomplete_message.insert(fragment).is_ok());
315 entry.insert(incomplete_message);
316 self.num_incomplete_fragments += 1;
317 self.maybe_evict(log_target);
318 None
319 },
320 }
321 }
322}
323
324pub struct FragmentBlueprint<'a> {
325 message_id: MessageId,
326 last_index: FragmentIndex,
327 index: FragmentIndex,
328 data: Scattered<'a, u8>,
329 num_surbs: FragmentNumSurbs,
330}
331
332impl<'a> FragmentBlueprint<'a> {
333 pub fn write_except_surbs(&self, fragment: &mut Fragment) {
334 let (message_id, last_index, index, data_size, num_surbs, payload) = mut_array_refs![
335 fragment,
336 MESSAGE_ID_SIZE,
337 FRAGMENT_INDEX_SIZE,
338 FRAGMENT_INDEX_SIZE,
339 FRAGMENT_DATA_SIZE_SIZE,
340 FRAGMENT_NUM_SURBS_SIZE,
341 FRAGMENT_PAYLOAD_SIZE
342 ];
343
344 *message_id = self.message_id;
346 *last_index = self.last_index.to_le_bytes();
347 *index = self.index.to_le_bytes();
348 *data_size = (self.data.len() as FragmentDataSize).to_le_bytes();
349 *num_surbs = self.num_surbs.to_le_bytes();
350
351 self.data.copy_to_slice(&mut payload[..self.data.len()]);
353 }
354
355 pub fn surbs<'fragment>(
356 &self,
357 fragment: &'fragment mut Fragment,
358 ) -> impl Iterator<Item = &'fragment mut Surb> {
359 array_mut_ref![fragment, FRAGMENT_HEADER_SIZE, FRAGMENT_PAYLOAD_SIZE]
360 .rchunks_exact_mut(SURB_SIZE)
362 .map(|surb| {
363 TryInto::<&mut Surb>::try_into(surb)
364 .expect("All slices returned by rchunks_exact_mut have length SURB_SIZE")
365 })
366 .take(self.num_surbs as usize)
367 }
368}
369
370fn div_ceil(x: usize, y: usize) -> usize {
372 if x == 0 {
373 0
374 } else {
375 ((x - 1) / y) + 1
376 }
377}
378
379pub fn fragment_blueprints<'a>(
384 message_id: &MessageId,
385 mut data: Scattered<'a, u8>,
386 mut num_surbs: usize,
387) -> Option<impl ExactSizeIterator<Item = FragmentBlueprint<'a>>> {
388 let message_id = *message_id;
389
390 let num_fragments_for_surbs = div_ceil(num_surbs, MAX_SURBS_PER_FRAGMENT);
392 let surb_fragments_unused_size = num_fragments_for_surbs.saturating_mul(FRAGMENT_PAYLOAD_SIZE) -
393 num_surbs.saturating_mul(SURB_SIZE);
394 let remaining_data_size = data.len().saturating_sub(surb_fragments_unused_size);
395 let num_fragments_for_remaining_data = div_ceil(remaining_data_size, FRAGMENT_PAYLOAD_SIZE);
396 let num_fragments =
397 max(num_fragments_for_surbs.saturating_add(num_fragments_for_remaining_data), 1);
398
399 let last_index = num_fragments - 1;
400 (last_index <= (FragmentIndex::MAX as usize)).then(|| {
401 (0..num_fragments).map(move |index| {
402 let fragment_num_surbs = min(num_surbs, MAX_SURBS_PER_FRAGMENT);
403 num_surbs -= fragment_num_surbs;
404 let fragment_unused_size = FRAGMENT_PAYLOAD_SIZE - (fragment_num_surbs * SURB_SIZE);
405 let fragment_data_size = min(data.len(), fragment_unused_size);
406 let (fragment_data, remaining_data) = data.split_at(fragment_data_size);
407 data = remaining_data;
408 FragmentBlueprint {
409 message_id,
410 last_index: last_index as FragmentIndex,
411 index: index as FragmentIndex,
412 data: fragment_data,
413 num_surbs: fragment_num_surbs as FragmentNumSurbs,
414 }
415 })
416 })
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use itertools::Itertools;
423 use rand::{prelude::SliceRandom, Rng, RngCore};
424
425 const LOG_TARGET: &str = "mixnet";
426
427 #[test]
428 fn create_and_insert_small() {
429 let mut rng = rand::thread_rng();
430
431 let id = rng.gen();
432 let mut blueprints = fragment_blueprints(&id, [42].as_slice().into(), 1).unwrap();
433 assert_eq!(blueprints.len(), 1);
434 let blueprint = blueprints.next().unwrap();
435
436 let mut fragment = [0; FRAGMENT_SIZE];
437 blueprint.write_except_surbs(&mut fragment);
438 let mut dummy_surb = [0; SURB_SIZE];
439 rng.fill_bytes(&mut dummy_surb);
440 {
441 let mut surbs = blueprint.surbs(&mut fragment);
442 *surbs.next().unwrap() = dummy_surb;
443 assert!(surbs.next().is_none());
444 }
445
446 let mut fa = FragmentAssembler::new(1, usize::MAX, usize::MAX);
447 assert_eq!(
448 fa.insert(&fragment, LOG_TARGET),
449 Some(GenericMessage { id, data: vec![42], surbs: vec![dummy_surb] })
450 );
451 }
452
453 fn no_surb_fragments(message_id: &MessageId, data: &[u8]) -> Vec<Fragment> {
454 fragment_blueprints(message_id, data.into(), 0)
455 .unwrap()
456 .map(|blueprint| {
457 let mut fragment = [0; FRAGMENT_SIZE];
458 blueprint.write_except_surbs(&mut fragment);
459 fragment
460 })
461 .collect()
462 }
463
464 fn insert_fragments<'a>(
465 fa: &mut FragmentAssembler,
466 mut fragments: impl Iterator<Item = &'a Fragment>,
467 ) -> Option<GenericMessage> {
468 let message = fragments.find_map(|fragment| fa.insert(fragment, LOG_TARGET));
469 assert!(fragments.next().is_none());
470 message
471 }
472
473 #[test]
474 fn create_and_insert_large() {
475 let mut rng = rand::thread_rng();
476
477 let id = rng.gen();
478 let mut data = vec![0; 60000];
479 rng.fill_bytes(&mut data);
480 let mut fragments = no_surb_fragments(&id, &data);
481 assert_eq!(fragments.len(), 30);
482 fragments.shuffle(&mut rng);
483
484 let mut fa = FragmentAssembler::new(1, usize::MAX, usize::MAX);
485 assert_eq!(
486 insert_fragments(&mut fa, fragments.iter()),
487 Some(GenericMessage { id, data, surbs: Vec::new() })
488 );
489 }
490
491 #[test]
492 fn create_too_large() {
493 let too_large = vec![0; (((FragmentIndex::MAX as usize) + 1) * FRAGMENT_PAYLOAD_SIZE) + 1];
494 assert!(
495 fragment_blueprints(&[0; MESSAGE_ID_SIZE], too_large.as_slice().into(), 0).is_none()
496 );
497 }
498
499 #[test]
500 fn message_limit_eviction() {
501 let mut rng = rand::thread_rng();
502
503 let first_id = rng.gen();
504 let mut first_data = vec![0; 3000];
505 rng.fill_bytes(&mut first_data);
506 let first_fragments = no_surb_fragments(&first_id, &first_data);
507
508 let second_id = rng.gen();
509 let mut second_data = vec![0; 3000];
510 rng.fill_bytes(&mut second_data);
511 let second_fragments = no_surb_fragments(&second_id, &second_data);
512
513 let mut fa = FragmentAssembler::new(1, usize::MAX, usize::MAX);
514
515 assert_eq!(
517 insert_fragments(&mut fa, first_fragments.iter()),
518 Some(GenericMessage { id: first_id, data: first_data, surbs: Vec::new() })
519 );
520 assert_eq!(
521 insert_fragments(&mut fa, second_fragments.iter()),
522 Some(GenericMessage { id: second_id, data: second_data, surbs: Vec::new() })
523 );
524
525 assert_eq!(
527 insert_fragments(&mut fa, first_fragments.iter().interleave(&second_fragments)),
528 None
529 );
530 }
531
532 #[test]
533 fn fragment_limit_eviction() {
534 let mut rng = rand::thread_rng();
535
536 let first_id = rng.gen();
537 let mut first_data = vec![0; 5000];
538 rng.fill_bytes(&mut first_data);
539 let first_fragments = no_surb_fragments(&first_id, &first_data);
540
541 let second_id = rng.gen();
542 let mut second_data = vec![0; 5000];
543 rng.fill_bytes(&mut second_data);
544 let second_fragments = no_surb_fragments(&second_id, &second_data);
545
546 let mut fa = FragmentAssembler::new(2, 1, usize::MAX);
548 assert_eq!(insert_fragments(&mut fa, first_fragments.iter()), None);
549 assert_eq!(insert_fragments(&mut fa, second_fragments.iter()), None);
550
551 let mut fa = FragmentAssembler::new(2, 2, usize::MAX);
552
553 assert_eq!(
555 insert_fragments(&mut fa, first_fragments.iter()),
556 Some(GenericMessage { id: first_id, data: first_data, surbs: Vec::new() })
557 );
558 assert_eq!(
559 insert_fragments(&mut fa, second_fragments.iter()),
560 Some(GenericMessage { id: second_id, data: second_data, surbs: Vec::new() })
561 );
562
563 assert_eq!(
565 insert_fragments(&mut fa, first_fragments.iter().interleave(&second_fragments)),
566 None
567 );
568 }
569}