libp2p_kad/kbucket/
key.rs1use crate::record_priv;
22use libp2p_core::multihash::Multihash;
23use libp2p_identity::PeerId;
24use sha2::digest::generic_array::{typenum::U32, GenericArray};
25use sha2::{Digest, Sha256};
26use std::borrow::Borrow;
27use std::hash::{Hash, Hasher};
28use uint::*;
29
30construct_uint! {
31 pub(super) struct U256(4);
33}
34
35#[derive(Clone, Debug)]
43pub struct Key<T> {
44 preimage: T,
45 bytes: KeyBytes,
46}
47
48impl<T> Key<T> {
49 pub fn new(preimage: T) -> Key<T>
55 where
56 T: Borrow<[u8]>,
57 {
58 let bytes = KeyBytes::new(preimage.borrow());
59 Key { preimage, bytes }
60 }
61
62 pub fn preimage(&self) -> &T {
64 &self.preimage
65 }
66
67 pub fn into_preimage(self) -> T {
69 self.preimage
70 }
71
72 pub fn distance<U>(&self, other: &U) -> Distance
74 where
75 U: AsRef<KeyBytes>,
76 {
77 self.bytes.distance(other)
78 }
79
80 pub fn for_distance(&self, d: Distance) -> KeyBytes {
86 self.bytes.for_distance(d)
87 }
88}
89
90impl<T> From<Key<T>> for KeyBytes {
91 fn from(key: Key<T>) -> KeyBytes {
92 key.bytes
93 }
94}
95
96impl<const S: usize> From<Multihash<S>> for Key<Multihash<S>> {
97 fn from(m: Multihash<S>) -> Self {
98 let bytes = KeyBytes(Sha256::digest(m.to_bytes()));
99 Key { preimage: m, bytes }
100 }
101}
102
103impl From<PeerId> for Key<PeerId> {
104 fn from(p: PeerId) -> Self {
105 let bytes = KeyBytes(Sha256::digest(p.to_bytes()));
106 Key { preimage: p, bytes }
107 }
108}
109
110impl From<Vec<u8>> for Key<Vec<u8>> {
111 fn from(b: Vec<u8>) -> Self {
112 Key::new(b)
113 }
114}
115
116impl From<record_priv::Key> for Key<record_priv::Key> {
117 fn from(k: record_priv::Key) -> Self {
118 Key::new(k)
119 }
120}
121
122impl<T> AsRef<KeyBytes> for Key<T> {
123 fn as_ref(&self) -> &KeyBytes {
124 &self.bytes
125 }
126}
127
128impl<T, U> PartialEq<Key<U>> for Key<T> {
129 fn eq(&self, other: &Key<U>) -> bool {
130 self.bytes == other.bytes
131 }
132}
133
134impl<T> Eq for Key<T> {}
135
136impl<T> Hash for Key<T> {
137 fn hash<H: Hasher>(&self, state: &mut H) {
138 self.bytes.0.hash(state);
139 }
140}
141
142#[derive(PartialEq, Eq, Clone, Debug)]
144pub struct KeyBytes(GenericArray<u8, U32>);
145
146impl KeyBytes {
147 pub fn new<T>(value: T) -> Self
150 where
151 T: Borrow<[u8]>,
152 {
153 KeyBytes(Sha256::digest(value.borrow()))
154 }
155
156 pub fn distance<U>(&self, other: &U) -> Distance
158 where
159 U: AsRef<KeyBytes>,
160 {
161 let a = U256::from(self.0.as_slice());
162 let b = U256::from(other.as_ref().0.as_slice());
163 Distance(a ^ b)
164 }
165
166 pub fn for_distance(&self, d: Distance) -> KeyBytes {
172 let key_int = U256::from(self.0.as_slice()) ^ d.0;
173 KeyBytes(GenericArray::from(<[u8; 32]>::from(key_int)))
174 }
175}
176
177impl AsRef<KeyBytes> for KeyBytes {
178 fn as_ref(&self) -> &KeyBytes {
179 self
180 }
181}
182
183#[derive(Copy, Clone, PartialEq, Eq, Default, PartialOrd, Ord, Debug)]
185pub struct Distance(pub(super) U256);
186
187impl Distance {
188 pub fn ilog2(&self) -> Option<u32> {
192 (256 - self.0.leading_zeros()).checked_sub(1)
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use crate::SHA_256_MH;
200 use quickcheck::*;
201
202 impl Arbitrary for Key<PeerId> {
203 fn arbitrary(_: &mut Gen) -> Key<PeerId> {
204 Key::from(PeerId::random())
205 }
206 }
207
208 impl Arbitrary for Key<Multihash<64>> {
209 fn arbitrary(g: &mut Gen) -> Key<Multihash<64>> {
210 let hash: [u8; 32] = core::array::from_fn(|_| u8::arbitrary(g));
211 Key::from(Multihash::wrap(SHA_256_MH, &hash).unwrap())
212 }
213 }
214
215 #[test]
216 fn identity() {
217 fn prop(a: Key<PeerId>) -> bool {
218 a.distance(&a) == Distance::default()
219 }
220 quickcheck(prop as fn(_) -> _)
221 }
222
223 #[test]
224 fn symmetry() {
225 fn prop(a: Key<PeerId>, b: Key<PeerId>) -> bool {
226 a.distance(&b) == b.distance(&a)
227 }
228 quickcheck(prop as fn(_, _) -> _)
229 }
230
231 #[test]
232 fn triangle_inequality() {
233 fn prop(a: Key<PeerId>, b: Key<PeerId>, c: Key<PeerId>) -> TestResult {
234 let ab = a.distance(&b);
235 let bc = b.distance(&c);
236 let (ab_plus_bc, overflow) = ab.0.overflowing_add(bc.0);
237 if overflow {
238 TestResult::discard()
239 } else {
240 TestResult::from_bool(a.distance(&c) <= Distance(ab_plus_bc))
241 }
242 }
243 quickcheck(prop as fn(_, _, _) -> _)
244 }
245
246 #[test]
247 fn unidirectionality() {
248 fn prop(a: Key<PeerId>, b: Key<PeerId>) -> bool {
249 let d = a.distance(&b);
250 (0..100).all(|_| {
251 let c = Key::from(PeerId::random());
252 a.distance(&c) != d || b == c
253 })
254 }
255 quickcheck(prop as fn(_, _) -> _)
256 }
257}