mixnet/core/
fragment.rs

1// Copyright 2022 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Mixnet message fragment handling.
22
23use super::{
24	scattered::Scattered,
25	sphinx::{Surb, PAYLOAD_DATA_SIZE, SURB_SIZE},
26};
27use arrayref::{array_mut_ref, array_refs, mut_array_refs};
28use hashlink::{linked_hash_map::Entry, LinkedHashMap};
29use log::{debug, log, Level};
30use std::cmp::{max, min};
31
32/// Size in bytes of a [`MessageId`].
33pub const MESSAGE_ID_SIZE: usize = 16;
34/// Message identifier. Should be randomly generated. Attached to fragments to enable reassembly.
35pub type MessageId = [u8; MESSAGE_ID_SIZE];
36const FRAGMENT_INDEX_SIZE: usize = 2;
37type FragmentIndex = u16;
38const FRAGMENT_DATA_SIZE_SIZE: usize = 2;
39type FragmentDataSize = u16;
40const FRAGMENT_NUM_SURBS_SIZE: usize = 1;
41type FragmentNumSurbs = u8;
42const FRAGMENT_HEADER_SIZE: usize = MESSAGE_ID_SIZE +
43	FRAGMENT_INDEX_SIZE + // Last fragment index (number of fragments - 1)
44	FRAGMENT_INDEX_SIZE + // Index of this fragment
45	FRAGMENT_DATA_SIZE_SIZE + // Number of data bytes in this fragment
46	FRAGMENT_NUM_SURBS_SIZE; // Number of SURBs in this fragment
47
48pub const FRAGMENT_SIZE: usize = PAYLOAD_DATA_SIZE;
49pub type Fragment = [u8; FRAGMENT_SIZE];
50const FRAGMENT_PAYLOAD_SIZE: usize = FRAGMENT_SIZE - FRAGMENT_HEADER_SIZE;
51type FragmentPayload = [u8; FRAGMENT_PAYLOAD_SIZE];
52const MAX_SURBS_PER_FRAGMENT: usize = FRAGMENT_PAYLOAD_SIZE / SURB_SIZE;
53
54#[allow(clippy::type_complexity)]
55fn split_fragment(
56	fragment: &Fragment,
57) -> (
58	&MessageId,
59	&[u8; FRAGMENT_INDEX_SIZE],
60	&[u8; FRAGMENT_INDEX_SIZE],
61	&[u8; FRAGMENT_DATA_SIZE_SIZE],
62	&[u8; FRAGMENT_NUM_SURBS_SIZE],
63	&FragmentPayload,
64) {
65	array_refs![
66		fragment,
67		MESSAGE_ID_SIZE,
68		FRAGMENT_INDEX_SIZE,
69		FRAGMENT_INDEX_SIZE,
70		FRAGMENT_DATA_SIZE_SIZE,
71		FRAGMENT_NUM_SURBS_SIZE,
72		FRAGMENT_PAYLOAD_SIZE
73	]
74}
75
76fn message_id(fragment: &Fragment) -> &MessageId {
77	split_fragment(fragment).0
78}
79
80fn num_fragments(fragment: &Fragment) -> usize {
81	(FragmentIndex::from_le_bytes(*split_fragment(fragment).1) as usize) + 1
82}
83
84fn fragment_index(fragment: &Fragment) -> usize {
85	FragmentIndex::from_le_bytes(*split_fragment(fragment).2) as usize
86}
87
88fn fragment_data_size(fragment: &Fragment) -> usize {
89	FragmentDataSize::from_le_bytes(*split_fragment(fragment).3) as usize
90}
91
92fn fragment_num_surbs(fragment: &Fragment) -> usize {
93	FragmentNumSurbs::from_le_bytes(*split_fragment(fragment).4) as usize
94}
95
96fn fragment_payload(fragment: &Fragment) -> &FragmentPayload {
97	split_fragment(fragment).5
98}
99
100#[derive(Debug, thiserror::Error)]
101enum CheckFragmentErr {
102	#[error("Out-of-range index ({index}, max {max})")]
103	Index { index: usize, max: usize },
104	#[error("Bad payload size ({size}, max {max})")]
105	PayloadSize { size: usize, max: usize },
106}
107
108fn check_fragment(fragment: &Fragment) -> Result<(), CheckFragmentErr> {
109	if fragment_index(fragment) >= num_fragments(fragment) {
110		return Err(CheckFragmentErr::Index {
111			index: fragment_index(fragment),
112			max: num_fragments(fragment) - 1,
113		})
114	}
115
116	let data_size = fragment_data_size(fragment);
117	let num_surbs = fragment_num_surbs(fragment);
118	let payload_size = data_size + (num_surbs * SURB_SIZE);
119	if payload_size > FRAGMENT_PAYLOAD_SIZE {
120		return Err(CheckFragmentErr::PayloadSize { size: payload_size, max: FRAGMENT_PAYLOAD_SIZE })
121	}
122
123	Ok(())
124}
125
126#[derive(Debug, PartialEq, Eq)]
127pub struct GenericMessage {
128	pub id: MessageId,
129	pub data: Vec<u8>,
130	pub surbs: Vec<Surb>,
131}
132
133impl GenericMessage {
134	/// Construct a message from a list of fragments. The fragments must all be valid (checked by
135	/// [`check_fragment`]) and in the correct order.
136	fn from_fragments<'a>(fragments: impl Iterator<Item = &'a Fragment> + Clone) -> Self {
137		let id = *message_id(fragments.clone().next().expect("At least one fragment"));
138
139		let mut data = Vec::with_capacity(fragments.clone().map(fragment_data_size).sum());
140		let mut surbs = Vec::with_capacity(fragments.clone().map(fragment_num_surbs).sum());
141		for fragment in fragments {
142			debug_assert!(check_fragment(fragment).is_ok());
143			let payload = fragment_payload(fragment);
144			data.extend_from_slice(&payload[..fragment_data_size(fragment)]);
145			surbs.extend(
146				payload
147					// TODO Use array_rchunks if/when this is stabilised
148					.rchunks_exact(SURB_SIZE)
149					.map(|surb| {
150						TryInto::<&Surb>::try_into(surb)
151							.expect("All slices returned by rchunks_exact have length SURB_SIZE")
152					})
153					.take(fragment_num_surbs(fragment)),
154			);
155		}
156
157		Self { id, data, surbs }
158	}
159}
160
161#[derive(Debug, thiserror::Error)]
162enum IncompleteMessageInsertErr {
163	#[error("Inconsistent number of fragments for message ({0} vs {1})")]
164	InconsistentNumFragments(usize, usize),
165	#[error("Already have this fragment")]
166	AlreadyHave,
167}
168
169struct IncompleteMessage {
170	fragments: Vec<Option<Box<Fragment>>>,
171	/// Count of [`Some`] in `fragments`.
172	num_received_fragments: usize,
173}
174
175impl IncompleteMessage {
176	fn new(num_fragments: usize) -> Self {
177		Self { fragments: vec![None; num_fragments], num_received_fragments: 0 }
178	}
179
180	/// Attempt to insert `fragment`, which must be a valid fragment (checked by
181	/// [`check_fragment`]). Success implies
182	/// [`num_received_fragments`](Self::num_received_fragments) was incremented.
183	fn insert(&mut self, fragment: &Fragment) -> Result<(), IncompleteMessageInsertErr> {
184		debug_assert!(check_fragment(fragment).is_ok());
185
186		if num_fragments(fragment) != self.fragments.len() {
187			return Err(IncompleteMessageInsertErr::InconsistentNumFragments(
188				num_fragments(fragment),
189				self.fragments.len(),
190			))
191		}
192
193		let slot = &mut self.fragments[fragment_index(fragment)];
194		if slot.is_some() {
195			return Err(IncompleteMessageInsertErr::AlreadyHave)
196		}
197
198		*slot = Some((*fragment).into());
199		self.num_received_fragments += 1;
200		debug_assert!(self.num_received_fragments <= self.fragments.len());
201		Ok(())
202	}
203
204	/// Returns [`None`] if we don't have all the fragments yet. Otherwise, returns an iterator
205	/// over the completed list of fragments.
206	fn complete_fragments(&self) -> Option<impl Iterator<Item = &Fragment> + Clone> {
207		(self.num_received_fragments == self.fragments.len()).then(|| {
208			self.fragments
209				.iter()
210				.map(|fragment| fragment.as_ref().expect("All fragments received").as_ref())
211		})
212	}
213}
214
215pub struct FragmentAssembler {
216	/// Incomplete messages, in LRU order: least recently used at the front, most recently at the
217	/// back. All messages have at least one received fragment.
218	incomplete_messages: LinkedHashMap<MessageId, IncompleteMessage>,
219	/// Total number of received fragments across all messages in `incomplete_messages`.
220	num_incomplete_fragments: usize,
221
222	/// Maximum number of incomplete messages to keep in `incomplete_messages`.
223	max_incomplete_messages: usize,
224	/// Maximum number of received fragments to keep across all messages in `incomplete_messages`.
225	max_incomplete_fragments: usize,
226	/// Maximum number of fragments per message. Fragments of messages with more than this many
227	/// fragments are dropped on receipt.
228	max_fragments_per_message: usize,
229}
230
231impl FragmentAssembler {
232	pub fn new(
233		max_incomplete_messages: usize,
234		max_incomplete_fragments: usize,
235		max_fragments_per_message: usize,
236	) -> Self {
237		Self {
238			incomplete_messages: LinkedHashMap::with_capacity(
239				// Plus one because we only evict _after_ going over the limit
240				max_incomplete_messages.saturating_add(1),
241			),
242			num_incomplete_fragments: 0,
243			max_incomplete_messages,
244			max_incomplete_fragments,
245			max_fragments_per_message,
246		}
247	}
248
249	fn need_eviction(&self) -> bool {
250		(self.incomplete_messages.len() > self.max_incomplete_messages) ||
251			(self.num_incomplete_fragments > self.max_incomplete_fragments)
252	}
253
254	/// Evict a message if we're over the messages or fragments limit. This should be called after
255	/// each fragment insertion.
256	fn maybe_evict(&mut self, log_target: &str) {
257		if self.need_eviction() {
258			debug!(target: log_target, "Too many incomplete messages; evicting LRU");
259			let incomplete_message = self
260				.incomplete_messages
261				.pop_front()
262				.expect("Over messages or fragments limit, there must be at least one message")
263				.1;
264			debug_assert!(
265				self.num_incomplete_fragments >= incomplete_message.num_received_fragments
266			);
267			self.num_incomplete_fragments -= incomplete_message.num_received_fragments;
268			// Called after each fragment insertion, so could only have been one message or
269			// fragment over the limit. Each message has at least one received fragment, so having
270			// popped a message we should now be within both limits.
271			debug_assert!(!self.need_eviction());
272		}
273	}
274
275	/// Attempt to insert `fragment`. If this completes a message, the completed message is
276	/// returned.
277	pub fn insert(&mut self, fragment: &Fragment, log_target: &str) -> Option<GenericMessage> {
278		if let Err(err) = check_fragment(fragment) {
279			debug!(target: log_target, "Received bad fragment: {err}");
280			return None
281		}
282		let num_fragments = num_fragments(fragment);
283		if num_fragments > self.max_fragments_per_message {
284			return None
285		}
286		if num_fragments == 1 {
287			return Some(GenericMessage::from_fragments(std::iter::once(fragment)))
288		}
289		match self.incomplete_messages.entry(*message_id(fragment)) {
290			Entry::Occupied(mut entry) => {
291				let incomplete_message = entry.get_mut();
292				if let Err(err) = incomplete_message.insert(fragment) {
293					let level = match err {
294						IncompleteMessageInsertErr::AlreadyHave => Level::Trace,
295						_ => Level::Debug,
296					};
297					log!(target: log_target, level, "Fragment insert failed: {err}");
298					return None
299				}
300				self.num_incomplete_fragments += 1;
301				let message =
302					incomplete_message.complete_fragments().map(GenericMessage::from_fragments);
303				if message.is_some() {
304					self.num_incomplete_fragments -= entry.remove().num_received_fragments;
305				} else {
306					entry.to_back();
307					self.maybe_evict(log_target);
308				}
309				message
310			},
311			Entry::Vacant(entry) => {
312				let mut incomplete_message = IncompleteMessage::new(num_fragments);
313				// Insert of first fragment cannot fail
314				assert!(incomplete_message.insert(fragment).is_ok());
315				entry.insert(incomplete_message);
316				self.num_incomplete_fragments += 1;
317				self.maybe_evict(log_target);
318				None
319			},
320		}
321	}
322}
323
324pub struct FragmentBlueprint<'a> {
325	message_id: MessageId,
326	last_index: FragmentIndex,
327	index: FragmentIndex,
328	data: Scattered<'a, u8>,
329	num_surbs: FragmentNumSurbs,
330}
331
332impl<'a> FragmentBlueprint<'a> {
333	pub fn write_except_surbs(&self, fragment: &mut Fragment) {
334		let (message_id, last_index, index, data_size, num_surbs, payload) = mut_array_refs![
335			fragment,
336			MESSAGE_ID_SIZE,
337			FRAGMENT_INDEX_SIZE,
338			FRAGMENT_INDEX_SIZE,
339			FRAGMENT_DATA_SIZE_SIZE,
340			FRAGMENT_NUM_SURBS_SIZE,
341			FRAGMENT_PAYLOAD_SIZE
342		];
343
344		// Write header
345		*message_id = self.message_id;
346		*last_index = self.last_index.to_le_bytes();
347		*index = self.index.to_le_bytes();
348		*data_size = (self.data.len() as FragmentDataSize).to_le_bytes();
349		*num_surbs = self.num_surbs.to_le_bytes();
350
351		// Write payload
352		self.data.copy_to_slice(&mut payload[..self.data.len()]);
353	}
354
355	pub fn surbs<'fragment>(
356		&self,
357		fragment: &'fragment mut Fragment,
358	) -> impl Iterator<Item = &'fragment mut Surb> {
359		array_mut_ref![fragment, FRAGMENT_HEADER_SIZE, FRAGMENT_PAYLOAD_SIZE]
360			// TODO Use array_rchunks_mut if/when this is stabilised
361			.rchunks_exact_mut(SURB_SIZE)
362			.map(|surb| {
363				TryInto::<&mut Surb>::try_into(surb)
364					.expect("All slices returned by rchunks_exact_mut have length SURB_SIZE")
365			})
366			.take(self.num_surbs as usize)
367	}
368}
369
370// TODO Use usize::div_ceil when this is stabilised
371fn div_ceil(x: usize, y: usize) -> usize {
372	if x == 0 {
373		0
374	} else {
375		((x - 1) / y) + 1
376	}
377}
378
379/// Generate fragment blueprints containing the provided message ID and data and the specified
380/// number of SURBs. Returns [`None`] if more fragments would be required than are possible to
381/// encode. Note that the actual number of fragments supported by the receiver is likely to be
382/// significantly less than this.
383pub fn fragment_blueprints<'a>(
384	message_id: &MessageId,
385	mut data: Scattered<'a, u8>,
386	mut num_surbs: usize,
387) -> Option<impl ExactSizeIterator<Item = FragmentBlueprint<'a>>> {
388	let message_id = *message_id;
389
390	// Figure out how many fragments we need
391	let num_fragments_for_surbs = div_ceil(num_surbs, MAX_SURBS_PER_FRAGMENT);
392	let surb_fragments_unused_size = num_fragments_for_surbs.saturating_mul(FRAGMENT_PAYLOAD_SIZE) -
393		num_surbs.saturating_mul(SURB_SIZE);
394	let remaining_data_size = data.len().saturating_sub(surb_fragments_unused_size);
395	let num_fragments_for_remaining_data = div_ceil(remaining_data_size, FRAGMENT_PAYLOAD_SIZE);
396	let num_fragments =
397		max(num_fragments_for_surbs.saturating_add(num_fragments_for_remaining_data), 1);
398
399	let last_index = num_fragments - 1;
400	(last_index <= (FragmentIndex::MAX as usize)).then(|| {
401		(0..num_fragments).map(move |index| {
402			let fragment_num_surbs = min(num_surbs, MAX_SURBS_PER_FRAGMENT);
403			num_surbs -= fragment_num_surbs;
404			let fragment_unused_size = FRAGMENT_PAYLOAD_SIZE - (fragment_num_surbs * SURB_SIZE);
405			let fragment_data_size = min(data.len(), fragment_unused_size);
406			let (fragment_data, remaining_data) = data.split_at(fragment_data_size);
407			data = remaining_data;
408			FragmentBlueprint {
409				message_id,
410				last_index: last_index as FragmentIndex,
411				index: index as FragmentIndex,
412				data: fragment_data,
413				num_surbs: fragment_num_surbs as FragmentNumSurbs,
414			}
415		})
416	})
417}
418
419#[cfg(test)]
420mod tests {
421	use super::*;
422	use itertools::Itertools;
423	use rand::{prelude::SliceRandom, Rng, RngCore};
424
425	const LOG_TARGET: &str = "mixnet";
426
427	#[test]
428	fn create_and_insert_small() {
429		let mut rng = rand::thread_rng();
430
431		let id = rng.gen();
432		let mut blueprints = fragment_blueprints(&id, [42].as_slice().into(), 1).unwrap();
433		assert_eq!(blueprints.len(), 1);
434		let blueprint = blueprints.next().unwrap();
435
436		let mut fragment = [0; FRAGMENT_SIZE];
437		blueprint.write_except_surbs(&mut fragment);
438		let mut dummy_surb = [0; SURB_SIZE];
439		rng.fill_bytes(&mut dummy_surb);
440		{
441			let mut surbs = blueprint.surbs(&mut fragment);
442			*surbs.next().unwrap() = dummy_surb;
443			assert!(surbs.next().is_none());
444		}
445
446		let mut fa = FragmentAssembler::new(1, usize::MAX, usize::MAX);
447		assert_eq!(
448			fa.insert(&fragment, LOG_TARGET),
449			Some(GenericMessage { id, data: vec![42], surbs: vec![dummy_surb] })
450		);
451	}
452
453	fn no_surb_fragments(message_id: &MessageId, data: &[u8]) -> Vec<Fragment> {
454		fragment_blueprints(message_id, data.into(), 0)
455			.unwrap()
456			.map(|blueprint| {
457				let mut fragment = [0; FRAGMENT_SIZE];
458				blueprint.write_except_surbs(&mut fragment);
459				fragment
460			})
461			.collect()
462	}
463
464	fn insert_fragments<'a>(
465		fa: &mut FragmentAssembler,
466		mut fragments: impl Iterator<Item = &'a Fragment>,
467	) -> Option<GenericMessage> {
468		let message = fragments.find_map(|fragment| fa.insert(fragment, LOG_TARGET));
469		assert!(fragments.next().is_none());
470		message
471	}
472
473	#[test]
474	fn create_and_insert_large() {
475		let mut rng = rand::thread_rng();
476
477		let id = rng.gen();
478		let mut data = vec![0; 60000];
479		rng.fill_bytes(&mut data);
480		let mut fragments = no_surb_fragments(&id, &data);
481		assert_eq!(fragments.len(), 30);
482		fragments.shuffle(&mut rng);
483
484		let mut fa = FragmentAssembler::new(1, usize::MAX, usize::MAX);
485		assert_eq!(
486			insert_fragments(&mut fa, fragments.iter()),
487			Some(GenericMessage { id, data, surbs: Vec::new() })
488		);
489	}
490
491	#[test]
492	fn create_too_large() {
493		let too_large = vec![0; (((FragmentIndex::MAX as usize) + 1) * FRAGMENT_PAYLOAD_SIZE) + 1];
494		assert!(
495			fragment_blueprints(&[0; MESSAGE_ID_SIZE], too_large.as_slice().into(), 0).is_none()
496		);
497	}
498
499	#[test]
500	fn message_limit_eviction() {
501		let mut rng = rand::thread_rng();
502
503		let first_id = rng.gen();
504		let mut first_data = vec![0; 3000];
505		rng.fill_bytes(&mut first_data);
506		let first_fragments = no_surb_fragments(&first_id, &first_data);
507
508		let second_id = rng.gen();
509		let mut second_data = vec![0; 3000];
510		rng.fill_bytes(&mut second_data);
511		let second_fragments = no_surb_fragments(&second_id, &second_data);
512
513		let mut fa = FragmentAssembler::new(1, usize::MAX, usize::MAX);
514
515		// One message at a time should work
516		assert_eq!(
517			insert_fragments(&mut fa, first_fragments.iter()),
518			Some(GenericMessage { id: first_id, data: first_data, surbs: Vec::new() })
519		);
520		assert_eq!(
521			insert_fragments(&mut fa, second_fragments.iter()),
522			Some(GenericMessage { id: second_id, data: second_data, surbs: Vec::new() })
523		);
524
525		// Alternating fragments should not work due to eviction
526		assert_eq!(
527			insert_fragments(&mut fa, first_fragments.iter().interleave(&second_fragments)),
528			None
529		);
530	}
531
532	#[test]
533	fn fragment_limit_eviction() {
534		let mut rng = rand::thread_rng();
535
536		let first_id = rng.gen();
537		let mut first_data = vec![0; 5000];
538		rng.fill_bytes(&mut first_data);
539		let first_fragments = no_surb_fragments(&first_id, &first_data);
540
541		let second_id = rng.gen();
542		let mut second_data = vec![0; 5000];
543		rng.fill_bytes(&mut second_data);
544		let second_fragments = no_surb_fragments(&second_id, &second_data);
545
546		// With a one-fragment limit it should not be possible to reconstruct either message
547		let mut fa = FragmentAssembler::new(2, 1, usize::MAX);
548		assert_eq!(insert_fragments(&mut fa, first_fragments.iter()), None);
549		assert_eq!(insert_fragments(&mut fa, second_fragments.iter()), None);
550
551		let mut fa = FragmentAssembler::new(2, 2, usize::MAX);
552
553		// With a two-fragment limit it should be possible to reconstruct them individually
554		assert_eq!(
555			insert_fragments(&mut fa, first_fragments.iter()),
556			Some(GenericMessage { id: first_id, data: first_data, surbs: Vec::new() })
557		);
558		assert_eq!(
559			insert_fragments(&mut fa, second_fragments.iter()),
560			Some(GenericMessage { id: second_id, data: second_data, surbs: Vec::new() })
561		);
562
563		// But not when interleaved
564		assert_eq!(
565			insert_fragments(&mut fa, first_fragments.iter().interleave(&second_fragments)),
566			None
567		);
568	}
569}