use bytes::Bytes;
use crate::{
protocol::libp2p::kademlia::{
message::KademliaMessage,
query::{QueryAction, QueryId},
types::{Distance, KademliaPeer, Key},
},
PeerId,
};
use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
const LOG_TARGET: &str = "litep2p::ipfs::kademlia::query::find_node";
const DEFAULT_PEER_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
#[derive(Debug, Clone)]
pub struct FindNodeConfig<T: Clone + Into<Vec<u8>>> {
pub local_peer_id: PeerId,
pub replication_factor: usize,
pub parallelism_factor: usize,
pub query: QueryId,
pub target: Key<T>,
}
#[derive(Debug)]
pub struct FindNodeContext<T: Clone + Into<Vec<u8>>> {
pub config: FindNodeConfig<T>,
kad_message: Bytes,
pub pending: HashMap<PeerId, (KademliaPeer, std::time::Instant)>,
pub queried: HashSet<PeerId>,
pub candidates: BTreeMap<Distance, KademliaPeer>,
pub responses: BTreeMap<Distance, KademliaPeer>,
peer_timeout: std::time::Duration,
pending_responses: usize,
}
impl<T: Clone + Into<Vec<u8>>> FindNodeContext<T> {
pub fn new(config: FindNodeConfig<T>, in_peers: VecDeque<KademliaPeer>) -> Self {
let mut candidates = BTreeMap::new();
for candidate in &in_peers {
let distance = config.target.distance(&candidate.key);
candidates.insert(distance, candidate.clone());
}
let kad_message = KademliaMessage::find_node(config.target.clone().into_preimage());
Self {
config,
kad_message,
candidates,
pending: HashMap::new(),
queried: HashSet::new(),
responses: BTreeMap::new(),
peer_timeout: DEFAULT_PEER_TIMEOUT,
pending_responses: 0,
}
}
pub fn register_response_failure(&mut self, peer: PeerId) {
let Some((peer, instant)) = self.pending.remove(&peer) else {
tracing::debug!(target: LOG_TARGET, query = ?self.config.query, ?peer, "pending peer doesn't exist during response failure");
return;
};
self.pending_responses = self.pending_responses.saturating_sub(1);
tracing::trace!(target: LOG_TARGET, query = ?self.config.query, ?peer, elapsed = ?instant.elapsed(), "peer failed to respond");
self.queried.insert(peer.peer);
}
pub fn register_response(&mut self, peer: PeerId, peers: Vec<KademliaPeer>) {
let Some((peer, instant)) = self.pending.remove(&peer) else {
tracing::debug!(target: LOG_TARGET, query = ?self.config.query, ?peer, "received response from peer but didn't expect it");
return;
};
self.pending_responses = self.pending_responses.saturating_sub(1);
tracing::trace!(target: LOG_TARGET, query = ?self.config.query, ?peer, elapsed = ?instant.elapsed(), "received response from peer");
let distance = self.config.target.distance(&peer.key);
self.queried.insert(peer.peer);
if self.responses.len() < self.config.replication_factor {
self.responses.insert(distance, peer);
} else {
let furthest_distance =
self.responses.last_entry().map(|entry| *entry.key()).unwrap_or(distance);
if distance < furthest_distance {
self.responses.insert(distance, peer);
if self.responses.len() > self.config.replication_factor {
self.responses.pop_last();
}
}
}
let to_query_candidate = peers.into_iter().filter_map(|peer| {
if self.queried.contains(&peer.peer) {
return None;
}
if self.pending.contains_key(&peer.peer) {
return None;
}
if self.config.local_peer_id == peer.peer {
return None;
}
Some(peer)
});
for candidate in to_query_candidate {
let distance = self.config.target.distance(&candidate.key);
self.candidates.insert(distance, candidate);
}
}
pub fn next_peer_action(&mut self, peer: &PeerId) -> Option<QueryAction> {
self.pending.contains_key(peer).then_some(QueryAction::SendMessage {
query: self.config.query,
peer: *peer,
message: self.kad_message.clone(),
})
}
fn schedule_next_peer(&mut self) -> Option<QueryAction> {
tracing::trace!(target: LOG_TARGET, query = ?self.config.query, "get next peer");
let (_, candidate) = self.candidates.pop_first()?;
let peer = candidate.peer;
tracing::trace!(target: LOG_TARGET, query = ?self.config.query, ?peer, "current candidate");
self.pending.insert(candidate.peer, (candidate, std::time::Instant::now()));
self.pending_responses = self.pending_responses.saturating_add(1);
Some(QueryAction::SendMessage {
query: self.config.query,
peer,
message: self.kad_message.clone(),
})
}
fn is_done(&self) -> bool {
self.pending.is_empty() && self.candidates.is_empty()
}
pub fn next_action(&mut self) -> Option<QueryAction> {
if self.is_done() {
tracing::trace!(
target: LOG_TARGET,
query = ?self.config.query,
pending = self.pending.len(),
candidates = self.candidates.len(),
"query finished"
);
return if self.responses.is_empty() {
Some(QueryAction::QueryFailed {
query: self.config.query,
})
} else {
Some(QueryAction::QuerySucceeded {
query: self.config.query,
})
};
}
for (peer, instant) in self.pending.values() {
if instant.elapsed() > self.peer_timeout {
tracing::trace!(
target: LOG_TARGET,
query = ?self.config.query,
?peer,
elapsed = ?instant.elapsed(),
"peer no longer counting towards parallelism factor"
);
self.pending_responses = self.pending_responses.saturating_sub(1);
}
}
if self.pending_responses == self.config.parallelism_factor {
return None;
}
if self.responses.len() < self.config.replication_factor {
return self.schedule_next_peer();
}
match (
self.candidates.first_key_value(),
self.responses.last_key_value(),
) {
(Some((_, candidate_peer)), Some((worst_response_distance, _))) => {
let first_candidate_distance = self.config.target.distance(&candidate_peer.key);
if first_candidate_distance < *worst_response_distance {
return self.schedule_next_peer();
}
}
_ => (),
}
Some(QueryAction::QuerySucceeded {
query: self.config.query,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::libp2p::kademlia::types::ConnectionType;
fn default_config() -> FindNodeConfig<Vec<u8>> {
FindNodeConfig {
local_peer_id: PeerId::random(),
replication_factor: 20,
parallelism_factor: 10,
query: QueryId(0),
target: Key::new(vec![1, 2, 3].into()),
}
}
fn peer_to_kad(peer: PeerId) -> KademliaPeer {
KademliaPeer {
peer,
key: Key::from(peer),
addresses: vec![],
connection: ConnectionType::Connected,
}
}
fn setup_closest_responses() -> (PeerId, PeerId, FindNodeConfig<PeerId>) {
let peer_a = PeerId::random();
let peer_b = PeerId::random();
let target = PeerId::random();
let distance_a = Key::from(peer_a).distance(&Key::from(target));
let distance_b = Key::from(peer_b).distance(&Key::from(target));
let (closest, furthest) = if distance_a < distance_b {
(peer_a, peer_b)
} else {
(peer_b, peer_a)
};
let config = FindNodeConfig {
parallelism_factor: 1,
replication_factor: 1,
target: Key::from(target),
local_peer_id: PeerId::random(),
query: QueryId(0),
};
(closest, furthest, config)
}
#[test]
fn completes_when_no_candidates() {
let config = default_config();
let mut context = FindNodeContext::new(config, VecDeque::new());
assert!(context.is_done());
let event = context.next_action().unwrap();
assert_eq!(event, QueryAction::QueryFailed { query: QueryId(0) });
}
#[test]
fn fulfill_parallelism() {
let config = FindNodeConfig {
parallelism_factor: 3,
..default_config()
};
let in_peers_set = (0..3).map(|_| PeerId::random()).collect::<HashSet<_>>();
let in_peers = in_peers_set.iter().map(|peer| peer_to_kad(*peer)).collect();
let mut context = FindNodeContext::new(config, in_peers);
for num in 0..3 {
let event = context.next_action().unwrap();
match event {
QueryAction::SendMessage { query, peer, .. } => {
assert_eq!(query, QueryId(0));
assert_eq!(context.pending.len(), num + 1);
assert!(context.pending.contains_key(&peer));
assert!(in_peers_set.contains(&peer));
}
_ => panic!("Unexpected event"),
}
}
assert!(context.next_action().is_none());
}
#[test]
fn fulfill_parallelism_with_timeout_optimization() {
let config = FindNodeConfig {
parallelism_factor: 3,
..default_config()
};
let in_peers_set = (0..4).map(|_| PeerId::random()).collect::<HashSet<_>>();
let in_peers = in_peers_set.iter().map(|peer| peer_to_kad(*peer)).collect();
let mut context = FindNodeContext::new(config, in_peers);
context.peer_timeout = std::time::Duration::from_secs(1);
for num in 0..3 {
let event = context.next_action().unwrap();
match event {
QueryAction::SendMessage { query, peer, .. } => {
assert_eq!(query, QueryId(0));
assert_eq!(context.pending.len(), num + 1);
assert!(context.pending.contains_key(&peer));
assert!(in_peers_set.contains(&peer));
}
_ => panic!("Unexpected event"),
}
}
assert!(context.next_action().is_none());
std::thread::sleep(std::time::Duration::from_secs(2));
assert_eq!(context.pending_responses, 3);
assert_eq!(context.pending.len(), 3);
let event = context.next_action().unwrap();
match event {
QueryAction::SendMessage { query, peer, .. } => {
assert_eq!(query, QueryId(0));
assert_eq!(context.pending.len(), 4);
assert!(context.pending.contains_key(&peer));
assert!(in_peers_set.contains(&peer));
}
_ => panic!("Unexpected event"),
}
assert_eq!(context.pending_responses, 1);
assert_eq!(context.pending.len(), 4);
}
#[test]
fn completes_when_responses() {
let config = FindNodeConfig {
parallelism_factor: 3,
replication_factor: 3,
..default_config()
};
let peer_a = PeerId::random();
let peer_b = PeerId::random();
let peer_c = PeerId::random();
let in_peers_set: HashSet<_> = [peer_a, peer_b, peer_c].into_iter().collect();
assert_eq!(in_peers_set.len(), 3);
let in_peers = [peer_a, peer_b, peer_c].iter().map(|peer| peer_to_kad(*peer)).collect();
let mut context = FindNodeContext::new(config, in_peers);
for num in 0..3 {
let event = context.next_action().unwrap();
match event {
QueryAction::SendMessage { query, peer, .. } => {
assert_eq!(query, QueryId(0));
assert_eq!(context.pending.len(), num + 1);
assert!(context.pending.contains_key(&peer));
assert!(in_peers_set.contains(&peer));
}
_ => panic!("Unexpected event"),
}
}
let peer_d = PeerId::random();
context.register_response_failure(peer_d);
assert_eq!(context.pending.len(), 3);
assert!(context.queried.is_empty());
context.register_response(peer_a, vec![]);
assert_eq!(context.pending.len(), 2);
assert_eq!(context.queried.len(), 1);
assert_eq!(context.responses.len(), 1);
context.register_response(peer_b, vec![peer_to_kad(peer_d.clone())]);
assert_eq!(context.pending.len(), 1);
assert_eq!(context.queried.len(), 2);
assert_eq!(context.responses.len(), 2);
assert_eq!(context.candidates.len(), 1);
context.register_response_failure(peer_c);
assert!(context.pending.is_empty());
assert_eq!(context.queried.len(), 3);
assert_eq!(context.responses.len(), 2);
let event = context.next_action().unwrap();
match event {
QueryAction::SendMessage { query, peer, .. } => {
assert_eq!(query, QueryId(0));
assert_eq!(context.pending.len(), 1);
assert_eq!(peer, peer_d);
}
_ => panic!("Unexpected event"),
}
context.register_response(peer_d, vec![]);
let event = context.next_action().unwrap();
assert_eq!(event, QueryAction::QuerySucceeded { query: QueryId(0) });
}
#[test]
fn offers_closest_responses() {
let (closest, furthest, config) = setup_closest_responses();
let in_peers = vec![peer_to_kad(furthest), peer_to_kad(closest)];
let mut context = FindNodeContext::new(config.clone(), in_peers.into_iter().collect());
let event = context.next_action().unwrap();
match event {
QueryAction::SendMessage { query, peer, .. } => {
assert_eq!(query, QueryId(0));
assert_eq!(context.pending.len(), 1);
assert!(context.pending.contains_key(&peer));
assert_eq!(closest, peer);
}
_ => panic!("Unexpected event"),
}
context.register_response(closest, vec![]);
let event = context.next_action().unwrap();
assert_eq!(event, QueryAction::QuerySucceeded { query: QueryId(0) });
}
#[test]
fn offers_closest_responses_with_better_candidates() {
let (closest, furthest, config) = setup_closest_responses();
let in_peers = vec![peer_to_kad(furthest)];
let mut context = FindNodeContext::new(config, in_peers.into_iter().collect());
let event = context.next_action().unwrap();
match event {
QueryAction::SendMessage { query, peer, .. } => {
assert_eq!(query, QueryId(0));
assert_eq!(context.pending.len(), 1);
assert!(context.pending.contains_key(&peer));
assert_eq!(furthest, peer);
}
_ => panic!("Unexpected event"),
}
context.register_response(furthest, vec![peer_to_kad(closest)]);
let event = context.next_action().unwrap();
match event {
QueryAction::SendMessage { query, peer, .. } => {
assert_eq!(query, QueryId(0));
assert_eq!(context.pending.len(), 1);
assert!(context.pending.contains_key(&peer));
assert_eq!(closest, peer);
}
_ => panic!("Unexpected event"),
}
assert!(context.next_action().is_none());
context.register_response(closest, vec![]);
let event = context.next_action().unwrap();
assert_eq!(event, QueryAction::QuerySucceeded { query: QueryId(0) });
}
#[test]
fn keep_k_best_results() {
let mut peers = (0..6).map(|_| PeerId::random()).collect::<Vec<_>>();
let target = Key::from(PeerId::random());
peers.sort_by_key(|peer| std::cmp::Reverse(target.distance(&Key::from(*peer))));
let config = FindNodeConfig {
parallelism_factor: 3,
replication_factor: 3,
target,
local_peer_id: PeerId::random(),
query: QueryId(0),
};
let in_peers = vec![peers[0], peers[1], peers[2]]
.iter()
.map(|peer| peer_to_kad(*peer))
.collect();
let mut context = FindNodeContext::new(config, in_peers);
for num in 0..3 {
let event = context.next_action().unwrap();
match event {
QueryAction::SendMessage { query, peer, .. } => {
assert_eq!(query, QueryId(0));
assert_eq!(context.pending.len(), num + 1);
assert!(context.pending.contains_key(&peer));
}
_ => panic!("Unexpected event"),
}
}
context.register_response(peers[0], vec![peer_to_kad(peers[3])]);
context.register_response(peers[1], vec![peer_to_kad(peers[4])]);
context.register_response(peers[2], vec![peer_to_kad(peers[5])]);
for num in 0..3 {
let event = context.next_action().unwrap();
match event {
QueryAction::SendMessage { query, peer, .. } => {
assert_eq!(query, QueryId(0));
assert_eq!(context.pending.len(), num + 1);
assert!(context.pending.contains_key(&peer));
}
_ => panic!("Unexpected event"),
}
}
context.register_response(peers[3], vec![]);
context.register_response(peers[4], vec![]);
context.register_response(peers[5], vec![]);
let event = context.next_action().unwrap();
assert_eq!(event, QueryAction::QuerySucceeded { query: QueryId(0) });
let responses = context.responses.values().map(|peer| peer.peer).collect::<Vec<_>>();
assert_eq!(responses, [peers[5], peers[4], peers[3]]);
}
}