bounded_collections/
bounded_btree_set.rs

1// This file is part of Substrate.
2
3// Copyright (C) 2023 Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Traits, types and structs to support a bounded `BTreeSet`.
19
20use crate::{Get, TryCollect};
21use alloc::collections::BTreeSet;
22use codec::{Compact, Decode, Encode, MaxEncodedLen};
23use core::{borrow::Borrow, marker::PhantomData, ops::Deref};
24#[cfg(feature = "serde")]
25use serde::{
26	de::{Error, SeqAccess, Visitor},
27	Deserialize, Deserializer, Serialize,
28};
29
30/// A bounded set based on a B-Tree.
31///
32/// B-Trees represent a fundamental compromise between cache-efficiency and actually minimizing
33/// the amount of work performed in a search. See [`BTreeSet`] for more details.
34///
35/// Unlike a standard `BTreeSet`, there is an enforced upper limit to the number of items in the
36/// set. All internal operations ensure this bound is respected.
37#[cfg_attr(feature = "serde", derive(Serialize), serde(transparent))]
38#[derive(Encode, scale_info::TypeInfo)]
39#[scale_info(skip_type_params(S))]
40pub struct BoundedBTreeSet<T, S>(BTreeSet<T>, #[cfg_attr(feature = "serde", serde(skip_serializing))] PhantomData<S>);
41
42#[cfg(feature = "serde")]
43impl<'de, T, S: Get<u32>> Deserialize<'de> for BoundedBTreeSet<T, S>
44where
45	T: Ord + Deserialize<'de>,
46	S: Clone,
47{
48	fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
49	where
50		D: Deserializer<'de>,
51	{
52		// Create a visitor to visit each element in the sequence
53		struct BTreeSetVisitor<T, S>(PhantomData<(T, S)>);
54
55		impl<'de, T, S> Visitor<'de> for BTreeSetVisitor<T, S>
56		where
57			T: Ord + Deserialize<'de>,
58			S: Get<u32> + Clone,
59		{
60			type Value = BTreeSet<T>;
61
62			fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
63				formatter.write_str("a sequence")
64			}
65
66			fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
67			where
68				A: SeqAccess<'de>,
69			{
70				let size = seq.size_hint().unwrap_or(0);
71				let max = match usize::try_from(S::get()) {
72					Ok(n) => n,
73					Err(_) => return Err(A::Error::custom("can't convert to usize")),
74				};
75				if size > max {
76					Err(A::Error::custom("out of bounds"))
77				} else {
78					let mut values = BTreeSet::new();
79
80					while let Some(value) = seq.next_element()? {
81						if values.len() >= max {
82							return Err(A::Error::custom("out of bounds"))
83						}
84						values.insert(value);
85					}
86
87					Ok(values)
88				}
89			}
90		}
91
92		let visitor: BTreeSetVisitor<T, S> = BTreeSetVisitor(PhantomData);
93		deserializer
94			.deserialize_seq(visitor)
95			.map(|v| BoundedBTreeSet::<T, S>::try_from(v).map_err(|_| Error::custom("out of bounds")))?
96	}
97}
98
99impl<T, S> Decode for BoundedBTreeSet<T, S>
100where
101	T: Decode + Ord,
102	S: Get<u32>,
103{
104	fn decode<I: codec::Input>(input: &mut I) -> Result<Self, codec::Error> {
105		// Same as the underlying implementation for `Decode` on `BTreeSet`, except we fail early if
106		// the len is too big.
107		let len: u32 = <Compact<u32>>::decode(input)?.into();
108		if len > S::get() {
109			return Err("BoundedBTreeSet exceeds its limit".into())
110		}
111		input.descend_ref()?;
112		let inner = Result::from_iter((0..len).map(|_| Decode::decode(input)))?;
113		input.ascend_ref();
114		Ok(Self(inner, PhantomData))
115	}
116
117	fn skip<I: codec::Input>(input: &mut I) -> Result<(), codec::Error> {
118		BTreeSet::<T>::skip(input)
119	}
120}
121
122impl<T, S> BoundedBTreeSet<T, S>
123where
124	S: Get<u32>,
125{
126	/// Get the bound of the type in `usize`.
127	pub fn bound() -> usize {
128		S::get() as usize
129	}
130}
131
132impl<T, S> BoundedBTreeSet<T, S>
133where
134	T: Ord,
135	S: Get<u32>,
136{
137	/// Create `Self` from `t` without any checks.
138	fn unchecked_from(t: BTreeSet<T>) -> Self {
139		Self(t, Default::default())
140	}
141
142	/// Create a new `BoundedBTreeSet`.
143	///
144	/// Does not allocate.
145	pub fn new() -> Self {
146		BoundedBTreeSet(BTreeSet::new(), PhantomData)
147	}
148
149	/// Consume self, and return the inner `BTreeSet`.
150	///
151	/// This is useful when a mutating API of the inner type is desired, and closure-based mutation
152	/// such as provided by [`try_mutate`][Self::try_mutate] is inconvenient.
153	pub fn into_inner(self) -> BTreeSet<T> {
154		debug_assert!(self.0.len() <= Self::bound());
155		self.0
156	}
157
158	/// Consumes self and mutates self via the given `mutate` function.
159	///
160	/// If the outcome of mutation is within bounds, `Some(Self)` is returned. Else, `None` is
161	/// returned.
162	///
163	/// This is essentially a *consuming* shorthand [`Self::into_inner`] -> `...` ->
164	/// [`Self::try_from`].
165	pub fn try_mutate(mut self, mut mutate: impl FnMut(&mut BTreeSet<T>)) -> Option<Self> {
166		mutate(&mut self.0);
167		(self.0.len() <= Self::bound()).then(move || self)
168	}
169
170	/// Clears the set, removing all elements.
171	pub fn clear(&mut self) {
172		self.0.clear()
173	}
174
175	/// Exactly the same semantics as [`BTreeSet::insert`], but returns an `Err` (and is a noop) if
176	/// the new length of the set exceeds `S`.
177	///
178	/// In the `Err` case, returns the inserted item so it can be further used without cloning.
179	pub fn try_insert(&mut self, item: T) -> Result<bool, T> {
180		if self.len() < Self::bound() || self.0.contains(&item) {
181			Ok(self.0.insert(item))
182		} else {
183			Err(item)
184		}
185	}
186
187	/// Remove an item from the set, returning whether it was previously in the set.
188	///
189	/// The item may be any borrowed form of the set's item type, but the ordering on the borrowed
190	/// form _must_ match the ordering on the item type.
191	pub fn remove<Q>(&mut self, item: &Q) -> bool
192	where
193		T: Borrow<Q>,
194		Q: Ord + ?Sized,
195	{
196		self.0.remove(item)
197	}
198
199	/// Removes and returns the value in the set, if any, that is equal to the given one.
200	///
201	/// The value may be any borrowed form of the set's value type, but the ordering on the borrowed
202	/// form _must_ match the ordering on the value type.
203	pub fn take<Q>(&mut self, value: &Q) -> Option<T>
204	where
205		T: Borrow<Q> + Ord,
206		Q: Ord + ?Sized,
207	{
208		self.0.take(value)
209	}
210}
211
212impl<T, S> Default for BoundedBTreeSet<T, S>
213where
214	T: Ord,
215	S: Get<u32>,
216{
217	fn default() -> Self {
218		Self::new()
219	}
220}
221
222impl<T, S> Clone for BoundedBTreeSet<T, S>
223where
224	BTreeSet<T>: Clone,
225{
226	fn clone(&self) -> Self {
227		BoundedBTreeSet(self.0.clone(), PhantomData)
228	}
229}
230
231impl<T, S> core::fmt::Debug for BoundedBTreeSet<T, S>
232where
233	BTreeSet<T>: core::fmt::Debug,
234	S: Get<u32>,
235{
236	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
237		f.debug_tuple("BoundedBTreeSet").field(&self.0).field(&Self::bound()).finish()
238	}
239}
240
241// Custom implementation of `Hash` since deriving it would require all generic bounds to also
242// implement it.
243#[cfg(feature = "std")]
244impl<T: std::hash::Hash, S> std::hash::Hash for BoundedBTreeSet<T, S> {
245	fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
246		self.0.hash(state);
247	}
248}
249
250impl<T, S1, S2> PartialEq<BoundedBTreeSet<T, S1>> for BoundedBTreeSet<T, S2>
251where
252	BTreeSet<T>: PartialEq,
253	S1: Get<u32>,
254	S2: Get<u32>,
255{
256	fn eq(&self, other: &BoundedBTreeSet<T, S1>) -> bool {
257		S1::get() == S2::get() && self.0 == other.0
258	}
259}
260
261impl<T, S> Eq for BoundedBTreeSet<T, S>
262where
263	BTreeSet<T>: Eq,
264	S: Get<u32>,
265{
266}
267
268impl<T, S> PartialEq<BTreeSet<T>> for BoundedBTreeSet<T, S>
269where
270	BTreeSet<T>: PartialEq,
271	S: Get<u32>,
272{
273	fn eq(&self, other: &BTreeSet<T>) -> bool {
274		self.0 == *other
275	}
276}
277
278impl<T, S> PartialOrd for BoundedBTreeSet<T, S>
279where
280	BTreeSet<T>: PartialOrd,
281	S: Get<u32>,
282{
283	fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
284		self.0.partial_cmp(&other.0)
285	}
286}
287
288impl<T, S> Ord for BoundedBTreeSet<T, S>
289where
290	BTreeSet<T>: Ord,
291	S: Get<u32>,
292{
293	fn cmp(&self, other: &Self) -> core::cmp::Ordering {
294		self.0.cmp(&other.0)
295	}
296}
297
298impl<T, S> IntoIterator for BoundedBTreeSet<T, S> {
299	type Item = T;
300	type IntoIter = alloc::collections::btree_set::IntoIter<T>;
301
302	fn into_iter(self) -> Self::IntoIter {
303		self.0.into_iter()
304	}
305}
306
307impl<'a, T, S> IntoIterator for &'a BoundedBTreeSet<T, S> {
308	type Item = &'a T;
309	type IntoIter = alloc::collections::btree_set::Iter<'a, T>;
310
311	fn into_iter(self) -> Self::IntoIter {
312		self.0.iter()
313	}
314}
315
316impl<T, S> MaxEncodedLen for BoundedBTreeSet<T, S>
317where
318	T: MaxEncodedLen,
319	S: Get<u32>,
320{
321	fn max_encoded_len() -> usize {
322		Self::bound()
323			.saturating_mul(T::max_encoded_len())
324			.saturating_add(codec::Compact(S::get()).encoded_size())
325	}
326}
327
328impl<T, S> Deref for BoundedBTreeSet<T, S>
329where
330	T: Ord,
331{
332	type Target = BTreeSet<T>;
333
334	fn deref(&self) -> &Self::Target {
335		&self.0
336	}
337}
338
339impl<T, S> AsRef<BTreeSet<T>> for BoundedBTreeSet<T, S>
340where
341	T: Ord,
342{
343	fn as_ref(&self) -> &BTreeSet<T> {
344		&self.0
345	}
346}
347
348impl<T, S> From<BoundedBTreeSet<T, S>> for BTreeSet<T>
349where
350	T: Ord,
351{
352	fn from(set: BoundedBTreeSet<T, S>) -> Self {
353		set.0
354	}
355}
356
357impl<T, S> TryFrom<BTreeSet<T>> for BoundedBTreeSet<T, S>
358where
359	T: Ord,
360	S: Get<u32>,
361{
362	type Error = ();
363
364	fn try_from(value: BTreeSet<T>) -> Result<Self, Self::Error> {
365		(value.len() <= Self::bound())
366			.then(move || BoundedBTreeSet(value, PhantomData))
367			.ok_or(())
368	}
369}
370
371impl<T, S> codec::DecodeLength for BoundedBTreeSet<T, S> {
372	fn len(self_encoded: &[u8]) -> Result<usize, codec::Error> {
373		// `BoundedBTreeSet<T, S>` is stored just a `BTreeSet<T>`, which is stored as a
374		// `Compact<u32>` with its length followed by an iteration of its items. We can just use
375		// the underlying implementation.
376		<BTreeSet<T> as codec::DecodeLength>::len(self_encoded)
377	}
378}
379
380impl<T, S> codec::EncodeLike<BTreeSet<T>> for BoundedBTreeSet<T, S> where BTreeSet<T>: Encode {}
381
382impl<I, T, Bound> TryCollect<BoundedBTreeSet<T, Bound>> for I
383where
384	T: Ord,
385	I: ExactSizeIterator + Iterator<Item = T>,
386	Bound: Get<u32>,
387{
388	type Error = &'static str;
389
390	fn try_collect(self) -> Result<BoundedBTreeSet<T, Bound>, Self::Error> {
391		if self.len() > Bound::get() as usize {
392			Err("iterator length too big")
393		} else {
394			Ok(BoundedBTreeSet::<T, Bound>::unchecked_from(self.collect::<BTreeSet<T>>()))
395		}
396	}
397}
398
399#[cfg(test)]
400mod test {
401	use super::*;
402	use crate::ConstU32;
403	use alloc::{vec, vec::Vec};
404	use codec::CompactLen;
405
406	fn set_from_keys<T>(keys: &[T]) -> BTreeSet<T>
407	where
408		T: Ord + Copy,
409	{
410		keys.iter().copied().collect()
411	}
412
413	fn boundedset_from_keys<T, S>(keys: &[T]) -> BoundedBTreeSet<T, S>
414	where
415		T: Ord + Copy,
416		S: Get<u32>,
417	{
418		set_from_keys(keys).try_into().unwrap()
419	}
420
421	#[test]
422	fn encoding_same_as_unbounded_set() {
423		let b = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
424		let m = set_from_keys(&[1, 2, 3, 4, 5, 6]);
425
426		assert_eq!(b.encode(), m.encode());
427	}
428
429	#[test]
430	fn try_insert_works() {
431		let mut bounded = boundedset_from_keys::<u32, ConstU32<4>>(&[1, 2, 3]);
432		bounded.try_insert(0).unwrap();
433		assert_eq!(*bounded, set_from_keys(&[1, 0, 2, 3]));
434
435		assert!(bounded.try_insert(9).is_err());
436		assert_eq!(*bounded, set_from_keys(&[1, 0, 2, 3]));
437	}
438
439	#[test]
440	fn deref_coercion_works() {
441		let bounded = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3]);
442		// these methods come from deref-ed vec.
443		assert_eq!(bounded.len(), 3);
444		assert!(bounded.iter().next().is_some());
445		assert!(!bounded.is_empty());
446	}
447
448	#[test]
449	fn try_mutate_works() {
450		let bounded = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
451		let bounded = bounded
452			.try_mutate(|v| {
453				v.insert(7);
454			})
455			.unwrap();
456		assert_eq!(bounded.len(), 7);
457		assert!(bounded
458			.try_mutate(|v| {
459				v.insert(8);
460			})
461			.is_none());
462	}
463
464	#[test]
465	fn btree_map_eq_works() {
466		let bounded = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
467		assert_eq!(bounded, set_from_keys(&[1, 2, 3, 4, 5, 6]));
468	}
469
470	#[test]
471	fn too_big_fail_to_decode() {
472		let v: Vec<u32> = vec![1, 2, 3, 4, 5];
473		assert_eq!(
474			BoundedBTreeSet::<u32, ConstU32<4>>::decode(&mut &v.encode()[..]),
475			Err("BoundedBTreeSet exceeds its limit".into()),
476		);
477	}
478
479	#[test]
480	fn dont_consume_more_data_than_bounded_len() {
481		let s = set_from_keys(&[1, 2, 3, 4, 5, 6]);
482		let data = s.encode();
483		let data_input = &mut &data[..];
484
485		BoundedBTreeSet::<u32, ConstU32<4>>::decode(data_input).unwrap_err();
486		assert_eq!(data_input.len(), data.len() - Compact::<u32>::compact_len(&(data.len() as u32)));
487	}
488
489	#[test]
490	fn unequal_eq_impl_insert_works() {
491		// given a struct with a strange notion of equality
492		#[derive(Debug)]
493		struct Unequal(u32, bool);
494
495		impl PartialEq for Unequal {
496			fn eq(&self, other: &Self) -> bool {
497				self.0 == other.0
498			}
499		}
500		impl Eq for Unequal {}
501
502		impl Ord for Unequal {
503			fn cmp(&self, other: &Self) -> core::cmp::Ordering {
504				self.0.cmp(&other.0)
505			}
506		}
507
508		impl PartialOrd for Unequal {
509			fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
510				Some(self.cmp(other))
511			}
512		}
513
514		let mut set = BoundedBTreeSet::<Unequal, ConstU32<4>>::new();
515
516		// when the set is full
517
518		for i in 0..4 {
519			set.try_insert(Unequal(i, false)).unwrap();
520		}
521
522		// can't insert a new distinct member
523		set.try_insert(Unequal(5, false)).unwrap_err();
524
525		// but _can_ insert a distinct member which compares equal, though per the documentation,
526		// neither the set length nor the actual member are changed
527		set.try_insert(Unequal(0, true)).unwrap();
528		assert_eq!(set.len(), 4);
529		let zero_item = set.get(&Unequal(0, true)).unwrap();
530		assert_eq!(zero_item.0, 0);
531		assert_eq!(zero_item.1, false);
532	}
533
534	#[test]
535	fn eq_works() {
536		// of same type
537		let b1 = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2]);
538		let b2 = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2]);
539		assert_eq!(b1, b2);
540
541		// of different type, but same value and bound.
542		crate::parameter_types! {
543			B1: u32 = 7;
544			B2: u32 = 7;
545		}
546		let b1 = boundedset_from_keys::<u32, B1>(&[1, 2]);
547		let b2 = boundedset_from_keys::<u32, B2>(&[1, 2]);
548		assert_eq!(b1, b2);
549	}
550
551	#[test]
552	fn can_be_collected() {
553		let b1 = boundedset_from_keys::<u32, ConstU32<5>>(&[1, 2, 3, 4]);
554		let b2: BoundedBTreeSet<u32, ConstU32<5>> = b1.iter().map(|k| k + 1).try_collect().unwrap();
555		assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3, 4, 5]);
556
557		// can also be collected into a collection of length 4.
558		let b2: BoundedBTreeSet<u32, ConstU32<4>> = b1.iter().map(|k| k + 1).try_collect().unwrap();
559		assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3, 4, 5]);
560
561		// can be mutated further into iterators that are `ExactSizedIterator`.
562		let b2: BoundedBTreeSet<u32, ConstU32<5>> = b1.iter().map(|k| k + 1).rev().skip(2).try_collect().unwrap();
563		// note that the binary tree will re-sort this, so rev() is not really seen
564		assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3]);
565
566		let b2: BoundedBTreeSet<u32, ConstU32<5>> = b1.iter().map(|k| k + 1).take(2).try_collect().unwrap();
567		assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3]);
568
569		// but these worn't work
570		let b2: Result<BoundedBTreeSet<u32, ConstU32<3>>, _> = b1.iter().map(|k| k + 1).try_collect();
571		assert!(b2.is_err());
572
573		let b2: Result<BoundedBTreeSet<u32, ConstU32<1>>, _> = b1.iter().map(|k| k + 1).skip(2).try_collect();
574		assert!(b2.is_err());
575	}
576
577	// Just a test that structs containing `BoundedBTreeSet` can derive `Hash`. (This was broken
578	// when it was deriving `Hash`).
579	#[test]
580	#[cfg(feature = "std")]
581	fn container_can_derive_hash() {
582		#[derive(Hash)]
583		struct Foo {
584			bar: u8,
585			set: BoundedBTreeSet<String, ConstU32<16>>,
586		}
587	}
588
589	#[cfg(feature = "serde")]
590	mod serde {
591		use super::*;
592		use crate::alloc::string::ToString as _;
593
594		#[test]
595		fn test_serializer() {
596			let mut c = BoundedBTreeSet::<u32, ConstU32<6>>::new();
597			c.try_insert(0).unwrap();
598			c.try_insert(1).unwrap();
599			c.try_insert(2).unwrap();
600
601			assert_eq!(serde_json::json!(&c).to_string(), r#"[0,1,2]"#);
602		}
603
604		#[test]
605		fn test_deserializer() {
606			let c: Result<BoundedBTreeSet<u32, ConstU32<6>>, serde_json::error::Error> =
607				serde_json::from_str(r#"[0,1,2]"#);
608			assert!(c.is_ok());
609			let c = c.unwrap();
610
611			assert_eq!(c.len(), 3);
612			assert!(c.contains(&0));
613			assert!(c.contains(&1));
614			assert!(c.contains(&2));
615		}
616
617		#[test]
618		fn test_deserializer_bound() {
619			let c: Result<BoundedBTreeSet<u32, ConstU32<3>>, serde_json::error::Error> =
620				serde_json::from_str(r#"[0,1,2]"#);
621			assert!(c.is_ok());
622			let c = c.unwrap();
623
624			assert_eq!(c.len(), 3);
625			assert!(c.contains(&0));
626			assert!(c.contains(&1));
627			assert!(c.contains(&2));
628		}
629
630		#[test]
631		fn test_deserializer_failed() {
632			let c: Result<BoundedBTreeSet<u32, ConstU32<4>>, serde_json::error::Error> =
633				serde_json::from_str(r#"[0,1,2,3,4]"#);
634
635			match c {
636				Err(msg) => assert_eq!(msg.to_string(), "out of bounds at line 1 column 11"),
637				_ => unreachable!("deserializer must raise error"),
638			}
639		}
640	}
641}