litep2p/transport/manager/
address.rs1use crate::{types::ConnectionId, PeerId};
22
23use multiaddr::{Multiaddr, Protocol};
24use multihash::Multihash;
25
26use std::collections::{BinaryHeap, HashSet};
27
28#[allow(clippy::derived_hash_with_manual_eq)]
29#[derive(Debug, Clone, Hash)]
30pub struct AddressRecord {
31 score: i32,
33
34 address: Multiaddr,
36
37 connection_id: Option<ConnectionId>,
39}
40
41impl AsRef<Multiaddr> for AddressRecord {
42 fn as_ref(&self) -> &Multiaddr {
43 &self.address
44 }
45}
46
47impl AddressRecord {
48 pub fn new(
51 peer: &PeerId,
52 address: Multiaddr,
53 score: i32,
54 connection_id: Option<ConnectionId>,
55 ) -> Self {
56 let address = if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) {
57 address.with(Protocol::P2p(
58 Multihash::from_bytes(&peer.to_bytes()).expect("valid peer id"),
59 ))
60 } else {
61 address
62 };
63
64 Self {
65 address,
66 score,
67 connection_id,
68 }
69 }
70
71 pub fn from_multiaddr(address: Multiaddr) -> Option<AddressRecord> {
76 if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) {
77 return None;
78 }
79
80 Some(AddressRecord {
81 address,
82 score: 0i32,
83 connection_id: None,
84 })
85 }
86
87 #[cfg(test)]
89 pub fn score(&self) -> i32 {
90 self.score
91 }
92
93 pub fn address(&self) -> &Multiaddr {
95 &self.address
96 }
97
98 pub fn connection_id(&self) -> &Option<ConnectionId> {
100 &self.connection_id
101 }
102
103 pub fn update_score(&mut self, score: i32) {
105 self.score = self.score.saturating_add(score);
106 }
107
108 pub fn set_connection_id(&mut self, connection_id: ConnectionId) {
110 self.connection_id = Some(connection_id);
111 }
112}
113
114impl PartialEq for AddressRecord {
115 fn eq(&self, other: &Self) -> bool {
116 self.score.eq(&other.score)
117 }
118}
119
120impl Eq for AddressRecord {}
121
122impl PartialOrd for AddressRecord {
123 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
124 Some(self.score.cmp(&other.score))
125 }
126}
127
128impl Ord for AddressRecord {
129 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
130 self.score.cmp(&other.score)
131 }
132}
133
134#[derive(Debug)]
136pub struct AddressStore {
137 pub by_score: BinaryHeap<AddressRecord>,
139
140 pub by_address: HashSet<Multiaddr>,
142}
143
144impl FromIterator<Multiaddr> for AddressStore {
145 fn from_iter<T: IntoIterator<Item = Multiaddr>>(iter: T) -> Self {
146 let mut store = AddressStore::new();
147 for address in iter {
148 if let Some(address) = AddressRecord::from_multiaddr(address) {
149 store.insert(address);
150 }
151 }
152
153 store
154 }
155}
156
157impl FromIterator<AddressRecord> for AddressStore {
158 fn from_iter<T: IntoIterator<Item = AddressRecord>>(iter: T) -> Self {
159 let mut store = AddressStore::new();
160 for record in iter {
161 store.by_address.insert(record.address.clone());
162 store.by_score.push(record);
163 }
164
165 store
166 }
167}
168
169impl Extend<AddressRecord> for AddressStore {
170 fn extend<T: IntoIterator<Item = AddressRecord>>(&mut self, iter: T) {
171 for record in iter {
172 self.insert(record)
173 }
174 }
175}
176
177impl<'a> Extend<&'a AddressRecord> for AddressStore {
178 fn extend<T: IntoIterator<Item = &'a AddressRecord>>(&mut self, iter: T) {
179 for record in iter {
180 self.insert(record.clone())
181 }
182 }
183}
184
185impl AddressStore {
186 pub fn new() -> Self {
188 Self {
189 by_score: BinaryHeap::new(),
190 by_address: HashSet::new(),
191 }
192 }
193
194 pub fn is_empty(&self) -> bool {
196 self.by_score.is_empty()
197 }
198
199 pub fn contains(&self, address: &Multiaddr) -> bool {
201 self.by_address.contains(address)
202 }
203
204 pub fn insert(&mut self, mut record: AddressRecord) {
206 if self.by_address.contains(record.address()) {
207 return;
208 }
209
210 record.connection_id = None;
211 self.by_address.insert(record.address.clone());
212 self.by_score.push(record);
213 }
214
215 pub fn pop(&mut self) -> Option<AddressRecord> {
217 self.by_score.pop().map(|record| {
218 self.by_address.remove(&record.address);
219 record
220 })
221 }
222
223 pub fn take(&mut self, limit: usize) -> Vec<AddressRecord> {
225 let mut records = Vec::new();
226
227 for _ in 0..limit {
228 match self.pop() {
229 Some(record) => records.push(record),
230 None => break,
231 }
232 }
233
234 records
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use std::{
241 collections::HashMap,
242 net::{Ipv4Addr, SocketAddrV4},
243 };
244
245 use super::*;
246 use rand::{rngs::ThreadRng, Rng};
247
248 fn tcp_address_record(rng: &mut ThreadRng) -> AddressRecord {
249 let peer = PeerId::random();
250 let address = std::net::SocketAddr::V4(SocketAddrV4::new(
251 Ipv4Addr::new(
252 rng.gen_range(1..=255),
253 rng.gen_range(0..=255),
254 rng.gen_range(0..=255),
255 rng.gen_range(0..=255),
256 ),
257 rng.gen_range(1..=65535),
258 ));
259 let score: i32 = rng.gen();
260
261 AddressRecord::new(
262 &peer,
263 Multiaddr::empty()
264 .with(Protocol::from(address.ip()))
265 .with(Protocol::Tcp(address.port())),
266 score,
267 None,
268 )
269 }
270
271 fn ws_address_record(rng: &mut ThreadRng) -> AddressRecord {
272 let peer = PeerId::random();
273 let address = std::net::SocketAddr::V4(SocketAddrV4::new(
274 Ipv4Addr::new(
275 rng.gen_range(1..=255),
276 rng.gen_range(0..=255),
277 rng.gen_range(0..=255),
278 rng.gen_range(0..=255),
279 ),
280 rng.gen_range(1..=65535),
281 ));
282 let score: i32 = rng.gen();
283
284 AddressRecord::new(
285 &peer,
286 Multiaddr::empty()
287 .with(Protocol::from(address.ip()))
288 .with(Protocol::Tcp(address.port()))
289 .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))),
290 score,
291 None,
292 )
293 }
294
295 fn quic_address_record(rng: &mut ThreadRng) -> AddressRecord {
296 let peer = PeerId::random();
297 let address = std::net::SocketAddr::V4(SocketAddrV4::new(
298 Ipv4Addr::new(
299 rng.gen_range(1..=255),
300 rng.gen_range(0..=255),
301 rng.gen_range(0..=255),
302 rng.gen_range(0..=255),
303 ),
304 rng.gen_range(1..=65535),
305 ));
306 let score: i32 = rng.gen();
307
308 AddressRecord::new(
309 &peer,
310 Multiaddr::empty()
311 .with(Protocol::from(address.ip()))
312 .with(Protocol::Udp(address.port()))
313 .with(Protocol::QuicV1),
314 score,
315 None,
316 )
317 }
318
319 #[test]
320 fn take_multiple_records() {
321 let mut store = AddressStore::new();
322 let mut rng = rand::thread_rng();
323
324 for _ in 0..rng.gen_range(1..5) {
325 store.insert(tcp_address_record(&mut rng));
326 }
327 for _ in 0..rng.gen_range(1..5) {
328 store.insert(ws_address_record(&mut rng));
329 }
330 for _ in 0..rng.gen_range(1..5) {
331 store.insert(quic_address_record(&mut rng));
332 }
333
334 let known_addresses = store.by_address.len();
335 assert!(known_addresses >= 3);
336
337 let taken = store.take(known_addresses - 2);
338 assert_eq!(known_addresses - 2, taken.len());
339 assert!(!store.is_empty());
340
341 let mut prev: Option<AddressRecord> = None;
342 for record in taken {
343 assert!(!store.contains(record.address()));
344
345 if let Some(previous) = prev {
346 assert!(previous.score > record.score);
347 }
348
349 prev = Some(record);
350 }
351 }
352
353 #[test]
354 fn attempt_to_take_excess_records() {
355 let mut store = AddressStore::new();
356 let mut rng = rand::thread_rng();
357
358 store.insert(tcp_address_record(&mut rng));
359 store.insert(ws_address_record(&mut rng));
360 store.insert(quic_address_record(&mut rng));
361
362 assert_eq!(store.by_address.len(), 3);
363
364 let taken = store.take(8usize);
365 assert_eq!(taken.len(), 3);
366 assert!(store.is_empty());
367
368 let mut prev: Option<AddressRecord> = None;
369 for record in taken {
370 if prev.is_none() {
371 prev = Some(record);
372 } else {
373 assert!(prev.unwrap().score > record.score);
374 prev = Some(record);
375 }
376 }
377 }
378
379 #[test]
380 fn extend_from_iterator() {
381 let mut store = AddressStore::new();
382 let mut rng = rand::thread_rng();
383
384 let records = (0..10)
385 .map(|i| {
386 if i % 2 == 0 {
387 tcp_address_record(&mut rng)
388 } else if i % 3 == 0 {
389 quic_address_record(&mut rng)
390 } else {
391 ws_address_record(&mut rng)
392 }
393 })
394 .collect::<Vec<_>>();
395
396 assert!(store.is_empty());
397 let cloned = records
398 .iter()
399 .cloned()
400 .map(|record| (record.address().clone(), record))
401 .collect::<HashMap<_, _>>();
402 store.extend(records);
403
404 for record in store.by_score {
405 let stored = cloned.get(record.address()).unwrap();
406 assert_eq!(stored.score(), record.score());
407 assert_eq!(stored.connection_id(), record.connection_id());
408 assert_eq!(stored.address(), record.address());
409 }
410 }
411
412 #[test]
413 fn extend_from_iterator_ref() {
414 let mut store = AddressStore::new();
415 let mut rng = rand::thread_rng();
416
417 let records = (0..10)
418 .map(|i| {
419 if i % 2 == 0 {
420 let record = tcp_address_record(&mut rng);
421 (record.address().clone(), record)
422 } else if i % 3 == 0 {
423 let record = quic_address_record(&mut rng);
424 (record.address().clone(), record)
425 } else {
426 let record = ws_address_record(&mut rng);
427 (record.address().clone(), record)
428 }
429 })
430 .collect::<Vec<_>>();
431
432 assert!(store.is_empty());
433 let cloned = records.iter().cloned().collect::<HashMap<_, _>>();
434 store.extend(records.iter().map(|(_, record)| record));
435
436 for record in store.by_score {
437 let stored = cloned.get(record.address()).unwrap();
438 assert_eq!(stored.score(), record.score());
439 assert_eq!(stored.connection_id(), record.connection_id());
440 assert_eq!(stored.address(), record.address());
441 }
442 }
443}