1use crate::{Get, TryCollect};
21use alloc::collections::BTreeSet;
22use codec::{Compact, Decode, Encode, MaxEncodedLen};
23use core::{borrow::Borrow, marker::PhantomData, ops::Deref};
24#[cfg(feature = "serde")]
25use serde::{
26 de::{Error, SeqAccess, Visitor},
27 Deserialize, Deserializer, Serialize,
28};
29
30#[cfg_attr(feature = "serde", derive(Serialize), serde(transparent))]
38#[derive(Encode, scale_info::TypeInfo)]
39#[scale_info(skip_type_params(S))]
40pub struct BoundedBTreeSet<T, S>(BTreeSet<T>, #[cfg_attr(feature = "serde", serde(skip_serializing))] PhantomData<S>);
41
42#[cfg(feature = "serde")]
43impl<'de, T, S: Get<u32>> Deserialize<'de> for BoundedBTreeSet<T, S>
44where
45 T: Ord + Deserialize<'de>,
46 S: Clone,
47{
48 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
49 where
50 D: Deserializer<'de>,
51 {
52 struct BTreeSetVisitor<T, S>(PhantomData<(T, S)>);
54
55 impl<'de, T, S> Visitor<'de> for BTreeSetVisitor<T, S>
56 where
57 T: Ord + Deserialize<'de>,
58 S: Get<u32> + Clone,
59 {
60 type Value = BTreeSet<T>;
61
62 fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
63 formatter.write_str("a sequence")
64 }
65
66 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
67 where
68 A: SeqAccess<'de>,
69 {
70 let size = seq.size_hint().unwrap_or(0);
71 let max = match usize::try_from(S::get()) {
72 Ok(n) => n,
73 Err(_) => return Err(A::Error::custom("can't convert to usize")),
74 };
75 if size > max {
76 Err(A::Error::custom("out of bounds"))
77 } else {
78 let mut values = BTreeSet::new();
79
80 while let Some(value) = seq.next_element()? {
81 if values.len() >= max {
82 return Err(A::Error::custom("out of bounds"))
83 }
84 values.insert(value);
85 }
86
87 Ok(values)
88 }
89 }
90 }
91
92 let visitor: BTreeSetVisitor<T, S> = BTreeSetVisitor(PhantomData);
93 deserializer
94 .deserialize_seq(visitor)
95 .map(|v| BoundedBTreeSet::<T, S>::try_from(v).map_err(|_| Error::custom("out of bounds")))?
96 }
97}
98
99impl<T, S> Decode for BoundedBTreeSet<T, S>
100where
101 T: Decode + Ord,
102 S: Get<u32>,
103{
104 fn decode<I: codec::Input>(input: &mut I) -> Result<Self, codec::Error> {
105 let len: u32 = <Compact<u32>>::decode(input)?.into();
108 if len > S::get() {
109 return Err("BoundedBTreeSet exceeds its limit".into())
110 }
111 input.descend_ref()?;
112 let inner = Result::from_iter((0..len).map(|_| Decode::decode(input)))?;
113 input.ascend_ref();
114 Ok(Self(inner, PhantomData))
115 }
116
117 fn skip<I: codec::Input>(input: &mut I) -> Result<(), codec::Error> {
118 BTreeSet::<T>::skip(input)
119 }
120}
121
122impl<T, S> BoundedBTreeSet<T, S>
123where
124 S: Get<u32>,
125{
126 pub fn bound() -> usize {
128 S::get() as usize
129 }
130}
131
132impl<T, S> BoundedBTreeSet<T, S>
133where
134 T: Ord,
135 S: Get<u32>,
136{
137 fn unchecked_from(t: BTreeSet<T>) -> Self {
139 Self(t, Default::default())
140 }
141
142 pub fn new() -> Self {
146 BoundedBTreeSet(BTreeSet::new(), PhantomData)
147 }
148
149 pub fn into_inner(self) -> BTreeSet<T> {
154 debug_assert!(self.0.len() <= Self::bound());
155 self.0
156 }
157
158 pub fn try_mutate(mut self, mut mutate: impl FnMut(&mut BTreeSet<T>)) -> Option<Self> {
166 mutate(&mut self.0);
167 (self.0.len() <= Self::bound()).then(move || self)
168 }
169
170 pub fn clear(&mut self) {
172 self.0.clear()
173 }
174
175 pub fn try_insert(&mut self, item: T) -> Result<bool, T> {
180 if self.len() < Self::bound() || self.0.contains(&item) {
181 Ok(self.0.insert(item))
182 } else {
183 Err(item)
184 }
185 }
186
187 pub fn remove<Q>(&mut self, item: &Q) -> bool
192 where
193 T: Borrow<Q>,
194 Q: Ord + ?Sized,
195 {
196 self.0.remove(item)
197 }
198
199 pub fn take<Q>(&mut self, value: &Q) -> Option<T>
204 where
205 T: Borrow<Q> + Ord,
206 Q: Ord + ?Sized,
207 {
208 self.0.take(value)
209 }
210}
211
212impl<T, S> Default for BoundedBTreeSet<T, S>
213where
214 T: Ord,
215 S: Get<u32>,
216{
217 fn default() -> Self {
218 Self::new()
219 }
220}
221
222impl<T, S> Clone for BoundedBTreeSet<T, S>
223where
224 BTreeSet<T>: Clone,
225{
226 fn clone(&self) -> Self {
227 BoundedBTreeSet(self.0.clone(), PhantomData)
228 }
229}
230
231impl<T, S> core::fmt::Debug for BoundedBTreeSet<T, S>
232where
233 BTreeSet<T>: core::fmt::Debug,
234 S: Get<u32>,
235{
236 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
237 f.debug_tuple("BoundedBTreeSet").field(&self.0).field(&Self::bound()).finish()
238 }
239}
240
241#[cfg(feature = "std")]
244impl<T: std::hash::Hash, S> std::hash::Hash for BoundedBTreeSet<T, S> {
245 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
246 self.0.hash(state);
247 }
248}
249
250impl<T, S1, S2> PartialEq<BoundedBTreeSet<T, S1>> for BoundedBTreeSet<T, S2>
251where
252 BTreeSet<T>: PartialEq,
253 S1: Get<u32>,
254 S2: Get<u32>,
255{
256 fn eq(&self, other: &BoundedBTreeSet<T, S1>) -> bool {
257 S1::get() == S2::get() && self.0 == other.0
258 }
259}
260
261impl<T, S> Eq for BoundedBTreeSet<T, S>
262where
263 BTreeSet<T>: Eq,
264 S: Get<u32>,
265{
266}
267
268impl<T, S> PartialEq<BTreeSet<T>> for BoundedBTreeSet<T, S>
269where
270 BTreeSet<T>: PartialEq,
271 S: Get<u32>,
272{
273 fn eq(&self, other: &BTreeSet<T>) -> bool {
274 self.0 == *other
275 }
276}
277
278impl<T, S> PartialOrd for BoundedBTreeSet<T, S>
279where
280 BTreeSet<T>: PartialOrd,
281 S: Get<u32>,
282{
283 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
284 self.0.partial_cmp(&other.0)
285 }
286}
287
288impl<T, S> Ord for BoundedBTreeSet<T, S>
289where
290 BTreeSet<T>: Ord,
291 S: Get<u32>,
292{
293 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
294 self.0.cmp(&other.0)
295 }
296}
297
298impl<T, S> IntoIterator for BoundedBTreeSet<T, S> {
299 type Item = T;
300 type IntoIter = alloc::collections::btree_set::IntoIter<T>;
301
302 fn into_iter(self) -> Self::IntoIter {
303 self.0.into_iter()
304 }
305}
306
307impl<'a, T, S> IntoIterator for &'a BoundedBTreeSet<T, S> {
308 type Item = &'a T;
309 type IntoIter = alloc::collections::btree_set::Iter<'a, T>;
310
311 fn into_iter(self) -> Self::IntoIter {
312 self.0.iter()
313 }
314}
315
316impl<T, S> MaxEncodedLen for BoundedBTreeSet<T, S>
317where
318 T: MaxEncodedLen,
319 S: Get<u32>,
320{
321 fn max_encoded_len() -> usize {
322 Self::bound()
323 .saturating_mul(T::max_encoded_len())
324 .saturating_add(codec::Compact(S::get()).encoded_size())
325 }
326}
327
328impl<T, S> Deref for BoundedBTreeSet<T, S>
329where
330 T: Ord,
331{
332 type Target = BTreeSet<T>;
333
334 fn deref(&self) -> &Self::Target {
335 &self.0
336 }
337}
338
339impl<T, S> AsRef<BTreeSet<T>> for BoundedBTreeSet<T, S>
340where
341 T: Ord,
342{
343 fn as_ref(&self) -> &BTreeSet<T> {
344 &self.0
345 }
346}
347
348impl<T, S> From<BoundedBTreeSet<T, S>> for BTreeSet<T>
349where
350 T: Ord,
351{
352 fn from(set: BoundedBTreeSet<T, S>) -> Self {
353 set.0
354 }
355}
356
357impl<T, S> TryFrom<BTreeSet<T>> for BoundedBTreeSet<T, S>
358where
359 T: Ord,
360 S: Get<u32>,
361{
362 type Error = ();
363
364 fn try_from(value: BTreeSet<T>) -> Result<Self, Self::Error> {
365 (value.len() <= Self::bound())
366 .then(move || BoundedBTreeSet(value, PhantomData))
367 .ok_or(())
368 }
369}
370
371impl<T, S> codec::DecodeLength for BoundedBTreeSet<T, S> {
372 fn len(self_encoded: &[u8]) -> Result<usize, codec::Error> {
373 <BTreeSet<T> as codec::DecodeLength>::len(self_encoded)
377 }
378}
379
380impl<T, S> codec::EncodeLike<BTreeSet<T>> for BoundedBTreeSet<T, S> where BTreeSet<T>: Encode {}
381
382impl<I, T, Bound> TryCollect<BoundedBTreeSet<T, Bound>> for I
383where
384 T: Ord,
385 I: ExactSizeIterator + Iterator<Item = T>,
386 Bound: Get<u32>,
387{
388 type Error = &'static str;
389
390 fn try_collect(self) -> Result<BoundedBTreeSet<T, Bound>, Self::Error> {
391 if self.len() > Bound::get() as usize {
392 Err("iterator length too big")
393 } else {
394 Ok(BoundedBTreeSet::<T, Bound>::unchecked_from(self.collect::<BTreeSet<T>>()))
395 }
396 }
397}
398
399#[cfg(test)]
400mod test {
401 use super::*;
402 use crate::ConstU32;
403 use alloc::{vec, vec::Vec};
404 use codec::CompactLen;
405
406 fn set_from_keys<T>(keys: &[T]) -> BTreeSet<T>
407 where
408 T: Ord + Copy,
409 {
410 keys.iter().copied().collect()
411 }
412
413 fn boundedset_from_keys<T, S>(keys: &[T]) -> BoundedBTreeSet<T, S>
414 where
415 T: Ord + Copy,
416 S: Get<u32>,
417 {
418 set_from_keys(keys).try_into().unwrap()
419 }
420
421 #[test]
422 fn encoding_same_as_unbounded_set() {
423 let b = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
424 let m = set_from_keys(&[1, 2, 3, 4, 5, 6]);
425
426 assert_eq!(b.encode(), m.encode());
427 }
428
429 #[test]
430 fn try_insert_works() {
431 let mut bounded = boundedset_from_keys::<u32, ConstU32<4>>(&[1, 2, 3]);
432 bounded.try_insert(0).unwrap();
433 assert_eq!(*bounded, set_from_keys(&[1, 0, 2, 3]));
434
435 assert!(bounded.try_insert(9).is_err());
436 assert_eq!(*bounded, set_from_keys(&[1, 0, 2, 3]));
437 }
438
439 #[test]
440 fn deref_coercion_works() {
441 let bounded = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3]);
442 assert_eq!(bounded.len(), 3);
444 assert!(bounded.iter().next().is_some());
445 assert!(!bounded.is_empty());
446 }
447
448 #[test]
449 fn try_mutate_works() {
450 let bounded = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
451 let bounded = bounded
452 .try_mutate(|v| {
453 v.insert(7);
454 })
455 .unwrap();
456 assert_eq!(bounded.len(), 7);
457 assert!(bounded
458 .try_mutate(|v| {
459 v.insert(8);
460 })
461 .is_none());
462 }
463
464 #[test]
465 fn btree_map_eq_works() {
466 let bounded = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2, 3, 4, 5, 6]);
467 assert_eq!(bounded, set_from_keys(&[1, 2, 3, 4, 5, 6]));
468 }
469
470 #[test]
471 fn too_big_fail_to_decode() {
472 let v: Vec<u32> = vec![1, 2, 3, 4, 5];
473 assert_eq!(
474 BoundedBTreeSet::<u32, ConstU32<4>>::decode(&mut &v.encode()[..]),
475 Err("BoundedBTreeSet exceeds its limit".into()),
476 );
477 }
478
479 #[test]
480 fn dont_consume_more_data_than_bounded_len() {
481 let s = set_from_keys(&[1, 2, 3, 4, 5, 6]);
482 let data = s.encode();
483 let data_input = &mut &data[..];
484
485 BoundedBTreeSet::<u32, ConstU32<4>>::decode(data_input).unwrap_err();
486 assert_eq!(data_input.len(), data.len() - Compact::<u32>::compact_len(&(data.len() as u32)));
487 }
488
489 #[test]
490 fn unequal_eq_impl_insert_works() {
491 #[derive(Debug)]
493 struct Unequal(u32, bool);
494
495 impl PartialEq for Unequal {
496 fn eq(&self, other: &Self) -> bool {
497 self.0 == other.0
498 }
499 }
500 impl Eq for Unequal {}
501
502 impl Ord for Unequal {
503 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
504 self.0.cmp(&other.0)
505 }
506 }
507
508 impl PartialOrd for Unequal {
509 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
510 Some(self.cmp(other))
511 }
512 }
513
514 let mut set = BoundedBTreeSet::<Unequal, ConstU32<4>>::new();
515
516 for i in 0..4 {
519 set.try_insert(Unequal(i, false)).unwrap();
520 }
521
522 set.try_insert(Unequal(5, false)).unwrap_err();
524
525 set.try_insert(Unequal(0, true)).unwrap();
528 assert_eq!(set.len(), 4);
529 let zero_item = set.get(&Unequal(0, true)).unwrap();
530 assert_eq!(zero_item.0, 0);
531 assert_eq!(zero_item.1, false);
532 }
533
534 #[test]
535 fn eq_works() {
536 let b1 = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2]);
538 let b2 = boundedset_from_keys::<u32, ConstU32<7>>(&[1, 2]);
539 assert_eq!(b1, b2);
540
541 crate::parameter_types! {
543 B1: u32 = 7;
544 B2: u32 = 7;
545 }
546 let b1 = boundedset_from_keys::<u32, B1>(&[1, 2]);
547 let b2 = boundedset_from_keys::<u32, B2>(&[1, 2]);
548 assert_eq!(b1, b2);
549 }
550
551 #[test]
552 fn can_be_collected() {
553 let b1 = boundedset_from_keys::<u32, ConstU32<5>>(&[1, 2, 3, 4]);
554 let b2: BoundedBTreeSet<u32, ConstU32<5>> = b1.iter().map(|k| k + 1).try_collect().unwrap();
555 assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3, 4, 5]);
556
557 let b2: BoundedBTreeSet<u32, ConstU32<4>> = b1.iter().map(|k| k + 1).try_collect().unwrap();
559 assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3, 4, 5]);
560
561 let b2: BoundedBTreeSet<u32, ConstU32<5>> = b1.iter().map(|k| k + 1).rev().skip(2).try_collect().unwrap();
563 assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3]);
565
566 let b2: BoundedBTreeSet<u32, ConstU32<5>> = b1.iter().map(|k| k + 1).take(2).try_collect().unwrap();
567 assert_eq!(b2.into_iter().collect::<Vec<_>>(), vec![2, 3]);
568
569 let b2: Result<BoundedBTreeSet<u32, ConstU32<3>>, _> = b1.iter().map(|k| k + 1).try_collect();
571 assert!(b2.is_err());
572
573 let b2: Result<BoundedBTreeSet<u32, ConstU32<1>>, _> = b1.iter().map(|k| k + 1).skip(2).try_collect();
574 assert!(b2.is_err());
575 }
576
577 #[test]
580 #[cfg(feature = "std")]
581 fn container_can_derive_hash() {
582 #[derive(Hash)]
583 struct Foo {
584 bar: u8,
585 set: BoundedBTreeSet<String, ConstU32<16>>,
586 }
587 }
588
589 #[cfg(feature = "serde")]
590 mod serde {
591 use super::*;
592 use crate::alloc::string::ToString as _;
593
594 #[test]
595 fn test_serializer() {
596 let mut c = BoundedBTreeSet::<u32, ConstU32<6>>::new();
597 c.try_insert(0).unwrap();
598 c.try_insert(1).unwrap();
599 c.try_insert(2).unwrap();
600
601 assert_eq!(serde_json::json!(&c).to_string(), r#"[0,1,2]"#);
602 }
603
604 #[test]
605 fn test_deserializer() {
606 let c: Result<BoundedBTreeSet<u32, ConstU32<6>>, serde_json::error::Error> =
607 serde_json::from_str(r#"[0,1,2]"#);
608 assert!(c.is_ok());
609 let c = c.unwrap();
610
611 assert_eq!(c.len(), 3);
612 assert!(c.contains(&0));
613 assert!(c.contains(&1));
614 assert!(c.contains(&2));
615 }
616
617 #[test]
618 fn test_deserializer_bound() {
619 let c: Result<BoundedBTreeSet<u32, ConstU32<3>>, serde_json::error::Error> =
620 serde_json::from_str(r#"[0,1,2]"#);
621 assert!(c.is_ok());
622 let c = c.unwrap();
623
624 assert_eq!(c.len(), 3);
625 assert!(c.contains(&0));
626 assert!(c.contains(&1));
627 assert!(c.contains(&2));
628 }
629
630 #[test]
631 fn test_deserializer_failed() {
632 let c: Result<BoundedBTreeSet<u32, ConstU32<4>>, serde_json::error::Error> =
633 serde_json::from_str(r#"[0,1,2,3,4]"#);
634
635 match c {
636 Err(msg) => assert_eq!(msg.to_string(), "out of bounds at line 1 column 11"),
637 _ => unreachable!("deserializer must raise error"),
638 }
639 }
640 }
641}