use crate::{types::ConnectionId, PeerId};
use multiaddr::{Multiaddr, Protocol};
use multihash::Multihash;
use std::collections::{BinaryHeap, HashSet};
#[allow(clippy::derived_hash_with_manual_eq)]
#[derive(Debug, Clone, Hash)]
pub struct AddressRecord {
score: i32,
address: Multiaddr,
connection_id: Option<ConnectionId>,
}
impl AsRef<Multiaddr> for AddressRecord {
fn as_ref(&self) -> &Multiaddr {
&self.address
}
}
impl AddressRecord {
pub fn new(
peer: &PeerId,
address: Multiaddr,
score: i32,
connection_id: Option<ConnectionId>,
) -> Self {
let address = if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) {
address.with(Protocol::P2p(
Multihash::from_bytes(&peer.to_bytes()).expect("valid peer id"),
))
} else {
address
};
Self {
address,
score,
connection_id,
}
}
pub fn from_multiaddr(address: Multiaddr) -> Option<AddressRecord> {
if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) {
return None;
}
Some(AddressRecord {
address,
score: 0i32,
connection_id: None,
})
}
#[cfg(test)]
pub fn score(&self) -> i32 {
self.score
}
pub fn address(&self) -> &Multiaddr {
&self.address
}
pub fn connection_id(&self) -> &Option<ConnectionId> {
&self.connection_id
}
pub fn update_score(&mut self, score: i32) {
self.score = self.score.saturating_add(score);
}
pub fn set_connection_id(&mut self, connection_id: ConnectionId) {
self.connection_id = Some(connection_id);
}
}
impl PartialEq for AddressRecord {
fn eq(&self, other: &Self) -> bool {
self.score.eq(&other.score)
}
}
impl Eq for AddressRecord {}
impl PartialOrd for AddressRecord {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.score.cmp(&other.score))
}
}
impl Ord for AddressRecord {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.score.cmp(&other.score)
}
}
#[derive(Debug)]
pub struct AddressStore {
pub by_score: BinaryHeap<AddressRecord>,
pub by_address: HashSet<Multiaddr>,
}
impl FromIterator<Multiaddr> for AddressStore {
fn from_iter<T: IntoIterator<Item = Multiaddr>>(iter: T) -> Self {
let mut store = AddressStore::new();
for address in iter {
if let Some(address) = AddressRecord::from_multiaddr(address) {
store.insert(address);
}
}
store
}
}
impl FromIterator<AddressRecord> for AddressStore {
fn from_iter<T: IntoIterator<Item = AddressRecord>>(iter: T) -> Self {
let mut store = AddressStore::new();
for record in iter {
store.by_address.insert(record.address.clone());
store.by_score.push(record);
}
store
}
}
impl Extend<AddressRecord> for AddressStore {
fn extend<T: IntoIterator<Item = AddressRecord>>(&mut self, iter: T) {
for record in iter {
self.insert(record)
}
}
}
impl<'a> Extend<&'a AddressRecord> for AddressStore {
fn extend<T: IntoIterator<Item = &'a AddressRecord>>(&mut self, iter: T) {
for record in iter {
self.insert(record.clone())
}
}
}
impl AddressStore {
pub fn new() -> Self {
Self {
by_score: BinaryHeap::new(),
by_address: HashSet::new(),
}
}
pub fn is_empty(&self) -> bool {
self.by_score.is_empty()
}
pub fn contains(&self, address: &Multiaddr) -> bool {
self.by_address.contains(address)
}
pub fn insert(&mut self, mut record: AddressRecord) {
if self.by_address.contains(record.address()) {
return;
}
record.connection_id = None;
self.by_address.insert(record.address.clone());
self.by_score.push(record);
}
pub fn pop(&mut self) -> Option<AddressRecord> {
self.by_score.pop().map(|record| {
self.by_address.remove(&record.address);
record
})
}
pub fn take(&mut self, limit: usize) -> Vec<AddressRecord> {
let mut records = Vec::new();
for _ in 0..limit {
match self.pop() {
Some(record) => records.push(record),
None => break,
}
}
records
}
}
#[cfg(test)]
mod tests {
use std::{
collections::HashMap,
net::{Ipv4Addr, SocketAddrV4},
};
use super::*;
use rand::{rngs::ThreadRng, Rng};
fn tcp_address_record(rng: &mut ThreadRng) -> AddressRecord {
let peer = PeerId::random();
let address = std::net::SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(
rng.gen_range(1..=255),
rng.gen_range(0..=255),
rng.gen_range(0..=255),
rng.gen_range(0..=255),
),
rng.gen_range(1..=65535),
));
let score: i32 = rng.gen();
AddressRecord::new(
&peer,
Multiaddr::empty()
.with(Protocol::from(address.ip()))
.with(Protocol::Tcp(address.port())),
score,
None,
)
}
fn ws_address_record(rng: &mut ThreadRng) -> AddressRecord {
let peer = PeerId::random();
let address = std::net::SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(
rng.gen_range(1..=255),
rng.gen_range(0..=255),
rng.gen_range(0..=255),
rng.gen_range(0..=255),
),
rng.gen_range(1..=65535),
));
let score: i32 = rng.gen();
AddressRecord::new(
&peer,
Multiaddr::empty()
.with(Protocol::from(address.ip()))
.with(Protocol::Tcp(address.port()))
.with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))),
score,
None,
)
}
fn quic_address_record(rng: &mut ThreadRng) -> AddressRecord {
let peer = PeerId::random();
let address = std::net::SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::new(
rng.gen_range(1..=255),
rng.gen_range(0..=255),
rng.gen_range(0..=255),
rng.gen_range(0..=255),
),
rng.gen_range(1..=65535),
));
let score: i32 = rng.gen();
AddressRecord::new(
&peer,
Multiaddr::empty()
.with(Protocol::from(address.ip()))
.with(Protocol::Udp(address.port()))
.with(Protocol::QuicV1),
score,
None,
)
}
#[test]
fn take_multiple_records() {
let mut store = AddressStore::new();
let mut rng = rand::thread_rng();
for _ in 0..rng.gen_range(1..5) {
store.insert(tcp_address_record(&mut rng));
}
for _ in 0..rng.gen_range(1..5) {
store.insert(ws_address_record(&mut rng));
}
for _ in 0..rng.gen_range(1..5) {
store.insert(quic_address_record(&mut rng));
}
let known_addresses = store.by_address.len();
assert!(known_addresses >= 3);
let taken = store.take(known_addresses - 2);
assert_eq!(known_addresses - 2, taken.len());
assert!(!store.is_empty());
let mut prev: Option<AddressRecord> = None;
for record in taken {
assert!(!store.contains(record.address()));
if let Some(previous) = prev {
assert!(previous.score > record.score);
}
prev = Some(record);
}
}
#[test]
fn attempt_to_take_excess_records() {
let mut store = AddressStore::new();
let mut rng = rand::thread_rng();
store.insert(tcp_address_record(&mut rng));
store.insert(ws_address_record(&mut rng));
store.insert(quic_address_record(&mut rng));
assert_eq!(store.by_address.len(), 3);
let taken = store.take(8usize);
assert_eq!(taken.len(), 3);
assert!(store.is_empty());
let mut prev: Option<AddressRecord> = None;
for record in taken {
if prev.is_none() {
prev = Some(record);
} else {
assert!(prev.unwrap().score > record.score);
prev = Some(record);
}
}
}
#[test]
fn extend_from_iterator() {
let mut store = AddressStore::new();
let mut rng = rand::thread_rng();
let records = (0..10)
.map(|i| {
if i % 2 == 0 {
tcp_address_record(&mut rng)
} else if i % 3 == 0 {
quic_address_record(&mut rng)
} else {
ws_address_record(&mut rng)
}
})
.collect::<Vec<_>>();
assert!(store.is_empty());
let cloned = records
.iter()
.cloned()
.map(|record| (record.address().clone(), record))
.collect::<HashMap<_, _>>();
store.extend(records);
for record in store.by_score {
let stored = cloned.get(record.address()).unwrap();
assert_eq!(stored.score(), record.score());
assert_eq!(stored.connection_id(), record.connection_id());
assert_eq!(stored.address(), record.address());
}
}
#[test]
fn extend_from_iterator_ref() {
let mut store = AddressStore::new();
let mut rng = rand::thread_rng();
let records = (0..10)
.map(|i| {
if i % 2 == 0 {
let record = tcp_address_record(&mut rng);
(record.address().clone(), record)
} else if i % 3 == 0 {
let record = quic_address_record(&mut rng);
(record.address().clone(), record)
} else {
let record = ws_address_record(&mut rng);
(record.address().clone(), record)
}
})
.collect::<Vec<_>>();
assert!(store.is_empty());
let cloned = records.iter().cloned().collect::<HashMap<_, _>>();
store.extend(records.iter().map(|(_, record)| record));
for record in store.by_score {
let stored = cloned.get(record.address()).unwrap();
assert_eq!(stored.score(), record.score());
assert_eq!(stored.connection_id(), record.connection_id());
assert_eq!(stored.address(), record.address());
}
}
}