litep2p/protocol/libp2p/kademlia/
store.rs

1// Copyright 2023 litep2p developers
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Memory store implementation for Kademlia.
22
23#![allow(unused)]
24use crate::protocol::libp2p::kademlia::record::{Key, ProviderRecord, Record};
25
26use std::{
27    collections::{hash_map::Entry, HashMap},
28    num::NonZeroUsize,
29};
30
31/// Logging target for the file.
32const LOG_TARGET: &str = "litep2p::ipfs::kademlia::store";
33
34/// Memory store events.
35pub enum MemoryStoreEvent {}
36
37/// Memory store.
38pub struct MemoryStore {
39    /// Records.
40    records: HashMap<Key, Record>,
41    /// Provider records.
42    provider_keys: HashMap<Key, Vec<ProviderRecord>>,
43    /// Configuration.
44    config: MemoryStoreConfig,
45}
46
47impl MemoryStore {
48    /// Create new [`MemoryStore`].
49    pub fn new() -> Self {
50        Self {
51            records: HashMap::new(),
52            provider_keys: HashMap::new(),
53            config: MemoryStoreConfig::default(),
54        }
55    }
56
57    /// Create new [`MemoryStore`] with the provided configuration.
58    pub fn with_config(config: MemoryStoreConfig) -> Self {
59        Self {
60            records: HashMap::new(),
61            provider_keys: HashMap::new(),
62            config,
63        }
64    }
65
66    /// Try to get record from local store for `key`.
67    pub fn get(&mut self, key: &Key) -> Option<&Record> {
68        let is_expired = self
69            .records
70            .get(key)
71            .map_or(false, |record| record.is_expired(std::time::Instant::now()));
72
73        if is_expired {
74            self.records.remove(key);
75            None
76        } else {
77            self.records.get(key)
78        }
79    }
80
81    /// Store record.
82    pub fn put(&mut self, record: Record) {
83        if record.value.len() >= self.config.max_record_size_bytes {
84            tracing::warn!(
85                target: LOG_TARGET,
86                key = ?record.key,
87                publisher = ?record.publisher,
88                size = record.value.len(),
89                max_size = self.config.max_record_size_bytes,
90                "discarding a DHT record that exceeds the configured size limit",
91            );
92            return;
93        }
94
95        let len = self.records.len();
96        match self.records.entry(record.key.clone()) {
97            Entry::Occupied(mut entry) => {
98                // Lean towards the new record.
99                if let (Some(stored_record_ttl), Some(new_record_ttl)) =
100                    (entry.get().expires, record.expires)
101                {
102                    if stored_record_ttl > new_record_ttl {
103                        return;
104                    }
105                }
106
107                entry.insert(record);
108            }
109
110            Entry::Vacant(entry) => {
111                if len >= self.config.max_records {
112                    tracing::warn!(
113                        target: LOG_TARGET,
114                        max_records = self.config.max_records,
115                        "discarding a DHT record, because maximum memory store size reached",
116                    );
117                    return;
118                }
119
120                entry.insert(record);
121            }
122        }
123    }
124
125    /// Try to get providers from local store for `key`.
126    ///
127    /// Returns a non-empty list of providers, if any.
128    pub fn get_providers(&mut self, key: &Key) -> Vec<ProviderRecord> {
129        let drop = self.provider_keys.get_mut(key).map_or(false, |providers| {
130            let now = std::time::Instant::now();
131            providers.retain(|p| !p.is_expired(now));
132
133            providers.is_empty()
134        });
135
136        if drop {
137            self.provider_keys.remove(key);
138
139            Vec::default()
140        } else {
141            self.provider_keys.get(key).cloned().unwrap_or_else(Vec::default)
142        }
143    }
144
145    /// Try to add a provider for `key`. If there are already `max_providers_per_key` for
146    /// this `key`, the new provider is only inserted if its closer to `key` than
147    /// the furthest already inserted provider. The furthest provider is then discarded.
148    ///
149    /// Returns `true` if the provider was added, `false` otherwise.
150    pub fn put_provider(&mut self, provider_record: ProviderRecord) -> bool {
151        // Make sure we have no more than `max_provider_addresses`.
152        let provider_record = {
153            let mut record = provider_record;
154            record.addresses.truncate(self.config.max_provider_addresses);
155            record
156        };
157
158        let can_insert_new_key = self.provider_keys.len() < self.config.max_provider_keys;
159
160        match self.provider_keys.entry(provider_record.key.clone()) {
161            Entry::Vacant(entry) =>
162                if can_insert_new_key {
163                    entry.insert(vec![provider_record]);
164
165                    true
166                } else {
167                    tracing::warn!(
168                        target: LOG_TARGET,
169                        max_provider_keys = self.config.max_provider_keys,
170                        "discarding a provider record, because the provider key limit reached",
171                    );
172
173                    false
174                },
175            Entry::Occupied(mut entry) => {
176                let mut providers = entry.get_mut();
177
178                // Providers under every key are sorted by distance from the provided key, with
179                // equal distances meaning peer IDs (more strictly, their hashes)
180                // are equal.
181                let provider_position =
182                    providers.binary_search_by(|p| p.distance().cmp(&provider_record.distance()));
183
184                match provider_position {
185                    Ok(i) => {
186                        // Update the provider in place.
187                        providers[i] = provider_record;
188
189                        true
190                    }
191                    Err(i) => {
192                        // `Err(i)` contains the insertion point.
193                        if i == self.config.max_providers_per_key {
194                            tracing::trace!(
195                                target: LOG_TARGET,
196                                key = ?provider_record.key,
197                                provider = ?provider_record.provider,
198                                max_providers_per_key = self.config.max_providers_per_key,
199                                "discarding a provider record, because it's further than \
200                                 existing `max_providers_per_key`",
201                            );
202
203                            false
204                        } else {
205                            if providers.len() == self.config.max_providers_per_key {
206                                providers.pop();
207                            }
208
209                            providers.insert(i, provider_record);
210
211                            true
212                        }
213                    }
214                }
215            }
216        }
217    }
218
219    /// Poll next event from the store.
220    async fn next_event() -> Option<MemoryStoreEvent> {
221        None
222    }
223}
224
225pub struct MemoryStoreConfig {
226    /// Maximum number of records to store.
227    pub max_records: usize,
228
229    /// Maximum size of a record in bytes.
230    pub max_record_size_bytes: usize,
231
232    /// Maximum number of provider keys this node stores.
233    pub max_provider_keys: usize,
234
235    /// Maximum number of cached addresses per provider.
236    pub max_provider_addresses: usize,
237
238    /// Maximum number of providers per key. Only providers with peer IDs closest to the key are
239    /// kept.
240    pub max_providers_per_key: usize,
241}
242
243impl Default for MemoryStoreConfig {
244    fn default() -> Self {
245        Self {
246            max_records: 1024,
247            max_record_size_bytes: 65 * 1024,
248            max_provider_keys: 1024,
249            max_provider_addresses: 30,
250            max_providers_per_key: 20,
251        }
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use crate::PeerId;
259    use multiaddr::{
260        multiaddr,
261        Protocol::{Ip4, Tcp},
262    };
263
264    #[test]
265    fn put_get_record() {
266        let mut store = MemoryStore::new();
267        let key = Key::from(vec![1, 2, 3]);
268        let record = Record::new(key.clone(), vec![4, 5, 6]);
269
270        store.put(record.clone());
271        assert_eq!(store.get(&key), Some(&record));
272    }
273
274    #[test]
275    fn max_records() {
276        let mut store = MemoryStore::with_config(MemoryStoreConfig {
277            max_records: 1,
278            max_record_size_bytes: 1024,
279            ..Default::default()
280        });
281
282        let key1 = Key::from(vec![1, 2, 3]);
283        let key2 = Key::from(vec![4, 5, 6]);
284        let record1 = Record::new(key1.clone(), vec![4, 5, 6]);
285        let record2 = Record::new(key2.clone(), vec![7, 8, 9]);
286
287        store.put(record1.clone());
288        store.put(record2.clone());
289
290        assert_eq!(store.get(&key1), Some(&record1));
291        assert_eq!(store.get(&key2), None);
292    }
293
294    #[test]
295    fn expired_record_removed() {
296        let mut store = MemoryStore::new();
297        let key = Key::from(vec![1, 2, 3]);
298        let record = Record {
299            key: key.clone(),
300            value: vec![4, 5, 6],
301            publisher: None,
302            expires: Some(std::time::Instant::now() - std::time::Duration::from_secs(5)),
303        };
304        // Record is already expired.
305        assert!(record.is_expired(std::time::Instant::now()));
306
307        store.put(record.clone());
308        assert_eq!(store.get(&key), None);
309    }
310
311    #[test]
312    fn new_record_overwrites() {
313        let mut store = MemoryStore::new();
314        let key = Key::from(vec![1, 2, 3]);
315        let record1 = Record {
316            key: key.clone(),
317            value: vec![4, 5, 6],
318            publisher: None,
319            expires: Some(std::time::Instant::now() + std::time::Duration::from_secs(100)),
320        };
321        let record2 = Record {
322            key: key.clone(),
323            value: vec![4, 5, 6],
324            publisher: None,
325            expires: Some(std::time::Instant::now() + std::time::Duration::from_secs(1000)),
326        };
327
328        store.put(record1.clone());
329        assert_eq!(store.get(&key), Some(&record1));
330
331        store.put(record2.clone());
332        assert_eq!(store.get(&key), Some(&record2));
333    }
334
335    #[test]
336    fn max_record_size() {
337        let mut store = MemoryStore::with_config(MemoryStoreConfig {
338            max_records: 1024,
339            max_record_size_bytes: 2,
340            ..Default::default()
341        });
342
343        let key = Key::from(vec![1, 2, 3]);
344        let record = Record::new(key.clone(), vec![4, 5]);
345        store.put(record.clone());
346        assert_eq!(store.get(&key), None);
347
348        let record = Record::new(key.clone(), vec![4]);
349        store.put(record.clone());
350        assert_eq!(store.get(&key), Some(&record));
351    }
352
353    #[test]
354    fn put_get_provider() {
355        let mut store = MemoryStore::new();
356        let provider = ProviderRecord {
357            key: Key::from(vec![1, 2, 3]),
358            provider: PeerId::random(),
359            addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))],
360            expires: std::time::Instant::now() + std::time::Duration::from_secs(3600),
361        };
362
363        store.put_provider(provider.clone());
364        assert_eq!(store.get_providers(&provider.key), vec![provider]);
365    }
366
367    #[test]
368    fn multiple_providers_per_key() {
369        let mut store = MemoryStore::new();
370        let key = Key::from(vec![1, 2, 3]);
371        let provider1 = ProviderRecord {
372            key: key.clone(),
373            provider: PeerId::random(),
374            addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))],
375            expires: std::time::Instant::now() + std::time::Duration::from_secs(3600),
376        };
377        let provider2 = ProviderRecord {
378            key: key.clone(),
379            provider: PeerId::random(),
380            addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))],
381            expires: std::time::Instant::now() + std::time::Duration::from_secs(3600),
382        };
383
384        store.put_provider(provider1.clone());
385        store.put_provider(provider2.clone());
386
387        let got_providers = store.get_providers(&key);
388        assert_eq!(got_providers.len(), 2);
389        assert!(got_providers.contains(&provider1));
390        assert!(got_providers.contains(&provider2));
391    }
392
393    #[test]
394    fn providers_sorted_by_distance() {
395        let mut store = MemoryStore::new();
396        let key = Key::from(vec![1, 2, 3]);
397        let providers = (0..10)
398            .map(|_| ProviderRecord {
399                key: key.clone(),
400                provider: PeerId::random(),
401                addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))],
402                expires: std::time::Instant::now() + std::time::Duration::from_secs(3600),
403            })
404            .collect::<Vec<_>>();
405
406        providers.iter().for_each(|p| {
407            store.put_provider(p.clone());
408        });
409
410        let sorted_providers = {
411            let mut providers = providers;
412            providers.sort_unstable_by_key(ProviderRecord::distance);
413            providers
414        };
415
416        assert_eq!(store.get_providers(&key), sorted_providers);
417    }
418
419    #[test]
420    fn max_providers_per_key() {
421        let mut store = MemoryStore::with_config(MemoryStoreConfig {
422            max_providers_per_key: 10,
423            ..Default::default()
424        });
425        let key = Key::from(vec![1, 2, 3]);
426        let providers = (0..20)
427            .map(|_| ProviderRecord {
428                key: key.clone(),
429                provider: PeerId::random(),
430                addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))],
431                expires: std::time::Instant::now() + std::time::Duration::from_secs(3600),
432            })
433            .collect::<Vec<_>>();
434
435        providers.iter().for_each(|p| {
436            store.put_provider(p.clone());
437        });
438        assert_eq!(store.get_providers(&key).len(), 10);
439    }
440
441    #[test]
442    fn closest_providers_kept() {
443        let mut store = MemoryStore::with_config(MemoryStoreConfig {
444            max_providers_per_key: 10,
445            ..Default::default()
446        });
447        let key = Key::from(vec![1, 2, 3]);
448        let providers = (0..20)
449            .map(|_| ProviderRecord {
450                key: key.clone(),
451                provider: PeerId::random(),
452                addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))],
453                expires: std::time::Instant::now() + std::time::Duration::from_secs(3600),
454            })
455            .collect::<Vec<_>>();
456
457        providers.iter().for_each(|p| {
458            store.put_provider(p.clone());
459        });
460
461        let closest_providers = {
462            let mut providers = providers;
463            providers.sort_unstable_by_key(ProviderRecord::distance);
464            providers.truncate(10);
465            providers
466        };
467
468        assert_eq!(store.get_providers(&key), closest_providers);
469    }
470
471    #[test]
472    fn furthest_provider_discarded() {
473        let mut store = MemoryStore::with_config(MemoryStoreConfig {
474            max_providers_per_key: 10,
475            ..Default::default()
476        });
477        let key = Key::from(vec![1, 2, 3]);
478        let providers = (0..11)
479            .map(|_| ProviderRecord {
480                key: key.clone(),
481                provider: PeerId::random(),
482                addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))],
483                expires: std::time::Instant::now() + std::time::Duration::from_secs(3600),
484            })
485            .collect::<Vec<_>>();
486
487        let sorted_providers = {
488            let mut providers = providers;
489            providers.sort_unstable_by_key(ProviderRecord::distance);
490            providers
491        };
492
493        // First 10 providers are inserted.
494        for i in 0..10 {
495            assert!(store.put_provider(sorted_providers[i].clone()));
496        }
497        assert_eq!(store.get_providers(&key), sorted_providers[..10]);
498
499        // The furthests provider doesn't fit.
500        assert!(!store.put_provider(sorted_providers[10].clone()));
501        assert_eq!(store.get_providers(&key), sorted_providers[..10]);
502    }
503
504    #[test]
505    fn update_provider_in_place() {
506        let mut store = MemoryStore::with_config(MemoryStoreConfig {
507            max_providers_per_key: 10,
508            ..Default::default()
509        });
510        let key = Key::from(vec![1, 2, 3]);
511        let peer_ids = (0..10).map(|_| PeerId::random()).collect::<Vec<_>>();
512        let peer_id0 = peer_ids[0];
513        let providers = peer_ids
514            .iter()
515            .map(|peer_id| ProviderRecord {
516                key: key.clone(),
517                provider: *peer_id,
518                addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))],
519                expires: std::time::Instant::now() + std::time::Duration::from_secs(3600),
520            })
521            .collect::<Vec<_>>();
522
523        providers.iter().for_each(|p| {
524            store.put_provider(p.clone());
525        });
526
527        let sorted_providers = {
528            let mut providers = providers;
529            providers.sort_unstable_by_key(ProviderRecord::distance);
530            providers
531        };
532
533        assert_eq!(store.get_providers(&key), sorted_providers);
534
535        let provider0_new = ProviderRecord {
536            key: key.clone(),
537            provider: peer_id0,
538            addresses: vec![multiaddr!(Ip4([192, 168, 0, 1]), Tcp(20000u16))],
539            expires: std::time::Instant::now() + std::time::Duration::from_secs(3600),
540        };
541
542        // Provider is updated in place.
543        assert!(store.put_provider(provider0_new.clone()));
544
545        let providers_new = sorted_providers
546            .into_iter()
547            .map(|p| {
548                if p.provider == peer_id0 {
549                    provider0_new.clone()
550                } else {
551                    p
552                }
553            })
554            .collect::<Vec<_>>();
555
556        assert_eq!(store.get_providers(&key), providers_new);
557    }
558
559    #[test]
560    fn provider_record_expires() {
561        let mut store = MemoryStore::new();
562        let provider = ProviderRecord {
563            key: Key::from(vec![1, 2, 3]),
564            provider: PeerId::random(),
565            addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))],
566            expires: std::time::Instant::now() - std::time::Duration::from_secs(5),
567        };
568
569        // Provider record is already expired.
570        assert!(provider.is_expired(std::time::Instant::now()));
571
572        store.put_provider(provider.clone());
573        assert!(store.get_providers(&provider.key).is_empty());
574    }
575
576    #[test]
577    fn individual_provider_record_expires() {
578        let mut store = MemoryStore::new();
579        let key = Key::from(vec![1, 2, 3]);
580        let provider1 = ProviderRecord {
581            key: key.clone(),
582            provider: PeerId::random(),
583            addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))],
584            expires: std::time::Instant::now() - std::time::Duration::from_secs(5),
585        };
586        let provider2 = ProviderRecord {
587            key: key.clone(),
588            provider: PeerId::random(),
589            addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))],
590            expires: std::time::Instant::now() + std::time::Duration::from_secs(3600),
591        };
592
593        assert!(provider1.is_expired(std::time::Instant::now()));
594
595        store.put_provider(provider1.clone());
596        store.put_provider(provider2.clone());
597
598        assert_eq!(store.get_providers(&key), vec![provider2]);
599    }
600
601    #[test]
602    fn max_addresses_per_provider() {
603        let mut store = MemoryStore::with_config(MemoryStoreConfig {
604            max_provider_addresses: 2,
605            ..Default::default()
606        });
607        let key = Key::from(vec![1, 2, 3]);
608        let provider = ProviderRecord {
609            key: Key::from(vec![1, 2, 3]),
610            provider: PeerId::random(),
611            addresses: vec![
612                multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16)),
613                multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10001u16)),
614                multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10002u16)),
615                multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10003u16)),
616                multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10004u16)),
617            ],
618            expires: std::time::Instant::now() + std::time::Duration::from_secs(3600),
619        };
620
621        store.put_provider(provider);
622
623        let got_providers = store.get_providers(&key);
624        assert_eq!(got_providers.len(), 1);
625        assert_eq!(got_providers.first().unwrap().key, key);
626        assert_eq!(got_providers.first().unwrap().addresses.len(), 2);
627    }
628
629    #[test]
630    fn max_provider_keys() {
631        let mut store = MemoryStore::with_config(MemoryStoreConfig {
632            max_provider_keys: 2,
633            ..Default::default()
634        });
635
636        let provider1 = ProviderRecord {
637            key: Key::from(vec![1, 2, 3]),
638            provider: PeerId::random(),
639            addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10001u16))],
640            expires: std::time::Instant::now() + std::time::Duration::from_secs(3600),
641        };
642        let provider2 = ProviderRecord {
643            key: Key::from(vec![4, 5, 6]),
644            provider: PeerId::random(),
645            addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10002u16))],
646            expires: std::time::Instant::now() + std::time::Duration::from_secs(3600),
647        };
648        let provider3 = ProviderRecord {
649            key: Key::from(vec![7, 8, 9]),
650            provider: PeerId::random(),
651            addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10003u16))],
652            expires: std::time::Instant::now() + std::time::Duration::from_secs(3600),
653        };
654
655        assert!(store.put_provider(provider1.clone()));
656        assert!(store.put_provider(provider2.clone()));
657        assert!(!store.put_provider(provider3.clone()));
658
659        assert_eq!(store.get_providers(&provider1.key), vec![provider1]);
660        assert_eq!(store.get_providers(&provider2.key), vec![provider2]);
661        assert_eq!(store.get_providers(&provider3.key), vec![]);
662    }
663}