litep2p/transport/manager/
address.rs

1// Copyright 2023 litep2p developers
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::{types::ConnectionId, PeerId};
22
23use multiaddr::{Multiaddr, Protocol};
24use multihash::Multihash;
25
26use std::collections::{BinaryHeap, HashSet};
27
28#[allow(clippy::derived_hash_with_manual_eq)]
29#[derive(Debug, Clone, Hash)]
30pub struct AddressRecord {
31    /// Address score.
32    score: i32,
33
34    /// Address.
35    address: Multiaddr,
36
37    /// Connection ID, if specified.
38    connection_id: Option<ConnectionId>,
39}
40
41impl AsRef<Multiaddr> for AddressRecord {
42    fn as_ref(&self) -> &Multiaddr {
43        &self.address
44    }
45}
46
47impl AddressRecord {
48    /// Create new `AddressRecord` and if `address` doesn't contain `P2p`,
49    /// append the provided `PeerId` to the address.
50    pub fn new(
51        peer: &PeerId,
52        address: Multiaddr,
53        score: i32,
54        connection_id: Option<ConnectionId>,
55    ) -> Self {
56        let address = if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) {
57            address.with(Protocol::P2p(
58                Multihash::from_bytes(&peer.to_bytes()).expect("valid peer id"),
59            ))
60        } else {
61            address
62        };
63
64        Self {
65            address,
66            score,
67            connection_id,
68        }
69    }
70
71    /// Create `AddressRecord` from `Multiaddr`.
72    ///
73    /// If `address` doesn't contain `PeerId`, return `None` to indicate that this
74    /// an invalid `Multiaddr` from the perspective of the `TransportManager`.
75    pub fn from_multiaddr(address: Multiaddr) -> Option<AddressRecord> {
76        if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) {
77            return None;
78        }
79
80        Some(AddressRecord {
81            address,
82            score: 0i32,
83            connection_id: None,
84        })
85    }
86
87    /// Get address score.
88    #[cfg(test)]
89    pub fn score(&self) -> i32 {
90        self.score
91    }
92
93    /// Get address.
94    pub fn address(&self) -> &Multiaddr {
95        &self.address
96    }
97
98    /// Get connection ID.
99    pub fn connection_id(&self) -> &Option<ConnectionId> {
100        &self.connection_id
101    }
102
103    /// Update score of an address.
104    pub fn update_score(&mut self, score: i32) {
105        self.score = self.score.saturating_add(score);
106    }
107
108    /// Set `ConnectionId` for the [`AddressRecord`].
109    pub fn set_connection_id(&mut self, connection_id: ConnectionId) {
110        self.connection_id = Some(connection_id);
111    }
112}
113
114impl PartialEq for AddressRecord {
115    fn eq(&self, other: &Self) -> bool {
116        self.score.eq(&other.score)
117    }
118}
119
120impl Eq for AddressRecord {}
121
122impl PartialOrd for AddressRecord {
123    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
124        Some(self.score.cmp(&other.score))
125    }
126}
127
128impl Ord for AddressRecord {
129    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
130        self.score.cmp(&other.score)
131    }
132}
133
134/// Store for peer addresses.
135#[derive(Debug)]
136pub struct AddressStore {
137    //// Addresses sorted by score.
138    pub by_score: BinaryHeap<AddressRecord>,
139
140    /// Addresses queryable by hashing them for faster lookup.
141    pub by_address: HashSet<Multiaddr>,
142}
143
144impl FromIterator<Multiaddr> for AddressStore {
145    fn from_iter<T: IntoIterator<Item = Multiaddr>>(iter: T) -> Self {
146        let mut store = AddressStore::new();
147        for address in iter {
148            if let Some(address) = AddressRecord::from_multiaddr(address) {
149                store.insert(address);
150            }
151        }
152
153        store
154    }
155}
156
157impl FromIterator<AddressRecord> for AddressStore {
158    fn from_iter<T: IntoIterator<Item = AddressRecord>>(iter: T) -> Self {
159        let mut store = AddressStore::new();
160        for record in iter {
161            store.by_address.insert(record.address.clone());
162            store.by_score.push(record);
163        }
164
165        store
166    }
167}
168
169impl Extend<AddressRecord> for AddressStore {
170    fn extend<T: IntoIterator<Item = AddressRecord>>(&mut self, iter: T) {
171        for record in iter {
172            self.insert(record)
173        }
174    }
175}
176
177impl<'a> Extend<&'a AddressRecord> for AddressStore {
178    fn extend<T: IntoIterator<Item = &'a AddressRecord>>(&mut self, iter: T) {
179        for record in iter {
180            self.insert(record.clone())
181        }
182    }
183}
184
185impl AddressStore {
186    /// Create new [`AddressStore`].
187    pub fn new() -> Self {
188        Self {
189            by_score: BinaryHeap::new(),
190            by_address: HashSet::new(),
191        }
192    }
193
194    /// Check if [`AddressStore`] is empty.
195    pub fn is_empty(&self) -> bool {
196        self.by_score.is_empty()
197    }
198
199    /// Check if address is already in the a
200    pub fn contains(&self, address: &Multiaddr) -> bool {
201        self.by_address.contains(address)
202    }
203
204    /// Insert new address record into [`AddressStore`] with default address score.
205    pub fn insert(&mut self, mut record: AddressRecord) {
206        if self.by_address.contains(record.address()) {
207            return;
208        }
209
210        record.connection_id = None;
211        self.by_address.insert(record.address.clone());
212        self.by_score.push(record);
213    }
214
215    /// Pop address with the highest score from [`AddressStore`].
216    pub fn pop(&mut self) -> Option<AddressRecord> {
217        self.by_score.pop().map(|record| {
218            self.by_address.remove(&record.address);
219            record
220        })
221    }
222
223    /// Take at most `limit` `AddressRecord`s from [`AddressStore`].
224    pub fn take(&mut self, limit: usize) -> Vec<AddressRecord> {
225        let mut records = Vec::new();
226
227        for _ in 0..limit {
228            match self.pop() {
229                Some(record) => records.push(record),
230                None => break,
231            }
232        }
233
234        records
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use std::{
241        collections::HashMap,
242        net::{Ipv4Addr, SocketAddrV4},
243    };
244
245    use super::*;
246    use rand::{rngs::ThreadRng, Rng};
247
248    fn tcp_address_record(rng: &mut ThreadRng) -> AddressRecord {
249        let peer = PeerId::random();
250        let address = std::net::SocketAddr::V4(SocketAddrV4::new(
251            Ipv4Addr::new(
252                rng.gen_range(1..=255),
253                rng.gen_range(0..=255),
254                rng.gen_range(0..=255),
255                rng.gen_range(0..=255),
256            ),
257            rng.gen_range(1..=65535),
258        ));
259        let score: i32 = rng.gen();
260
261        AddressRecord::new(
262            &peer,
263            Multiaddr::empty()
264                .with(Protocol::from(address.ip()))
265                .with(Protocol::Tcp(address.port())),
266            score,
267            None,
268        )
269    }
270
271    fn ws_address_record(rng: &mut ThreadRng) -> AddressRecord {
272        let peer = PeerId::random();
273        let address = std::net::SocketAddr::V4(SocketAddrV4::new(
274            Ipv4Addr::new(
275                rng.gen_range(1..=255),
276                rng.gen_range(0..=255),
277                rng.gen_range(0..=255),
278                rng.gen_range(0..=255),
279            ),
280            rng.gen_range(1..=65535),
281        ));
282        let score: i32 = rng.gen();
283
284        AddressRecord::new(
285            &peer,
286            Multiaddr::empty()
287                .with(Protocol::from(address.ip()))
288                .with(Protocol::Tcp(address.port()))
289                .with(Protocol::Ws(std::borrow::Cow::Owned("/".to_string()))),
290            score,
291            None,
292        )
293    }
294
295    fn quic_address_record(rng: &mut ThreadRng) -> AddressRecord {
296        let peer = PeerId::random();
297        let address = std::net::SocketAddr::V4(SocketAddrV4::new(
298            Ipv4Addr::new(
299                rng.gen_range(1..=255),
300                rng.gen_range(0..=255),
301                rng.gen_range(0..=255),
302                rng.gen_range(0..=255),
303            ),
304            rng.gen_range(1..=65535),
305        ));
306        let score: i32 = rng.gen();
307
308        AddressRecord::new(
309            &peer,
310            Multiaddr::empty()
311                .with(Protocol::from(address.ip()))
312                .with(Protocol::Udp(address.port()))
313                .with(Protocol::QuicV1),
314            score,
315            None,
316        )
317    }
318
319    #[test]
320    fn take_multiple_records() {
321        let mut store = AddressStore::new();
322        let mut rng = rand::thread_rng();
323
324        for _ in 0..rng.gen_range(1..5) {
325            store.insert(tcp_address_record(&mut rng));
326        }
327        for _ in 0..rng.gen_range(1..5) {
328            store.insert(ws_address_record(&mut rng));
329        }
330        for _ in 0..rng.gen_range(1..5) {
331            store.insert(quic_address_record(&mut rng));
332        }
333
334        let known_addresses = store.by_address.len();
335        assert!(known_addresses >= 3);
336
337        let taken = store.take(known_addresses - 2);
338        assert_eq!(known_addresses - 2, taken.len());
339        assert!(!store.is_empty());
340
341        let mut prev: Option<AddressRecord> = None;
342        for record in taken {
343            assert!(!store.contains(record.address()));
344
345            if let Some(previous) = prev {
346                assert!(previous.score > record.score);
347            }
348
349            prev = Some(record);
350        }
351    }
352
353    #[test]
354    fn attempt_to_take_excess_records() {
355        let mut store = AddressStore::new();
356        let mut rng = rand::thread_rng();
357
358        store.insert(tcp_address_record(&mut rng));
359        store.insert(ws_address_record(&mut rng));
360        store.insert(quic_address_record(&mut rng));
361
362        assert_eq!(store.by_address.len(), 3);
363
364        let taken = store.take(8usize);
365        assert_eq!(taken.len(), 3);
366        assert!(store.is_empty());
367
368        let mut prev: Option<AddressRecord> = None;
369        for record in taken {
370            if prev.is_none() {
371                prev = Some(record);
372            } else {
373                assert!(prev.unwrap().score > record.score);
374                prev = Some(record);
375            }
376        }
377    }
378
379    #[test]
380    fn extend_from_iterator() {
381        let mut store = AddressStore::new();
382        let mut rng = rand::thread_rng();
383
384        let records = (0..10)
385            .map(|i| {
386                if i % 2 == 0 {
387                    tcp_address_record(&mut rng)
388                } else if i % 3 == 0 {
389                    quic_address_record(&mut rng)
390                } else {
391                    ws_address_record(&mut rng)
392                }
393            })
394            .collect::<Vec<_>>();
395
396        assert!(store.is_empty());
397        let cloned = records
398            .iter()
399            .cloned()
400            .map(|record| (record.address().clone(), record))
401            .collect::<HashMap<_, _>>();
402        store.extend(records);
403
404        for record in store.by_score {
405            let stored = cloned.get(record.address()).unwrap();
406            assert_eq!(stored.score(), record.score());
407            assert_eq!(stored.connection_id(), record.connection_id());
408            assert_eq!(stored.address(), record.address());
409        }
410    }
411
412    #[test]
413    fn extend_from_iterator_ref() {
414        let mut store = AddressStore::new();
415        let mut rng = rand::thread_rng();
416
417        let records = (0..10)
418            .map(|i| {
419                if i % 2 == 0 {
420                    let record = tcp_address_record(&mut rng);
421                    (record.address().clone(), record)
422                } else if i % 3 == 0 {
423                    let record = quic_address_record(&mut rng);
424                    (record.address().clone(), record)
425                } else {
426                    let record = ws_address_record(&mut rng);
427                    (record.address().clone(), record)
428                }
429            })
430            .collect::<Vec<_>>();
431
432        assert!(store.is_empty());
433        let cloned = records.iter().cloned().collect::<HashMap<_, _>>();
434        store.extend(records.iter().map(|(_, record)| record));
435
436        for record in store.by_score {
437            let stored = cloned.get(record.address()).unwrap();
438            assert_eq!(stored.score(), record.score());
439            assert_eq!(stored.connection_id(), record.connection_id());
440            assert_eq!(stored.address(), record.address());
441        }
442    }
443}