1use bytes::Bytes;
22
23use crate::{
24 protocol::libp2p::kademlia::{
25 message::KademliaMessage,
26 query::{QueryAction, QueryId},
27 types::{Distance, KademliaPeer, Key},
28 },
29 PeerId,
30};
31
32use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
33
34const LOG_TARGET: &str = "litep2p::ipfs::kademlia::query::find_node";
36
37const DEFAULT_PEER_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
39
40#[derive(Debug, Clone)]
42pub struct FindNodeConfig<T: Clone + Into<Vec<u8>>> {
43 pub local_peer_id: PeerId,
45
46 pub replication_factor: usize,
48
49 pub parallelism_factor: usize,
51
52 pub query: QueryId,
54
55 pub target: Key<T>,
57}
58
59#[derive(Debug)]
61pub struct FindNodeContext<T: Clone + Into<Vec<u8>>> {
62 pub config: FindNodeConfig<T>,
64
65 kad_message: Bytes,
67
68 pub pending: HashMap<PeerId, (KademliaPeer, std::time::Instant)>,
70
71 pub queried: HashSet<PeerId>,
76
77 pub candidates: BTreeMap<Distance, KademliaPeer>,
79
80 pub responses: BTreeMap<Distance, KademliaPeer>,
82
83 peer_timeout: std::time::Duration,
89 pending_responses: usize,
94}
95
96impl<T: Clone + Into<Vec<u8>>> FindNodeContext<T> {
97 pub fn new(config: FindNodeConfig<T>, in_peers: VecDeque<KademliaPeer>) -> Self {
99 let mut candidates = BTreeMap::new();
100
101 for candidate in &in_peers {
102 let distance = config.target.distance(&candidate.key);
103 candidates.insert(distance, candidate.clone());
104 }
105
106 let kad_message = KademliaMessage::find_node(config.target.clone().into_preimage());
107
108 Self {
109 config,
110 kad_message,
111
112 candidates,
113 pending: HashMap::new(),
114 queried: HashSet::new(),
115 responses: BTreeMap::new(),
116
117 peer_timeout: DEFAULT_PEER_TIMEOUT,
118 pending_responses: 0,
119 }
120 }
121
122 pub fn register_response_failure(&mut self, peer: PeerId) {
124 let Some((peer, instant)) = self.pending.remove(&peer) else {
125 tracing::debug!(target: LOG_TARGET, query = ?self.config.query, ?peer, "pending peer doesn't exist during response failure");
126 return;
127 };
128 self.pending_responses = self.pending_responses.saturating_sub(1);
129
130 tracing::trace!(target: LOG_TARGET, query = ?self.config.query, ?peer, elapsed = ?instant.elapsed(), "peer failed to respond");
131
132 self.queried.insert(peer.peer);
133 }
134
135 pub fn register_response(&mut self, peer: PeerId, peers: Vec<KademliaPeer>) {
137 let Some((peer, instant)) = self.pending.remove(&peer) else {
138 tracing::debug!(target: LOG_TARGET, query = ?self.config.query, ?peer, "received response from peer but didn't expect it");
139 return;
140 };
141 self.pending_responses = self.pending_responses.saturating_sub(1);
142
143 tracing::trace!(target: LOG_TARGET, query = ?self.config.query, ?peer, elapsed = ?instant.elapsed(), "received response from peer");
144
145 let distance = self.config.target.distance(&peer.key);
149
150 self.queried.insert(peer.peer);
152
153 if self.responses.len() < self.config.replication_factor {
154 self.responses.insert(distance, peer);
155 } else {
156 let furthest_distance =
159 self.responses.last_entry().map(|entry| *entry.key()).unwrap_or(distance);
160
161 if distance < furthest_distance {
163 self.responses.insert(distance, peer);
164
165 if self.responses.len() > self.config.replication_factor {
167 self.responses.pop_last();
168 }
169 }
170 }
171
172 let to_query_candidate = peers.into_iter().filter_map(|peer| {
173 if self.queried.contains(&peer.peer) {
175 return None;
176 }
177
178 if self.pending.contains_key(&peer.peer) {
180 return None;
181 }
182
183 if self.config.local_peer_id == peer.peer {
185 return None;
186 }
187
188 Some(peer)
189 });
190
191 for candidate in to_query_candidate {
192 let distance = self.config.target.distance(&candidate.key);
193 self.candidates.insert(distance, candidate);
194 }
195 }
196
197 pub fn next_peer_action(&mut self, peer: &PeerId) -> Option<QueryAction> {
199 self.pending.contains_key(peer).then_some(QueryAction::SendMessage {
200 query: self.config.query,
201 peer: *peer,
202 message: self.kad_message.clone(),
203 })
204 }
205
206 fn schedule_next_peer(&mut self) -> Option<QueryAction> {
208 tracing::trace!(target: LOG_TARGET, query = ?self.config.query, "get next peer");
209
210 let (_, candidate) = self.candidates.pop_first()?;
211 let peer = candidate.peer;
212
213 tracing::trace!(target: LOG_TARGET, query = ?self.config.query, ?peer, "current candidate");
214 self.pending.insert(candidate.peer, (candidate, std::time::Instant::now()));
215 self.pending_responses = self.pending_responses.saturating_add(1);
216
217 Some(QueryAction::SendMessage {
218 query: self.config.query,
219 peer,
220 message: self.kad_message.clone(),
221 })
222 }
223
224 fn is_done(&self) -> bool {
228 self.pending.is_empty() && self.candidates.is_empty()
229 }
230
231 pub fn next_action(&mut self) -> Option<QueryAction> {
233 if self.is_done() {
236 tracing::trace!(
237 target: LOG_TARGET,
238 query = ?self.config.query,
239 pending = self.pending.len(),
240 candidates = self.candidates.len(),
241 "query finished"
242 );
243
244 return if self.responses.is_empty() {
245 Some(QueryAction::QueryFailed {
246 query: self.config.query,
247 })
248 } else {
249 Some(QueryAction::QuerySucceeded {
250 query: self.config.query,
251 })
252 };
253 }
254
255 for (peer, instant) in self.pending.values() {
256 if instant.elapsed() > self.peer_timeout {
257 tracing::trace!(
258 target: LOG_TARGET,
259 query = ?self.config.query,
260 ?peer,
261 elapsed = ?instant.elapsed(),
262 "peer no longer counting towards parallelism factor"
263 );
264 self.pending_responses = self.pending_responses.saturating_sub(1);
265 }
266 }
267
268 if self.pending_responses == self.config.parallelism_factor {
271 return None;
272 }
273
274 if self.responses.len() < self.config.replication_factor {
276 return self.schedule_next_peer();
277 }
278
279 match (
281 self.candidates.first_key_value(),
282 self.responses.last_key_value(),
283 ) {
284 (Some((_, candidate_peer)), Some((worst_response_distance, _))) => {
285 let first_candidate_distance = self.config.target.distance(&candidate_peer.key);
286 if first_candidate_distance < *worst_response_distance {
287 return self.schedule_next_peer();
288 }
289 }
290
291 _ => (),
292 }
293
294 Some(QueryAction::QuerySucceeded {
296 query: self.config.query,
297 })
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use crate::protocol::libp2p::kademlia::types::ConnectionType;
305
306 fn default_config() -> FindNodeConfig<Vec<u8>> {
307 FindNodeConfig {
308 local_peer_id: PeerId::random(),
309 replication_factor: 20,
310 parallelism_factor: 10,
311 query: QueryId(0),
312 target: Key::new(vec![1, 2, 3].into()),
313 }
314 }
315
316 fn peer_to_kad(peer: PeerId) -> KademliaPeer {
317 KademliaPeer {
318 peer,
319 key: Key::from(peer),
320 addresses: vec![],
321 connection: ConnectionType::Connected,
322 }
323 }
324
325 fn setup_closest_responses() -> (PeerId, PeerId, FindNodeConfig<PeerId>) {
326 let peer_a = PeerId::random();
327 let peer_b = PeerId::random();
328 let target = PeerId::random();
329
330 let distance_a = Key::from(peer_a).distance(&Key::from(target));
331 let distance_b = Key::from(peer_b).distance(&Key::from(target));
332
333 let (closest, furthest) = if distance_a < distance_b {
334 (peer_a, peer_b)
335 } else {
336 (peer_b, peer_a)
337 };
338
339 let config = FindNodeConfig {
340 parallelism_factor: 1,
341 replication_factor: 1,
342 target: Key::from(target),
343 local_peer_id: PeerId::random(),
344 query: QueryId(0),
345 };
346
347 (closest, furthest, config)
348 }
349
350 #[test]
351 fn completes_when_no_candidates() {
352 let config = default_config();
353 let mut context = FindNodeContext::new(config, VecDeque::new());
354 assert!(context.is_done());
355 let event = context.next_action().unwrap();
356 assert_eq!(event, QueryAction::QueryFailed { query: QueryId(0) });
357 }
358
359 #[test]
360 fn fulfill_parallelism() {
361 let config = FindNodeConfig {
362 parallelism_factor: 3,
363 ..default_config()
364 };
365
366 let in_peers_set = (0..3).map(|_| PeerId::random()).collect::<HashSet<_>>();
367 let in_peers = in_peers_set.iter().map(|peer| peer_to_kad(*peer)).collect();
368 let mut context = FindNodeContext::new(config, in_peers);
369
370 for num in 0..3 {
371 let event = context.next_action().unwrap();
372 match event {
373 QueryAction::SendMessage { query, peer, .. } => {
374 assert_eq!(query, QueryId(0));
375 assert_eq!(context.pending.len(), num + 1);
377 assert!(context.pending.contains_key(&peer));
378
379 assert!(in_peers_set.contains(&peer));
381 }
382 _ => panic!("Unexpected event"),
383 }
384 }
385
386 assert!(context.next_action().is_none());
388 }
389
390 #[test]
391 fn fulfill_parallelism_with_timeout_optimization() {
392 let config = FindNodeConfig {
393 parallelism_factor: 3,
394 ..default_config()
395 };
396
397 let in_peers_set = (0..4).map(|_| PeerId::random()).collect::<HashSet<_>>();
398 let in_peers = in_peers_set.iter().map(|peer| peer_to_kad(*peer)).collect();
399 let mut context = FindNodeContext::new(config, in_peers);
400 context.peer_timeout = std::time::Duration::from_secs(1);
402
403 for num in 0..3 {
404 let event = context.next_action().unwrap();
405 match event {
406 QueryAction::SendMessage { query, peer, .. } => {
407 assert_eq!(query, QueryId(0));
408 assert_eq!(context.pending.len(), num + 1);
410 assert!(context.pending.contains_key(&peer));
411
412 assert!(in_peers_set.contains(&peer));
414 }
415 _ => panic!("Unexpected event"),
416 }
417 }
418
419 assert!(context.next_action().is_none());
421
422 std::thread::sleep(std::time::Duration::from_secs(2));
424
425 assert_eq!(context.pending_responses, 3);
427 assert_eq!(context.pending.len(), 3);
428
429 let event = context.next_action().unwrap();
431 match event {
432 QueryAction::SendMessage { query, peer, .. } => {
433 assert_eq!(query, QueryId(0));
434 assert_eq!(context.pending.len(), 4);
436 assert!(context.pending.contains_key(&peer));
437
438 assert!(in_peers_set.contains(&peer));
440 }
441 _ => panic!("Unexpected event"),
442 }
443
444 assert_eq!(context.pending_responses, 1);
445 assert_eq!(context.pending.len(), 4);
446 }
447
448 #[test]
449 fn completes_when_responses() {
450 let config = FindNodeConfig {
451 parallelism_factor: 3,
452 replication_factor: 3,
453 ..default_config()
454 };
455
456 let peer_a = PeerId::random();
457 let peer_b = PeerId::random();
458 let peer_c = PeerId::random();
459
460 let in_peers_set: HashSet<_> = [peer_a, peer_b, peer_c].into_iter().collect();
461 assert_eq!(in_peers_set.len(), 3);
462
463 let in_peers = [peer_a, peer_b, peer_c].iter().map(|peer| peer_to_kad(*peer)).collect();
464 let mut context = FindNodeContext::new(config, in_peers);
465
466 for num in 0..3 {
468 let event = context.next_action().unwrap();
469 match event {
470 QueryAction::SendMessage { query, peer, .. } => {
471 assert_eq!(query, QueryId(0));
472 assert_eq!(context.pending.len(), num + 1);
474 assert!(context.pending.contains_key(&peer));
475
476 assert!(in_peers_set.contains(&peer));
478 }
479 _ => panic!("Unexpected event"),
480 }
481 }
482
483 let peer_d = PeerId::random();
485 context.register_response_failure(peer_d);
486 assert_eq!(context.pending.len(), 3);
487 assert!(context.queried.is_empty());
488
489 context.register_response(peer_a, vec![]);
491 assert_eq!(context.pending.len(), 2);
492 assert_eq!(context.queried.len(), 1);
493 assert_eq!(context.responses.len(), 1);
494
495 context.register_response(peer_b, vec![peer_to_kad(peer_d.clone())]);
497 assert_eq!(context.pending.len(), 1);
498 assert_eq!(context.queried.len(), 2);
499 assert_eq!(context.responses.len(), 2);
500 assert_eq!(context.candidates.len(), 1);
501
502 context.register_response_failure(peer_c);
504 assert!(context.pending.is_empty());
505 assert_eq!(context.queried.len(), 3);
506 assert_eq!(context.responses.len(), 2);
507
508 let event = context.next_action().unwrap();
510 match event {
511 QueryAction::SendMessage { query, peer, .. } => {
512 assert_eq!(query, QueryId(0));
513 assert_eq!(context.pending.len(), 1);
515 assert_eq!(peer, peer_d);
516 }
517 _ => panic!("Unexpected event"),
518 }
519
520 context.register_response(peer_d, vec![]);
522
523 let event = context.next_action().unwrap();
525 assert_eq!(event, QueryAction::QuerySucceeded { query: QueryId(0) });
526 }
527
528 #[test]
529 fn offers_closest_responses() {
530 let (closest, furthest, config) = setup_closest_responses();
531
532 let in_peers = vec![peer_to_kad(furthest), peer_to_kad(closest)];
534 let mut context = FindNodeContext::new(config.clone(), in_peers.into_iter().collect());
535
536 let event = context.next_action().unwrap();
537 match event {
538 QueryAction::SendMessage { query, peer, .. } => {
539 assert_eq!(query, QueryId(0));
540 assert_eq!(context.pending.len(), 1);
542 assert!(context.pending.contains_key(&peer));
543
544 assert_eq!(closest, peer);
546 }
547 _ => panic!("Unexpected event"),
548 }
549
550 context.register_response(closest, vec![]);
551
552 let event = context.next_action().unwrap();
553 assert_eq!(event, QueryAction::QuerySucceeded { query: QueryId(0) });
554 }
555
556 #[test]
557 fn offers_closest_responses_with_better_candidates() {
558 let (closest, furthest, config) = setup_closest_responses();
559
560 let in_peers = vec![peer_to_kad(furthest)];
563 let mut context = FindNodeContext::new(config, in_peers.into_iter().collect());
564
565 let event = context.next_action().unwrap();
566 match event {
567 QueryAction::SendMessage { query, peer, .. } => {
568 assert_eq!(query, QueryId(0));
569 assert_eq!(context.pending.len(), 1);
571 assert!(context.pending.contains_key(&peer));
572
573 assert_eq!(furthest, peer);
575 }
576 _ => panic!("Unexpected event"),
577 }
578
579 context.register_response(furthest, vec![peer_to_kad(closest)]);
582
583 let event = context.next_action().unwrap();
584 match event {
585 QueryAction::SendMessage { query, peer, .. } => {
586 assert_eq!(query, QueryId(0));
587 assert_eq!(context.pending.len(), 1);
589 assert!(context.pending.contains_key(&peer));
590
591 assert_eq!(closest, peer);
593 }
594 _ => panic!("Unexpected event"),
595 }
596
597 assert!(context.next_action().is_none());
600
601 context.register_response(closest, vec![]);
603
604 let event = context.next_action().unwrap();
605 assert_eq!(event, QueryAction::QuerySucceeded { query: QueryId(0) });
606 }
607
608 #[test]
609 fn keep_k_best_results() {
610 let mut peers = (0..6).map(|_| PeerId::random()).collect::<Vec<_>>();
611 let target = Key::from(PeerId::random());
612 peers.sort_by_key(|peer| std::cmp::Reverse(target.distance(&Key::from(*peer))));
614
615 let config = FindNodeConfig {
616 parallelism_factor: 3,
617 replication_factor: 3,
618 target,
619 local_peer_id: PeerId::random(),
620 query: QueryId(0),
621 };
622
623 let in_peers = vec![peers[0], peers[1], peers[2]]
624 .iter()
625 .map(|peer| peer_to_kad(*peer))
626 .collect();
627 let mut context = FindNodeContext::new(config, in_peers);
628
629 for num in 0..3 {
631 let event = context.next_action().unwrap();
632 match event {
633 QueryAction::SendMessage { query, peer, .. } => {
634 assert_eq!(query, QueryId(0));
635 assert_eq!(context.pending.len(), num + 1);
637 assert!(context.pending.contains_key(&peer));
638 }
639 _ => panic!("Unexpected event"),
640 }
641 }
642
643 context.register_response(peers[0], vec![peer_to_kad(peers[3])]);
645 context.register_response(peers[1], vec![peer_to_kad(peers[4])]);
646 context.register_response(peers[2], vec![peer_to_kad(peers[5])]);
647
648 for num in 0..3 {
650 let event = context.next_action().unwrap();
651 match event {
652 QueryAction::SendMessage { query, peer, .. } => {
653 assert_eq!(query, QueryId(0));
654 assert_eq!(context.pending.len(), num + 1);
656 assert!(context.pending.contains_key(&peer));
657 }
658 _ => panic!("Unexpected event"),
659 }
660 }
661
662 context.register_response(peers[3], vec![]);
663 context.register_response(peers[4], vec![]);
664 context.register_response(peers[5], vec![]);
665
666 let event = context.next_action().unwrap();
668 assert_eq!(event, QueryAction::QuerySucceeded { query: QueryId(0) });
669
670 let responses = context.responses.values().map(|peer| peer.peer).collect::<Vec<_>>();
676 assert_eq!(responses, [peers[5], peers[4], peers[3]]);
679 }
680}