1use crate::proto;
30use crate::record::{self, Record};
31use asynchronous_codec::{Decoder, Encoder, Framed};
32use bytes::BytesMut;
33use futures::prelude::*;
34use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo};
35use libp2p_core::Multiaddr;
36use libp2p_identity::PeerId;
37use libp2p_swarm::StreamProtocol;
38use std::marker::PhantomData;
39use std::time::Duration;
40use std::{io, iter};
41use tracing::debug;
42use web_time::Instant;
43
44pub(crate) const DEFAULT_PROTO_NAME: StreamProtocol = StreamProtocol::new("/ipfs/kad/1.0.0");
46pub(crate) const DEFAULT_MAX_PACKET_SIZE: usize = 16 * 1024;
48#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
50pub enum ConnectionType {
51 NotConnected = 0,
53 Connected = 1,
55 CanConnect = 2,
57 CannotConnect = 3,
59}
60
61impl From<proto::ConnectionType> for ConnectionType {
62 fn from(raw: proto::ConnectionType) -> ConnectionType {
63 use proto::ConnectionType::*;
64 match raw {
65 NOT_CONNECTED => ConnectionType::NotConnected,
66 CONNECTED => ConnectionType::Connected,
67 CAN_CONNECT => ConnectionType::CanConnect,
68 CANNOT_CONNECT => ConnectionType::CannotConnect,
69 }
70 }
71}
72
73impl From<ConnectionType> for proto::ConnectionType {
74 fn from(val: ConnectionType) -> Self {
75 use proto::ConnectionType::*;
76 match val {
77 ConnectionType::NotConnected => NOT_CONNECTED,
78 ConnectionType::Connected => CONNECTED,
79 ConnectionType::CanConnect => CAN_CONNECT,
80 ConnectionType::CannotConnect => CANNOT_CONNECT,
81 }
82 }
83}
84
85#[derive(Debug, Clone, PartialEq, Eq)]
87pub struct KadPeer {
88 pub node_id: PeerId,
90 pub multiaddrs: Vec<Multiaddr>,
92 pub connection_ty: ConnectionType,
94}
95
96impl TryFrom<proto::Peer> for KadPeer {
98 type Error = io::Error;
99
100 fn try_from(peer: proto::Peer) -> Result<KadPeer, Self::Error> {
101 let node_id = PeerId::from_bytes(&peer.id).map_err(|_| invalid_data("invalid peer id"))?;
104
105 let mut addrs = Vec::with_capacity(peer.addrs.len());
106 for addr in peer.addrs.into_iter() {
107 match Multiaddr::try_from(addr).map(|addr| addr.with_p2p(node_id)) {
108 Ok(Ok(a)) => addrs.push(a),
109 Ok(Err(a)) => {
110 debug!("Unable to parse multiaddr: {a} is not compatible with {node_id}")
111 }
112 Err(e) => debug!("Unable to parse multiaddr: {e}"),
113 };
114 }
115
116 Ok(KadPeer {
117 node_id,
118 multiaddrs: addrs,
119 connection_ty: peer.connection.into(),
120 })
121 }
122}
123
124impl From<KadPeer> for proto::Peer {
125 fn from(peer: KadPeer) -> Self {
126 proto::Peer {
127 id: peer.node_id.to_bytes(),
128 addrs: peer.multiaddrs.into_iter().map(|a| a.to_vec()).collect(),
129 connection: peer.connection_ty.into(),
130 }
131 }
132}
133
134#[derive(Debug, Clone)]
140pub struct ProtocolConfig {
141 protocol_names: Vec<StreamProtocol>,
142 max_packet_size: usize,
144}
145
146impl ProtocolConfig {
147 pub fn new(protocol_name: StreamProtocol) -> Self {
149 ProtocolConfig {
150 protocol_names: vec![protocol_name],
151 max_packet_size: DEFAULT_MAX_PACKET_SIZE,
152 }
153 }
154
155 #[deprecated(note = "Use `ProtocolConfig::new` instead")]
157 #[allow(clippy::should_implement_trait)]
158 pub fn default() -> Self {
159 Default::default()
160 }
161
162 pub fn protocol_names(&self) -> &[StreamProtocol] {
164 &self.protocol_names
165 }
166
167 #[deprecated(note = "Use `ProtocolConfig::new` instead")]
170 pub fn set_protocol_names(&mut self, names: Vec<StreamProtocol>) {
171 self.protocol_names = names;
172 }
173
174 pub fn set_max_packet_size(&mut self, size: usize) {
176 self.max_packet_size = size;
177 }
178}
179
180impl Default for ProtocolConfig {
181 fn default() -> Self {
185 ProtocolConfig {
186 protocol_names: iter::once(DEFAULT_PROTO_NAME).collect(),
187 max_packet_size: DEFAULT_MAX_PACKET_SIZE,
188 }
189 }
190}
191
192impl UpgradeInfo for ProtocolConfig {
193 type Info = StreamProtocol;
194 type InfoIter = std::vec::IntoIter<Self::Info>;
195
196 fn protocol_info(&self) -> Self::InfoIter {
197 self.protocol_names.clone().into_iter()
198 }
199}
200
201pub struct Codec<A, B> {
203 codec: quick_protobuf_codec::Codec<proto::Message>,
204 __phantom: PhantomData<(A, B)>,
205}
206impl<A, B> Codec<A, B> {
207 fn new(max_packet_size: usize) -> Self {
208 Codec {
209 codec: quick_protobuf_codec::Codec::new(max_packet_size),
210 __phantom: PhantomData,
211 }
212 }
213}
214
215impl<A: Into<proto::Message>, B> Encoder for Codec<A, B> {
216 type Error = io::Error;
217 type Item<'a> = A;
218
219 fn encode(&mut self, item: Self::Item<'_>, dst: &mut BytesMut) -> Result<(), Self::Error> {
220 Ok(self.codec.encode(item.into(), dst)?)
221 }
222}
223impl<A, B: TryFrom<proto::Message, Error = io::Error>> Decoder for Codec<A, B> {
224 type Error = io::Error;
225 type Item = B;
226
227 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
228 self.codec.decode(src)?.map(B::try_from).transpose()
229 }
230}
231
232pub(crate) type KadInStreamSink<S> = Framed<S, Codec<KadResponseMsg, KadRequestMsg>>;
234pub(crate) type KadOutStreamSink<S> = Framed<S, Codec<KadRequestMsg, KadResponseMsg>>;
236
237impl<C> InboundUpgrade<C> for ProtocolConfig
238where
239 C: AsyncRead + AsyncWrite + Unpin,
240{
241 type Output = KadInStreamSink<C>;
242 type Future = future::Ready<Result<Self::Output, io::Error>>;
243 type Error = io::Error;
244
245 fn upgrade_inbound(self, incoming: C, _: Self::Info) -> Self::Future {
246 let codec = Codec::new(self.max_packet_size);
247
248 future::ok(Framed::new(incoming, codec))
249 }
250}
251
252impl<C> OutboundUpgrade<C> for ProtocolConfig
253where
254 C: AsyncRead + AsyncWrite + Unpin,
255{
256 type Output = KadOutStreamSink<C>;
257 type Future = future::Ready<Result<Self::Output, io::Error>>;
258 type Error = io::Error;
259
260 fn upgrade_outbound(self, incoming: C, _: Self::Info) -> Self::Future {
261 let codec = Codec::new(self.max_packet_size);
262
263 future::ok(Framed::new(incoming, codec))
264 }
265}
266
267#[derive(Debug, Clone, PartialEq, Eq)]
269pub enum KadRequestMsg {
270 Ping,
272
273 FindNode {
276 key: Vec<u8>,
278 },
279
280 GetProviders {
283 key: record::Key,
285 },
286
287 AddProvider {
289 key: record::Key,
291 provider: KadPeer,
293 },
294
295 GetValue {
297 key: record::Key,
299 },
300
301 PutValue { record: Record },
303}
304
305#[derive(Debug, Clone, PartialEq, Eq)]
307pub enum KadResponseMsg {
308 Pong,
310
311 FindNode {
313 closer_peers: Vec<KadPeer>,
315 },
316
317 GetProviders {
319 closer_peers: Vec<KadPeer>,
321 provider_peers: Vec<KadPeer>,
323 },
324
325 GetValue {
327 record: Option<Record>,
329 closer_peers: Vec<KadPeer>,
331 },
332
333 PutValue {
335 key: record::Key,
337 value: Vec<u8>,
339 },
340}
341
342impl From<KadRequestMsg> for proto::Message {
343 fn from(kad_msg: KadRequestMsg) -> Self {
344 req_msg_to_proto(kad_msg)
345 }
346}
347impl From<KadResponseMsg> for proto::Message {
348 fn from(kad_msg: KadResponseMsg) -> Self {
349 resp_msg_to_proto(kad_msg)
350 }
351}
352impl TryFrom<proto::Message> for KadRequestMsg {
353 type Error = io::Error;
354
355 fn try_from(message: proto::Message) -> Result<Self, Self::Error> {
356 proto_to_req_msg(message)
357 }
358}
359impl TryFrom<proto::Message> for KadResponseMsg {
360 type Error = io::Error;
361
362 fn try_from(message: proto::Message) -> Result<Self, Self::Error> {
363 proto_to_resp_msg(message)
364 }
365}
366
367fn req_msg_to_proto(kad_msg: KadRequestMsg) -> proto::Message {
369 match kad_msg {
370 KadRequestMsg::Ping => proto::Message {
371 type_pb: proto::MessageType::PING,
372 ..proto::Message::default()
373 },
374 KadRequestMsg::FindNode { key } => proto::Message {
375 type_pb: proto::MessageType::FIND_NODE,
376 key,
377 clusterLevelRaw: 10,
378 ..proto::Message::default()
379 },
380 KadRequestMsg::GetProviders { key } => proto::Message {
381 type_pb: proto::MessageType::GET_PROVIDERS,
382 key: key.to_vec(),
383 clusterLevelRaw: 10,
384 ..proto::Message::default()
385 },
386 KadRequestMsg::AddProvider { key, provider } => proto::Message {
387 type_pb: proto::MessageType::ADD_PROVIDER,
388 clusterLevelRaw: 10,
389 key: key.to_vec(),
390 providerPeers: vec![provider.into()],
391 ..proto::Message::default()
392 },
393 KadRequestMsg::GetValue { key } => proto::Message {
394 type_pb: proto::MessageType::GET_VALUE,
395 clusterLevelRaw: 10,
396 key: key.to_vec(),
397 ..proto::Message::default()
398 },
399 KadRequestMsg::PutValue { record } => proto::Message {
400 type_pb: proto::MessageType::PUT_VALUE,
401 key: record.key.to_vec(),
402 record: Some(record_to_proto(record)),
403 ..proto::Message::default()
404 },
405 }
406}
407
408fn resp_msg_to_proto(kad_msg: KadResponseMsg) -> proto::Message {
410 match kad_msg {
411 KadResponseMsg::Pong => proto::Message {
412 type_pb: proto::MessageType::PING,
413 ..proto::Message::default()
414 },
415 KadResponseMsg::FindNode { closer_peers } => proto::Message {
416 type_pb: proto::MessageType::FIND_NODE,
417 clusterLevelRaw: 9,
418 closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
419 ..proto::Message::default()
420 },
421 KadResponseMsg::GetProviders {
422 closer_peers,
423 provider_peers,
424 } => proto::Message {
425 type_pb: proto::MessageType::GET_PROVIDERS,
426 clusterLevelRaw: 9,
427 closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
428 providerPeers: provider_peers.into_iter().map(KadPeer::into).collect(),
429 ..proto::Message::default()
430 },
431 KadResponseMsg::GetValue {
432 record,
433 closer_peers,
434 } => proto::Message {
435 type_pb: proto::MessageType::GET_VALUE,
436 clusterLevelRaw: 9,
437 closerPeers: closer_peers.into_iter().map(KadPeer::into).collect(),
438 record: record.map(record_to_proto),
439 ..proto::Message::default()
440 },
441 KadResponseMsg::PutValue { key, value } => proto::Message {
442 type_pb: proto::MessageType::PUT_VALUE,
443 key: key.to_vec(),
444 record: Some(proto::Record {
445 key: key.to_vec(),
446 value,
447 ..proto::Record::default()
448 }),
449 ..proto::Message::default()
450 },
451 }
452}
453
454fn proto_to_req_msg(message: proto::Message) -> Result<KadRequestMsg, io::Error> {
458 match message.type_pb {
459 proto::MessageType::PING => Ok(KadRequestMsg::Ping),
460 proto::MessageType::PUT_VALUE => {
461 let record = record_from_proto(message.record.unwrap_or_default())?;
462 Ok(KadRequestMsg::PutValue { record })
463 }
464 proto::MessageType::GET_VALUE => Ok(KadRequestMsg::GetValue {
465 key: record::Key::from(message.key),
466 }),
467 proto::MessageType::FIND_NODE => Ok(KadRequestMsg::FindNode { key: message.key }),
468 proto::MessageType::GET_PROVIDERS => Ok(KadRequestMsg::GetProviders {
469 key: record::Key::from(message.key),
470 }),
471 proto::MessageType::ADD_PROVIDER => {
472 let provider = message
476 .providerPeers
477 .into_iter()
478 .find_map(|peer| KadPeer::try_from(peer).ok());
479
480 if let Some(provider) = provider {
481 let key = record::Key::from(message.key);
482 Ok(KadRequestMsg::AddProvider { key, provider })
483 } else {
484 Err(invalid_data("AddProvider message with no valid peer."))
485 }
486 }
487 }
488}
489
490fn proto_to_resp_msg(message: proto::Message) -> Result<KadResponseMsg, io::Error> {
494 match message.type_pb {
495 proto::MessageType::PING => Ok(KadResponseMsg::Pong),
496 proto::MessageType::GET_VALUE => {
497 let record = if let Some(r) = message.record {
498 Some(record_from_proto(r)?)
499 } else {
500 None
501 };
502
503 let closer_peers = message
504 .closerPeers
505 .into_iter()
506 .filter_map(|peer| KadPeer::try_from(peer).ok())
507 .collect();
508
509 Ok(KadResponseMsg::GetValue {
510 record,
511 closer_peers,
512 })
513 }
514
515 proto::MessageType::FIND_NODE => {
516 let closer_peers = message
517 .closerPeers
518 .into_iter()
519 .filter_map(|peer| KadPeer::try_from(peer).ok())
520 .collect();
521
522 Ok(KadResponseMsg::FindNode { closer_peers })
523 }
524
525 proto::MessageType::GET_PROVIDERS => {
526 let closer_peers = message
527 .closerPeers
528 .into_iter()
529 .filter_map(|peer| KadPeer::try_from(peer).ok())
530 .collect();
531
532 let provider_peers = message
533 .providerPeers
534 .into_iter()
535 .filter_map(|peer| KadPeer::try_from(peer).ok())
536 .collect();
537
538 Ok(KadResponseMsg::GetProviders {
539 closer_peers,
540 provider_peers,
541 })
542 }
543
544 proto::MessageType::PUT_VALUE => {
545 let key = record::Key::from(message.key);
546 let rec = message
547 .record
548 .ok_or_else(|| invalid_data("received PutValue message with no record"))?;
549
550 Ok(KadResponseMsg::PutValue {
551 key,
552 value: rec.value,
553 })
554 }
555
556 proto::MessageType::ADD_PROVIDER => {
557 Err(invalid_data("received an unexpected AddProvider message"))
558 }
559 }
560}
561
562fn record_from_proto(record: proto::Record) -> Result<Record, io::Error> {
563 let key = record::Key::from(record.key);
564 let value = record.value;
565
566 let publisher = if !record.publisher.is_empty() {
567 PeerId::from_bytes(&record.publisher)
568 .map(Some)
569 .map_err(|_| invalid_data("Invalid publisher peer ID."))?
570 } else {
571 None
572 };
573
574 let expires = if record.ttl > 0 {
575 Some(Instant::now() + Duration::from_secs(record.ttl as u64))
576 } else {
577 None
578 };
579
580 Ok(Record {
581 key,
582 value,
583 publisher,
584 expires,
585 })
586}
587
588fn record_to_proto(record: Record) -> proto::Record {
589 proto::Record {
590 key: record.key.to_vec(),
591 value: record.value,
592 publisher: record.publisher.map(|id| id.to_bytes()).unwrap_or_default(),
593 ttl: record
594 .expires
595 .map(|t| {
596 let now = Instant::now();
597 if t > now {
598 (t - now).as_secs() as u32
599 } else {
600 1 }
602 })
603 .unwrap_or(0),
604 timeReceived: String::new(),
605 }
606}
607
608fn invalid_data<E>(e: E) -> io::Error
610where
611 E: Into<Box<dyn std::error::Error + Send + Sync>>,
612{
613 io::Error::new(io::ErrorKind::InvalidData, e)
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619
620 #[test]
621 fn append_p2p() {
622 let peer_id = PeerId::random();
623 let multiaddr = "/ip6/2001:db8::/tcp/1234".parse::<Multiaddr>().unwrap();
624
625 let payload = proto::Peer {
626 id: peer_id.to_bytes(),
627 addrs: vec![multiaddr.to_vec()],
628 connection: proto::ConnectionType::CAN_CONNECT,
629 };
630
631 let peer = KadPeer::try_from(payload).unwrap();
632
633 assert_eq!(peer.multiaddrs, vec![multiaddr.with_p2p(peer_id).unwrap()])
634 }
635
636 #[test]
637 fn skip_invalid_multiaddr() {
638 let peer_id = PeerId::random();
639 let multiaddr = "/ip6/2001:db8::/tcp/1234".parse::<Multiaddr>().unwrap();
640
641 let valid_multiaddr = multiaddr.clone().with_p2p(peer_id).unwrap();
642
643 let multiaddr_with_incorrect_peer_id = {
644 let other_peer_id = PeerId::random();
645 assert_ne!(peer_id, other_peer_id);
646 multiaddr.with_p2p(other_peer_id).unwrap()
647 };
648
649 let invalid_multiaddr = {
650 let a = vec![255; 8];
651 assert!(Multiaddr::try_from(a.clone()).is_err());
652 a
653 };
654
655 let payload = proto::Peer {
656 id: peer_id.to_bytes(),
657 addrs: vec![
658 valid_multiaddr.to_vec(),
659 multiaddr_with_incorrect_peer_id.to_vec(),
660 invalid_multiaddr,
661 ],
662 connection: proto::ConnectionType::CAN_CONNECT,
663 };
664
665 let peer = KadPeer::try_from(payload).unwrap();
666
667 assert_eq!(peer.multiaddrs, vec![valid_multiaddr])
668 }
669
670 }