1mod bucket;
70mod entry;
71#[allow(clippy::ptr_offset_with_cast)]
72#[allow(clippy::assign_op_pattern)]
73mod key;
74
75pub use bucket::NodeStatus;
76pub use entry::*;
77
78use bucket::KBucket;
79use std::collections::VecDeque;
80use std::num::NonZeroUsize;
81use std::time::Duration;
82use web_time::Instant;
83
84const NUM_BUCKETS: usize = 256;
86
87#[derive(Debug, Clone, Copy)]
89pub(crate) struct KBucketConfig {
90 bucket_size: usize,
92 pending_timeout: Duration,
96}
97
98impl Default for KBucketConfig {
99 fn default() -> Self {
100 KBucketConfig {
101 bucket_size: K_VALUE.get(),
102 pending_timeout: Duration::from_secs(60),
103 }
104 }
105}
106
107impl KBucketConfig {
108 pub(crate) fn set_bucket_size(&mut self, bucket_size: NonZeroUsize) {
110 self.bucket_size = bucket_size.get();
111 }
112
113 pub(crate) fn set_pending_timeout(&mut self, pending_timeout: Duration) {
117 self.pending_timeout = pending_timeout;
118 }
119}
120
121#[derive(Debug, Clone)]
123pub(crate) struct KBucketsTable<TKey, TVal> {
124 local_key: TKey,
126 buckets: Vec<KBucket<TKey, TVal>>,
128 bucket_size: usize,
130 applied_pending: VecDeque<AppliedPending<TKey, TVal>>,
133}
134
135#[derive(Debug, Copy, Clone, PartialEq, Eq)]
138struct BucketIndex(usize);
139
140impl BucketIndex {
141 fn new(d: &Distance) -> Option<BucketIndex> {
149 d.ilog2().map(|i| BucketIndex(i as usize))
150 }
151
152 fn get(&self) -> usize {
154 self.0
155 }
156
157 fn range(&self) -> (Distance, Distance) {
160 let min = Distance(U256::pow(U256::from(2), U256::from(self.0)));
161 if self.0 == usize::from(u8::MAX) {
162 (min, Distance(U256::MAX))
163 } else {
164 let max = Distance(U256::pow(U256::from(2), U256::from(self.0 + 1)) - 1);
165 (min, max)
166 }
167 }
168
169 fn rand_distance(&self, rng: &mut impl rand::Rng) -> Distance {
171 let mut bytes = [0u8; 32];
172 let quot = self.0 / 8;
173 for i in 0..quot {
174 bytes[31 - i] = rng.gen();
175 }
176 let rem = (self.0 % 8) as u32;
177 let lower = usize::pow(2, rem);
178 let upper = usize::pow(2, rem + 1);
179 bytes[31 - quot] = rng.gen_range(lower..upper) as u8;
180 Distance(U256::from(bytes))
181 }
182}
183
184impl<TKey, TVal> KBucketsTable<TKey, TVal>
185where
186 TKey: Clone + AsRef<KeyBytes>,
187 TVal: Clone,
188{
189 pub(crate) fn new(local_key: TKey, config: KBucketConfig) -> Self {
192 KBucketsTable {
193 local_key,
194 buckets: (0..NUM_BUCKETS).map(|_| KBucket::new(config)).collect(),
195 bucket_size: config.bucket_size,
196 applied_pending: VecDeque::new(),
197 }
198 }
199
200 pub(crate) fn local_key(&self) -> &TKey {
202 &self.local_key
203 }
204
205 pub(crate) fn entry<'a>(&'a mut self, key: &'a TKey) -> Option<Entry<'a, TKey, TVal>> {
210 let index = BucketIndex::new(&self.local_key.as_ref().distance(key))?;
211
212 let bucket = &mut self.buckets[index.get()];
213 if let Some(applied) = bucket.apply_pending() {
214 self.applied_pending.push_back(applied)
215 }
216 Some(Entry::new(bucket, key))
217 }
218
219 pub(crate) fn iter(&mut self) -> impl Iterator<Item = KBucketRef<'_, TKey, TVal>> + '_ {
224 let applied_pending = &mut self.applied_pending;
225 self.buckets.iter_mut().enumerate().map(move |(i, b)| {
226 if let Some(applied) = b.apply_pending() {
227 applied_pending.push_back(applied)
228 }
229 KBucketRef {
230 index: BucketIndex(i),
231 bucket: b,
232 }
233 })
234 }
235
236 pub(crate) fn bucket<K>(&mut self, key: &K) -> Option<KBucketRef<'_, TKey, TVal>>
240 where
241 K: AsRef<KeyBytes>,
242 {
243 let d = self.local_key.as_ref().distance(key);
244 if let Some(index) = BucketIndex::new(&d) {
245 let bucket = &mut self.buckets[index.0];
246 if let Some(applied) = bucket.apply_pending() {
247 self.applied_pending.push_back(applied)
248 }
249 Some(KBucketRef { bucket, index })
250 } else {
251 None
252 }
253 }
254
255 pub(crate) fn take_applied_pending(&mut self) -> Option<AppliedPending<TKey, TVal>> {
268 self.applied_pending.pop_front()
269 }
270
271 pub(crate) fn closest_keys<'a, T>(
274 &'a mut self,
275 target: &'a T,
276 ) -> impl Iterator<Item = TKey> + 'a
277 where
278 T: AsRef<KeyBytes>,
279 {
280 let distance = self.local_key.as_ref().distance(target);
281 let bucket_size = self.bucket_size;
282 ClosestIter {
283 target,
284 iter: None,
285 table: self,
286 buckets_iter: ClosestBucketsIter::new(distance),
287 fmap: move |b: &KBucket<TKey, _>| -> Vec<_> {
288 let mut vec = Vec::with_capacity(bucket_size);
289 vec.extend(b.iter().map(|(n, _)| n.key.clone()));
290 vec
291 },
292 }
293 }
294
295 pub(crate) fn closest<'a, T>(
298 &'a mut self,
299 target: &'a T,
300 ) -> impl Iterator<Item = EntryView<TKey, TVal>> + 'a
301 where
302 T: Clone + AsRef<KeyBytes>,
303 TVal: Clone,
304 {
305 let distance = self.local_key.as_ref().distance(target);
306 let bucket_size = self.bucket_size;
307 ClosestIter {
308 target,
309 iter: None,
310 table: self,
311 buckets_iter: ClosestBucketsIter::new(distance),
312 fmap: move |b: &KBucket<_, TVal>| -> Vec<_> {
313 b.iter()
314 .take(bucket_size)
315 .map(|(n, status)| EntryView {
316 node: n.clone(),
317 status,
318 })
319 .collect()
320 },
321 }
322 }
323
324 pub(crate) fn count_nodes_between<T>(&mut self, target: &T) -> usize
330 where
331 T: AsRef<KeyBytes>,
332 {
333 let local_key = self.local_key.clone();
334 let distance = target.as_ref().distance(&local_key);
335 let mut iter = ClosestBucketsIter::new(distance).take_while(|i| i.get() != 0);
336 if let Some(i) = iter.next() {
337 let num_first = self.buckets[i.get()]
338 .iter()
339 .filter(|(n, _)| n.key.as_ref().distance(&local_key) <= distance)
340 .count();
341 let num_rest: usize = iter.map(|i| self.buckets[i.get()].num_entries()).sum();
342 num_first + num_rest
343 } else {
344 0
345 }
346 }
347}
348
349struct ClosestIter<'a, TTarget, TKey, TVal, TMap, TOut> {
352 target: &'a TTarget,
357 table: &'a mut KBucketsTable<TKey, TVal>,
359 buckets_iter: ClosestBucketsIter,
362 iter: Option<std::vec::IntoIter<TOut>>,
364 fmap: TMap,
367}
368
369struct ClosestBucketsIter {
373 distance: Distance,
375 state: ClosestBucketsIterState,
377}
378
379enum ClosestBucketsIterState {
381 Start(BucketIndex),
384 ZoomIn(BucketIndex),
390 ZoomOut(BucketIndex),
396 Done,
398}
399
400impl ClosestBucketsIter {
401 fn new(distance: Distance) -> Self {
402 let state = match BucketIndex::new(&distance) {
403 Some(i) => ClosestBucketsIterState::Start(i),
404 None => ClosestBucketsIterState::Start(BucketIndex(0)),
405 };
406 Self { distance, state }
407 }
408
409 fn next_in(&self, i: BucketIndex) -> Option<BucketIndex> {
410 (0..i.get()).rev().find_map(|i| {
411 if self.distance.0.bit(i) {
412 Some(BucketIndex(i))
413 } else {
414 None
415 }
416 })
417 }
418
419 fn next_out(&self, i: BucketIndex) -> Option<BucketIndex> {
420 (i.get() + 1..NUM_BUCKETS).find_map(|i| {
421 if !self.distance.0.bit(i) {
422 Some(BucketIndex(i))
423 } else {
424 None
425 }
426 })
427 }
428}
429
430impl Iterator for ClosestBucketsIter {
431 type Item = BucketIndex;
432
433 fn next(&mut self) -> Option<Self::Item> {
434 match self.state {
435 ClosestBucketsIterState::Start(i) => {
436 self.state = ClosestBucketsIterState::ZoomIn(i);
437 Some(i)
438 }
439 ClosestBucketsIterState::ZoomIn(i) => {
440 if let Some(i) = self.next_in(i) {
441 self.state = ClosestBucketsIterState::ZoomIn(i);
442 Some(i)
443 } else {
444 let i = BucketIndex(0);
445 self.state = ClosestBucketsIterState::ZoomOut(i);
446 Some(i)
447 }
448 }
449 ClosestBucketsIterState::ZoomOut(i) => {
450 if let Some(i) = self.next_out(i) {
451 self.state = ClosestBucketsIterState::ZoomOut(i);
452 Some(i)
453 } else {
454 self.state = ClosestBucketsIterState::Done;
455 None
456 }
457 }
458 ClosestBucketsIterState::Done => None,
459 }
460 }
461}
462
463impl<TTarget, TKey, TVal, TMap, TOut> Iterator for ClosestIter<'_, TTarget, TKey, TVal, TMap, TOut>
464where
465 TTarget: AsRef<KeyBytes>,
466 TKey: Clone + AsRef<KeyBytes>,
467 TVal: Clone,
468 TMap: Fn(&KBucket<TKey, TVal>) -> Vec<TOut>,
469 TOut: AsRef<KeyBytes>,
470{
471 type Item = TOut;
472
473 fn next(&mut self) -> Option<Self::Item> {
474 loop {
475 match &mut self.iter {
476 Some(iter) => match iter.next() {
477 Some(k) => return Some(k),
478 None => self.iter = None,
479 },
480 None => {
481 if let Some(i) = self.buckets_iter.next() {
482 let bucket = &mut self.table.buckets[i.get()];
483 if let Some(applied) = bucket.apply_pending() {
484 self.table.applied_pending.push_back(applied)
485 }
486 let mut v = (self.fmap)(bucket);
487 v.sort_by(|a, b| {
488 self.target
489 .as_ref()
490 .distance(a.as_ref())
491 .cmp(&self.target.as_ref().distance(b.as_ref()))
492 });
493 self.iter = Some(v.into_iter());
494 } else {
495 return None;
496 }
497 }
498 }
499 }
500 }
501}
502
503pub struct KBucketRef<'a, TKey, TVal> {
505 index: BucketIndex,
506 bucket: &'a mut KBucket<TKey, TVal>,
507}
508
509impl<'a, TKey, TVal> KBucketRef<'a, TKey, TVal>
510where
511 TKey: Clone + AsRef<KeyBytes>,
512 TVal: Clone,
513{
514 pub fn range(&self) -> (Distance, Distance) {
517 self.index.range()
518 }
519
520 pub fn is_empty(&self) -> bool {
522 self.num_entries() == 0
523 }
524
525 pub fn num_entries(&self) -> usize {
527 self.bucket.num_entries()
528 }
529
530 pub fn has_pending(&self) -> bool {
532 self.bucket.pending().map_or(false, |n| !n.is_ready())
533 }
534
535 pub fn contains(&self, d: &Distance) -> bool {
537 BucketIndex::new(d).map_or(false, |i| i == self.index)
538 }
539
540 pub fn rand_distance(&self, rng: &mut impl rand::Rng) -> Distance {
547 self.index.rand_distance(rng)
548 }
549
550 pub fn iter(&'a self) -> impl Iterator<Item = EntryRefView<'a, TKey, TVal>> {
552 self.bucket.iter().map(move |(n, status)| EntryRefView {
553 node: NodeRefView {
554 key: &n.key,
555 value: &n.value,
556 },
557 status,
558 })
559 }
560}
561
562#[cfg(test)]
563mod tests {
564 use super::*;
565 use libp2p_identity::PeerId;
566 use quickcheck::*;
567
568 type TestTable = KBucketsTable<KeyBytes, ()>;
569
570 impl Arbitrary for TestTable {
571 fn arbitrary(g: &mut Gen) -> TestTable {
572 let local_key = Key::from(PeerId::random());
573 let timeout = Duration::from_secs(g.gen_range(1..360));
574 let mut config = KBucketConfig::default();
575 config.set_pending_timeout(timeout);
576 let bucket_size = config.bucket_size;
577 let mut table = TestTable::new(local_key.into(), config);
578 let mut num_total = g.gen_range(0..100);
579 for (i, b) in &mut table.buckets.iter_mut().enumerate().rev() {
580 let ix = BucketIndex(i);
581 let num = g.gen_range(0..usize::min(bucket_size, num_total) + 1);
582 num_total -= num;
583 for _ in 0..num {
584 let distance = ix.rand_distance(&mut rand::thread_rng());
585 let key = local_key.for_distance(distance);
586 let node = Node { key, value: () };
587 let status = NodeStatus::arbitrary(g);
588 match b.insert(node, status) {
589 InsertResult::Inserted => {}
590 _ => panic!(),
591 }
592 }
593 }
594 table
595 }
596 }
597
598 #[test]
599 fn buckets_are_non_overlapping_and_exhaustive() {
600 let local_key = Key::from(PeerId::random());
601 let timeout = Duration::from_secs(0);
602 let mut config = KBucketConfig::default();
603 config.set_pending_timeout(timeout);
604 let mut table = KBucketsTable::<KeyBytes, ()>::new(local_key.into(), config);
605
606 let mut prev_max = U256::from(0);
607
608 for bucket in table.iter() {
609 let (min, max) = bucket.range();
610 assert_eq!(Distance(prev_max + U256::from(1)), min);
611 prev_max = max.0;
612 }
613
614 assert_eq!(U256::MAX, prev_max);
615 }
616
617 #[test]
618 fn bucket_contains_range() {
619 fn prop(ix: u8) {
620 let index = BucketIndex(ix as usize);
621 let mut config = KBucketConfig::default();
622 config.set_pending_timeout(Duration::from_secs(0));
623 let mut bucket = KBucket::<Key<PeerId>, ()>::new(config);
624 let bucket_ref = KBucketRef {
625 index,
626 bucket: &mut bucket,
627 };
628
629 let (min, max) = bucket_ref.range();
630
631 assert!(min <= max);
632
633 assert!(bucket_ref.contains(&min));
634 assert!(bucket_ref.contains(&max));
635
636 if min != Distance(0.into()) {
637 assert!(!bucket_ref.contains(&Distance(min.0 - 1)));
639 }
640
641 if max != Distance(U256::MAX) {
642 assert!(!bucket_ref.contains(&Distance(max.0 + 1)));
644 }
645 }
646
647 quickcheck(prop as fn(_));
648 }
649
650 #[test]
651 fn rand_distance() {
652 fn prop(ix: u8) -> bool {
653 let d = BucketIndex(ix as usize).rand_distance(&mut rand::thread_rng());
654 let n = U256::from(<[u8; 32]>::from(d.0));
655 let b = U256::from(2);
656 let e = U256::from(ix);
657 let lower = b.pow(e);
658 let upper = b.checked_pow(e + U256::from(1)).unwrap_or(U256::MAX) - U256::from(1);
659 lower <= n && n <= upper
660 }
661 quickcheck(prop as fn(_) -> _);
662 }
663
664 #[test]
665 fn entry_inserted() {
666 let local_key = Key::from(PeerId::random());
667 let other_id = Key::from(PeerId::random());
668
669 let mut table = KBucketsTable::<_, ()>::new(local_key, KBucketConfig::default());
670 if let Some(Entry::Absent(entry)) = table.entry(&other_id) {
671 match entry.insert((), NodeStatus::Connected) {
672 InsertResult::Inserted => (),
673 _ => panic!(),
674 }
675 } else {
676 panic!()
677 }
678
679 let res = table.closest_keys(&other_id).collect::<Vec<_>>();
680 assert_eq!(res.len(), 1);
681 assert_eq!(res[0], other_id);
682 }
683
684 #[test]
685 fn entry_self() {
686 let local_key = Key::from(PeerId::random());
687 let mut table = KBucketsTable::<_, ()>::new(local_key, KBucketConfig::default());
688
689 assert!(table.entry(&local_key).is_none())
690 }
691
692 #[test]
693 fn closest() {
694 let local_key = Key::from(PeerId::random());
695 let mut table = KBucketsTable::<_, ()>::new(local_key, KBucketConfig::default());
696 let mut count = 0;
697 loop {
698 if count == 100 {
699 break;
700 }
701 let key = Key::from(PeerId::random());
702 if let Some(Entry::Absent(e)) = table.entry(&key) {
703 match e.insert((), NodeStatus::Connected) {
704 InsertResult::Inserted => count += 1,
705 _ => continue,
706 }
707 } else {
708 panic!("entry exists")
709 }
710 }
711
712 let mut expected_keys: Vec<_> = table
713 .buckets
714 .iter()
715 .flat_map(|t| t.iter().map(|(n, _)| n.key))
716 .collect();
717
718 for _ in 0..10 {
719 let target_key = Key::from(PeerId::random());
720 let keys = table.closest_keys(&target_key).collect::<Vec<_>>();
721 expected_keys.sort_by_key(|k| k.distance(&target_key));
723 assert_eq!(keys, expected_keys);
724 }
725 }
726
727 #[test]
728 fn applied_pending() {
729 let local_key = Key::from(PeerId::random());
730 let mut config = KBucketConfig::default();
731 config.set_pending_timeout(Duration::from_millis(1));
732 let mut table = KBucketsTable::<_, ()>::new(local_key, config);
733 let expected_applied;
734 let full_bucket_index;
735 loop {
736 let key = Key::from(PeerId::random());
737 if let Some(Entry::Absent(e)) = table.entry(&key) {
738 match e.insert((), NodeStatus::Disconnected) {
739 InsertResult::Full => {
740 if let Some(Entry::Absent(e)) = table.entry(&key) {
741 match e.insert((), NodeStatus::Connected) {
742 InsertResult::Pending { disconnected } => {
743 expected_applied = AppliedPending {
744 inserted: Node { key, value: () },
745 evicted: Some(Node {
746 key: disconnected,
747 value: (),
748 }),
749 };
750 full_bucket_index = BucketIndex::new(&key.distance(&local_key));
751 break;
752 }
753 _ => panic!(),
754 }
755 } else {
756 panic!()
757 }
758 }
759 _ => continue,
760 }
761 } else {
762 panic!("entry exists")
763 }
764 }
765
766 let full_bucket = &mut table.buckets[full_bucket_index.unwrap().get()];
768 let elapsed = Instant::now().checked_sub(Duration::from_secs(1)).unwrap();
769 full_bucket.pending_mut().unwrap().set_ready_at(elapsed);
770
771 match table.entry(&expected_applied.inserted.key) {
772 Some(Entry::Present(_, NodeStatus::Connected)) => {}
773 x => panic!("Unexpected entry: {x:?}"),
774 }
775
776 match table.entry(&expected_applied.evicted.as_ref().unwrap().key) {
777 Some(Entry::Absent(_)) => {}
778 x => panic!("Unexpected entry: {x:?}"),
779 }
780
781 assert_eq!(Some(expected_applied), table.take_applied_pending());
782 assert_eq!(None, table.take_applied_pending());
783 }
784
785 #[test]
786 fn count_nodes_between() {
787 fn prop(mut table: TestTable, target: Key<PeerId>) -> bool {
788 let num_to_target = table.count_nodes_between(&target);
789 let distance = table.local_key.distance(&target);
790 let base2 = U256::from(2);
791 let mut iter = ClosestBucketsIter::new(distance);
792 iter.all(|i| {
793 let d = Distance(distance.0 ^ (base2.pow(U256::from(i.get()))));
795 let k = table.local_key.for_distance(d);
796 if distance.0.bit(i.get()) {
797 d < distance && table.count_nodes_between(&k) <= num_to_target
799 } else {
800 d > distance && table.count_nodes_between(&k) >= num_to_target
802 }
803 })
804 }
805
806 QuickCheck::new()
807 .tests(10)
808 .quickcheck(prop as fn(_, _) -> _)
809 }
810}