1use crate::proto;
30use crate::record_priv::{self, Record};
31use asynchronous_codec::{Decoder, Encoder, Framed};
32use bytes::BytesMut;
33use futures::prelude::*;
34use instant::Instant;
35use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo};
36use libp2p_core::Multiaddr;
37use libp2p_identity::PeerId;
38use libp2p_swarm::StreamProtocol;
39use std::marker::PhantomData;
40use std::{convert::TryFrom, time::Duration};
41use std::{io, iter};
42
43pub(crate) const DEFAULT_PROTO_NAME: StreamProtocol = StreamProtocol::new("/ipfs/kad/1.0.0");
45pub(crate) const DEFAULT_MAX_PACKET_SIZE: usize = 16 * 1024;
47#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
49pub enum ConnectionType {
50 NotConnected = 0,
52 Connected = 1,
54 CanConnect = 2,
56 CannotConnect = 3,
58}
59
60impl From<proto::ConnectionType> for ConnectionType {
61 fn from(raw: proto::ConnectionType) -> ConnectionType {
62 use proto::ConnectionType::*;
63 match raw {
64 NOT_CONNECTED => ConnectionType::NotConnected,
65 CONNECTED => ConnectionType::Connected,
66 CAN_CONNECT => ConnectionType::CanConnect,
67 CANNOT_CONNECT => ConnectionType::CannotConnect,
68 }
69 }
70}
71
72impl From<ConnectionType> for proto::ConnectionType {
73 fn from(val: ConnectionType) -> Self {
74 use proto::ConnectionType::*;
75 match val {
76 ConnectionType::NotConnected => NOT_CONNECTED,
77 ConnectionType::Connected => CONNECTED,
78 ConnectionType::CanConnect => CAN_CONNECT,
79 ConnectionType::CannotConnect => CANNOT_CONNECT,
80 }
81 }
82}
83
84#[derive(Debug, Clone, PartialEq, Eq)]
86pub struct KadPeer {
87 pub node_id: PeerId,
89 pub multiaddrs: Vec<Multiaddr>,
91 pub connection_ty: ConnectionType,
93}
94
95impl TryFrom<proto::Peer> for KadPeer {
97 type Error = io::Error;
98
99 fn try_from(peer: proto::Peer) -> Result<KadPeer, Self::Error> {
100 let node_id = PeerId::from_bytes(&peer.id).map_err(|_| invalid_data("invalid peer id"))?;
103
104 let mut addrs = Vec::with_capacity(peer.addrs.len());
105 for addr in peer.addrs.into_iter() {
106 match Multiaddr::try_from(addr) {
107 Ok(a) => addrs.push(a),
108 Err(e) => {
109 log::debug!("Unable to parse multiaddr: {e}");
110 }
111 };
112 }
113
114 Ok(KadPeer {
115 node_id,
116 multiaddrs: addrs,
117 connection_ty: peer.connection.into(),
118 })
119 }
120}
121
122impl From<KadPeer> for proto::Peer {
123 fn from(peer: KadPeer) -> Self {
124 proto::Peer {
125 id: peer.node_id.to_bytes(),
126 addrs: peer.multiaddrs.into_iter().map(|a| a.to_vec()).collect(),
127 connection: peer.connection_ty.into(),
128 }
129 }
130}
131
132#[derive(Debug, Clone)]
138pub struct ProtocolConfig {
139 protocol_names: Vec<StreamProtocol>,
140 max_packet_size: usize,
142}
143
144impl ProtocolConfig {
145 pub fn protocol_names(&self) -> &[StreamProtocol] {
147 &self.protocol_names
148 }
149
150 pub fn set_protocol_names(&mut self, names: Vec<StreamProtocol>) {
153 self.protocol_names = names;
154 }
155
156 pub fn set_max_packet_size(&mut self, size: usize) {
158 self.max_packet_size = size;
159 }
160}
161
162impl Default for ProtocolConfig {
163 fn default() -> Self {
164 ProtocolConfig {
165 protocol_names: iter::once(DEFAULT_PROTO_NAME).collect(),
166 max_packet_size: DEFAULT_MAX_PACKET_SIZE,
167 }
168 }
169}
170
171impl UpgradeInfo for ProtocolConfig {
172 type Info = StreamProtocol;
173 type InfoIter = std::vec::IntoIter<Self::Info>;
174
175 fn protocol_info(&self) -> Self::InfoIter {
176 self.protocol_names.clone().into_iter()
177 }
178}
179
180pub struct Codec<A, B> {
182 codec: quick_protobuf_codec::Codec<proto::Message>,
183 __phantom: PhantomData<(A, B)>,
184}
185impl<A, B> Codec<A, B> {
186 fn new(max_packet_size: usize) -> Self {
187 Codec {
188 codec: quick_protobuf_codec::Codec::new(max_packet_size),
189 __phantom: PhantomData,
190 }
191 }
192}
193
194impl<A: Into<proto::Message>, B> Encoder for Codec<A, B> {
195 type Error = io::Error;
196 type Item = A;
197
198 fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
199 Ok(self.codec.encode(item.into(), dst)?)
200 }
201}
202impl<A, B: TryFrom<proto::Message, Error = io::Error>> Decoder for Codec<A, B> {
203 type Error = io::Error;
204 type Item = B;
205
206 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
207 self.codec.decode(src)?.map(B::try_from).transpose()
208 }
209}
210
211pub(crate) type KadInStreamSink<S> = Framed<S, Codec<KadResponseMsg, KadRequestMsg>>;
213pub(crate) type KadOutStreamSink<S> = Framed<S, Codec<KadRequestMsg, KadResponseMsg>>;
215
216impl<C> InboundUpgrade<C> for ProtocolConfig
217where
218 C: AsyncRead + AsyncWrite + Unpin,
219{
220 type Output = KadInStreamSink<C>;
221 type Future = future::Ready<Result<Self::Output, io::Error>>;
222 type Error = io::Error;
223
224 fn upgrade_inbound(self, incoming: C, _: Self::Info) -> Self::Future {
225 let codec = Codec::new(self.max_packet_size);
226
227 future::ok(Framed::new(incoming, codec))
228 }
229}
230
231impl<C> OutboundUpgrade<C> for ProtocolConfig
232where
233 C: AsyncRead + AsyncWrite + Unpin,
234{
235 type Output = KadOutStreamSink<C>;
236 type Future = future::Ready<Result<Self::Output, io::Error>>;
237 type Error = io::Error;
238
239 fn upgrade_outbound(self, incoming: C, _: Self::Info) -> Self::Future {
240 let codec = Codec::new(self.max_packet_size);
241
242 future::ok(Framed::new(incoming, codec))
243 }
244}
245
246#[derive(Debug, Clone, PartialEq, Eq)]
248pub enum KadRequestMsg {
249 Ping,
251
252 FindNode {
255 key: Vec<u8>,
257 },
258
259 GetProviders {
262 key: record_priv::Key,
264 },
265
266 AddProvider {
268 key: record_priv::Key,
270 provider: KadPeer,
272 },
273
274 GetValue {
276 key: record_priv::Key,
278 },
279
280 PutValue { record: Record },
282}
283
284#[derive(Debug, Clone, PartialEq, Eq)]
286pub enum KadResponseMsg {
287 Pong,
289
290 FindNode {
292 closer_peers: Vec<KadPeer>,
294 },
295
296 GetProviders {
298 closer_peers: Vec<KadPeer>,
300 provider_peers: Vec<KadPeer>,
302 },
303
304 GetValue {
306 record: Option<Record>,
308 closer_peers: Vec<KadPeer>,
310 },
311
312 PutValue {
314 key: record_priv::Key,
316 value: Vec<u8>,
318 },
319}
320
321impl From<KadRequestMsg> for proto::Message {
322 fn from(kad_msg: KadRequestMsg) -> Self {
323 req_msg_to_proto(kad_msg)
324 }
325}
326impl From<KadResponseMsg> for proto::Message {
327 fn from(kad_msg: KadResponseMsg) -> Self {
328 resp_msg_to_proto(kad_msg)
329 }
330}
331impl TryFrom<proto::Message> for KadRequestMsg {
332 type Error = io::Error;
333
334 fn try_from(message: proto::Message) -> Result<Self, Self::Error> {
335 proto_to_req_msg(message)
336 }
337}
338impl TryFrom<proto::Message> for KadResponseMsg {
339 type Error = io::Error;
340
341 fn try_from(message: proto::Message) -> Result<Self, Self::Error> {
342 proto_to_resp_msg(message)
343 }
344}
345
346fn req_msg_to_proto(kad_msg: KadRequestMsg) -> proto::Message {
348 match kad_msg {
349 KadRequestMsg::Ping => proto::Message {
350 type_pb: proto::MessageType::PING,
351 ..proto::Message::default()
352 },
353 KadRequestMsg::FindNode { key } => proto::Message {
354 type_pb: proto::MessageType::FIND_NODE,
355 key,
356 clusterLevelRaw: 10,
357 ..proto::Message::default()
358 },
359 KadRequestMsg::GetProviders { key } => proto::Message {
360 type_pb: proto::MessageType::GET_PROVIDERS,
361 key: key.to_vec(),
362 clusterLevelRaw: 10,
363 ..proto::Message::default()
364 },
365 KadRequestMsg::AddProvider { key, provider } => proto::Message {
366 type_pb: proto::MessageType::ADD_PROVIDER,
367 clusterLevelRaw: 10,
368 key: key.to_vec(),
369 providerPeers: vec![provider.into()],
370 ..proto::Message::default()
371 },
372 KadRequestMsg::GetValue { key } => proto::Message {
373 type_pb: proto::MessageType::GET_VALUE,
374 clusterLevelRaw: 10,
375 key: key.to_vec(),
376 ..proto::Message::default()
377 },
378 KadRequestMsg::PutValue { record } => proto::Message {
379 type_pb: proto::MessageType::PUT_VALUE,
380 key: record.key.to_vec(),
381 record: Some(record_to_proto(record)),
382 ..proto::Message::default()
383 },
384 }
385}
386
387fn resp_msg_to_proto(kad_msg: KadResponseMsg) -> proto::Message {
389 match kad_msg {
390 KadResponseMsg::Pong => proto::Message {
391 type_pb: proto::MessageType::PING,
392 ..proto::Message::default()
393 },
394 KadResponseMsg::FindNode { closer_peers } => proto::Message {
395 type_pb: proto::MessageType::FIND_NODE,
396 clusterLevelRaw: 9,
397 closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
398 ..proto::Message::default()
399 },
400 KadResponseMsg::GetProviders {
401 closer_peers,
402 provider_peers,
403 } => proto::Message {
404 type_pb: proto::MessageType::GET_PROVIDERS,
405 clusterLevelRaw: 9,
406 closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
407 providerPeers: provider_peers.into_iter().map(KadPeer::into).collect(),
408 ..proto::Message::default()
409 },
410 KadResponseMsg::GetValue {
411 record,
412 closer_peers,
413 } => proto::Message {
414 type_pb: proto::MessageType::GET_VALUE,
415 clusterLevelRaw: 9,
416 closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
417 record: record.map(record_to_proto),
418 ..proto::Message::default()
419 },
420 KadResponseMsg::PutValue { key, value } => proto::Message {
421 type_pb: proto::MessageType::PUT_VALUE,
422 key: key.to_vec(),
423 record: Some(proto::Record {
424 key: key.to_vec(),
425 value,
426 ..proto::Record::default()
427 }),
428 ..proto::Message::default()
429 },
430 }
431}
432
433fn proto_to_req_msg(message: proto::Message) -> Result<KadRequestMsg, io::Error> {
437 match message.type_pb {
438 proto::MessageType::PING => Ok(KadRequestMsg::Ping),
439 proto::MessageType::PUT_VALUE => {
440 let record = record_from_proto(message.record.unwrap_or_default())?;
441 Ok(KadRequestMsg::PutValue { record })
442 }
443 proto::MessageType::GET_VALUE => Ok(KadRequestMsg::GetValue {
444 key: record_priv::Key::from(message.key),
445 }),
446 proto::MessageType::FIND_NODE => Ok(KadRequestMsg::FindNode { key: message.key }),
447 proto::MessageType::GET_PROVIDERS => Ok(KadRequestMsg::GetProviders {
448 key: record_priv::Key::from(message.key),
449 }),
450 proto::MessageType::ADD_PROVIDER => {
451 let provider = message
455 .providerPeers
456 .into_iter()
457 .find_map(|peer| KadPeer::try_from(peer).ok());
458
459 if let Some(provider) = provider {
460 let key = record_priv::Key::from(message.key);
461 Ok(KadRequestMsg::AddProvider { key, provider })
462 } else {
463 Err(invalid_data("AddProvider message with no valid peer."))
464 }
465 }
466 }
467}
468
469fn proto_to_resp_msg(message: proto::Message) -> Result<KadResponseMsg, io::Error> {
473 match message.type_pb {
474 proto::MessageType::PING => Ok(KadResponseMsg::Pong),
475 proto::MessageType::GET_VALUE => {
476 let record = if let Some(r) = message.record {
477 Some(record_from_proto(r)?)
478 } else {
479 None
480 };
481
482 let closer_peers = message
483 .closerPeers
484 .into_iter()
485 .filter_map(|peer| KadPeer::try_from(peer).ok())
486 .collect();
487
488 Ok(KadResponseMsg::GetValue {
489 record,
490 closer_peers,
491 })
492 }
493
494 proto::MessageType::FIND_NODE => {
495 let closer_peers = message
496 .closerPeers
497 .into_iter()
498 .filter_map(|peer| KadPeer::try_from(peer).ok())
499 .collect();
500
501 Ok(KadResponseMsg::FindNode { closer_peers })
502 }
503
504 proto::MessageType::GET_PROVIDERS => {
505 let closer_peers = message
506 .closerPeers
507 .into_iter()
508 .filter_map(|peer| KadPeer::try_from(peer).ok())
509 .collect();
510
511 let provider_peers = message
512 .providerPeers
513 .into_iter()
514 .filter_map(|peer| KadPeer::try_from(peer).ok())
515 .collect();
516
517 Ok(KadResponseMsg::GetProviders {
518 closer_peers,
519 provider_peers,
520 })
521 }
522
523 proto::MessageType::PUT_VALUE => {
524 let key = record_priv::Key::from(message.key);
525 let rec = message
526 .record
527 .ok_or_else(|| invalid_data("received PutValue message with no record"))?;
528
529 Ok(KadResponseMsg::PutValue {
530 key,
531 value: rec.value,
532 })
533 }
534
535 proto::MessageType::ADD_PROVIDER => {
536 Err(invalid_data("received an unexpected AddProvider message"))
537 }
538 }
539}
540
541fn record_from_proto(record: proto::Record) -> Result<Record, io::Error> {
542 let key = record_priv::Key::from(record.key);
543 let value = record.value;
544
545 let publisher = if !record.publisher.is_empty() {
546 PeerId::from_bytes(&record.publisher)
547 .map(Some)
548 .map_err(|_| invalid_data("Invalid publisher peer ID."))?
549 } else {
550 None
551 };
552
553 let expires = if record.ttl > 0 {
554 Some(Instant::now() + Duration::from_secs(record.ttl as u64))
555 } else {
556 None
557 };
558
559 Ok(Record {
560 key,
561 value,
562 publisher,
563 expires,
564 })
565}
566
567fn record_to_proto(record: Record) -> proto::Record {
568 proto::Record {
569 key: record.key.to_vec(),
570 value: record.value,
571 publisher: record.publisher.map(|id| id.to_bytes()).unwrap_or_default(),
572 ttl: record
573 .expires
574 .map(|t| {
575 let now = Instant::now();
576 if t > now {
577 (t - now).as_secs() as u32
578 } else {
579 1 }
581 })
582 .unwrap_or(0),
583 timeReceived: String::new(),
584 }
585}
586
587fn invalid_data<E>(e: E) -> io::Error
589where
590 E: Into<Box<dyn std::error::Error + Send + Sync>>,
591{
592 io::Error::new(io::ErrorKind::InvalidData, e)
593}
594
595#[cfg(test)]
596mod tests {
597 use super::*;
598
599 #[test]
600 fn skip_invalid_multiaddr() {
601 let valid_multiaddr: Multiaddr = "/ip6/2001:db8::/tcp/1234".parse().unwrap();
602 let valid_multiaddr_bytes = valid_multiaddr.to_vec();
603
604 let invalid_multiaddr = {
605 let a = vec![255; 8];
606 assert!(Multiaddr::try_from(a.clone()).is_err());
607 a
608 };
609
610 let payload = proto::Peer {
611 id: PeerId::random().to_bytes(),
612 addrs: vec![valid_multiaddr_bytes, invalid_multiaddr],
613 connection: proto::ConnectionType::CAN_CONNECT,
614 };
615
616 let peer = KadPeer::try_from(payload).expect("not to fail");
617
618 assert_eq!(peer.multiaddrs, vec![valid_multiaddr])
619 }
620
621 }