1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
68
69#[cfg(feature = "cbor")]
70pub mod cbor;
71mod codec;
72mod handler;
73#[cfg(feature = "json")]
74pub mod json;
75
76pub use codec::Codec;
77pub use handler::ProtocolSupport;
78
79use crate::handler::OutboundMessage;
80use futures::channel::oneshot;
81use handler::Handler;
82use libp2p_core::{transport::PortUse, ConnectedPoint, Endpoint, Multiaddr};
83use libp2p_identity::PeerId;
84use libp2p_swarm::{
85 behaviour::{AddressChange, ConnectionClosed, DialFailure, FromSwarm},
86 dial_opts::DialOpts,
87 ConnectionDenied, ConnectionHandler, ConnectionId, NetworkBehaviour, NotifyHandler,
88 PeerAddresses, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm,
89};
90use smallvec::SmallVec;
91use std::{
92 collections::{HashMap, HashSet, VecDeque},
93 fmt, io,
94 sync::{atomic::AtomicU64, Arc},
95 task::{Context, Poll},
96 time::Duration,
97};
98
99#[derive(Debug)]
101pub enum Message<TRequest, TResponse, TChannelResponse = TResponse> {
102 Request {
104 request_id: InboundRequestId,
106 request: TRequest,
108 channel: ResponseChannel<TChannelResponse>,
114 },
115 Response {
117 request_id: OutboundRequestId,
121 response: TResponse,
123 },
124}
125
126#[derive(Debug)]
128pub enum Event<TRequest, TResponse, TChannelResponse = TResponse> {
129 Message {
131 peer: PeerId,
133 message: Message<TRequest, TResponse, TChannelResponse>,
135 },
136 OutboundFailure {
138 peer: PeerId,
140 request_id: OutboundRequestId,
142 error: OutboundFailure,
144 },
145 InboundFailure {
147 peer: PeerId,
149 request_id: InboundRequestId,
151 error: InboundFailure,
153 },
154 ResponseSent {
159 peer: PeerId,
161 request_id: InboundRequestId,
163 },
164}
165
166#[derive(Debug)]
169pub enum OutboundFailure {
170 DialFailure,
172 Timeout,
177 ConnectionClosed,
182 UnsupportedProtocols,
184 Io(io::Error),
186}
187
188impl fmt::Display for OutboundFailure {
189 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190 match self {
191 OutboundFailure::DialFailure => write!(f, "Failed to dial the requested peer"),
192 OutboundFailure::Timeout => write!(f, "Timeout while waiting for a response"),
193 OutboundFailure::ConnectionClosed => {
194 write!(f, "Connection was closed before a response was received")
195 }
196 OutboundFailure::UnsupportedProtocols => {
197 write!(f, "The remote supports none of the requested protocols")
198 }
199 OutboundFailure::Io(e) => write!(f, "IO error on outbound stream: {e}"),
200 }
201 }
202}
203
204impl std::error::Error for OutboundFailure {}
205
206#[derive(Debug)]
209pub enum InboundFailure {
210 Timeout,
215 ConnectionClosed,
217 UnsupportedProtocols,
220 ResponseOmission,
224 Io(io::Error),
226}
227
228impl fmt::Display for InboundFailure {
229 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230 match self {
231 InboundFailure::Timeout => {
232 write!(f, "Timeout while receiving request or sending response")
233 }
234 InboundFailure::ConnectionClosed => {
235 write!(f, "Connection was closed before a response could be sent")
236 }
237 InboundFailure::UnsupportedProtocols => write!(
238 f,
239 "The local peer supports none of the protocols requested by the remote"
240 ),
241 InboundFailure::ResponseOmission => write!(
242 f,
243 "The response channel was dropped without sending a response to the remote"
244 ),
245 InboundFailure::Io(e) => write!(f, "IO error on inbound stream: {e}"),
246 }
247 }
248}
249
250impl std::error::Error for InboundFailure {}
251
252#[derive(Debug)]
256pub struct ResponseChannel<TResponse> {
257 sender: oneshot::Sender<TResponse>,
258}
259
260impl<TResponse> ResponseChannel<TResponse> {
261 pub fn is_open(&self) -> bool {
269 !self.sender.is_canceled()
270 }
271}
272
273#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
278pub struct InboundRequestId(u64);
279
280impl fmt::Display for InboundRequestId {
281 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
282 write!(f, "{}", self.0)
283 }
284}
285
286#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
291pub struct OutboundRequestId(u64);
292
293impl fmt::Display for OutboundRequestId {
294 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
295 write!(f, "{}", self.0)
296 }
297}
298
299#[derive(Debug, Clone)]
301pub struct Config {
302 request_timeout: Duration,
303 max_concurrent_streams: usize,
304}
305
306impl Default for Config {
307 fn default() -> Self {
308 Self {
309 request_timeout: Duration::from_secs(10),
310 max_concurrent_streams: 100,
311 }
312 }
313}
314
315impl Config {
316 #[deprecated(note = "Use `Config::with_request_timeout` for one-liner constructions.")]
318 pub fn set_request_timeout(&mut self, v: Duration) -> &mut Self {
319 self.request_timeout = v;
320 self
321 }
322
323 pub fn with_request_timeout(mut self, v: Duration) -> Self {
325 self.request_timeout = v;
326 self
327 }
328
329 pub fn with_max_concurrent_streams(mut self, num_streams: usize) -> Self {
331 self.max_concurrent_streams = num_streams;
332 self
333 }
334}
335
336pub struct Behaviour<TCodec>
338where
339 TCodec: Codec + Clone + Send + 'static,
340{
341 inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
343 outbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
345 next_outbound_request_id: OutboundRequestId,
347 next_inbound_request_id: Arc<AtomicU64>,
349 config: Config,
351 codec: TCodec,
353 pending_events:
355 VecDeque<ToSwarm<Event<TCodec::Request, TCodec::Response>, OutboundMessage<TCodec>>>,
356 connected: HashMap<PeerId, SmallVec<[Connection; 2]>>,
359 addresses: PeerAddresses,
361 pending_outbound_requests: HashMap<PeerId, SmallVec<[OutboundMessage<TCodec>; 10]>>,
364}
365
366impl<TCodec> Behaviour<TCodec>
367where
368 TCodec: Codec + Default + Clone + Send + 'static,
369{
370 pub fn new<I>(protocols: I, cfg: Config) -> Self
372 where
373 I: IntoIterator<Item = (TCodec::Protocol, ProtocolSupport)>,
374 {
375 Self::with_codec(TCodec::default(), protocols, cfg)
376 }
377}
378
379impl<TCodec> Behaviour<TCodec>
380where
381 TCodec: Codec + Clone + Send + 'static,
382{
383 pub fn with_codec<I>(codec: TCodec, protocols: I, cfg: Config) -> Self
386 where
387 I: IntoIterator<Item = (TCodec::Protocol, ProtocolSupport)>,
388 {
389 let mut inbound_protocols = SmallVec::new();
390 let mut outbound_protocols = SmallVec::new();
391 for (p, s) in protocols {
392 if s.inbound() {
393 inbound_protocols.push(p.clone());
394 }
395 if s.outbound() {
396 outbound_protocols.push(p.clone());
397 }
398 }
399 Behaviour {
400 inbound_protocols,
401 outbound_protocols,
402 next_outbound_request_id: OutboundRequestId(1),
403 next_inbound_request_id: Arc::new(AtomicU64::new(1)),
404 config: cfg,
405 codec,
406 pending_events: VecDeque::new(),
407 connected: HashMap::new(),
408 pending_outbound_requests: HashMap::new(),
409 addresses: PeerAddresses::default(),
410 }
411 }
412
413 pub fn send_request(&mut self, peer: &PeerId, request: TCodec::Request) -> OutboundRequestId {
426 let request_id = self.next_outbound_request_id();
427 let request = OutboundMessage {
428 request_id,
429 request,
430 protocols: self.outbound_protocols.clone(),
431 };
432
433 if let Some(request) = self.try_send_request(peer, request) {
434 self.pending_events.push_back(ToSwarm::Dial {
435 opts: DialOpts::peer_id(*peer).build(),
436 });
437 self.pending_outbound_requests
438 .entry(*peer)
439 .or_default()
440 .push(request);
441 }
442
443 request_id
444 }
445
446 pub fn send_response(
458 &mut self,
459 ch: ResponseChannel<TCodec::Response>,
460 rs: TCodec::Response,
461 ) -> Result<(), TCodec::Response> {
462 ch.sender.send(rs)
463 }
464
465 #[deprecated(note = "Use `Swarm::add_peer_address` instead.")]
474 pub fn add_address(&mut self, peer: &PeerId, address: Multiaddr) -> bool {
475 self.addresses.add(*peer, address)
476 }
477
478 #[deprecated(note = "Will be removed with the next breaking release and won't be replaced.")]
480 pub fn remove_address(&mut self, peer: &PeerId, address: &Multiaddr) {
481 self.addresses.remove(peer, address);
482 }
483
484 pub fn is_connected(&self, peer: &PeerId) -> bool {
486 if let Some(connections) = self.connected.get(peer) {
487 !connections.is_empty()
488 } else {
489 false
490 }
491 }
492
493 pub fn is_pending_outbound(&self, peer: &PeerId, request_id: &OutboundRequestId) -> bool {
497 let est_conn = self
499 .connected
500 .get(peer)
501 .map(|cs| {
502 cs.iter()
503 .any(|c| c.pending_outbound_responses.contains(request_id))
504 })
505 .unwrap_or(false);
506 let pen_conn = self
508 .pending_outbound_requests
509 .get(peer)
510 .map(|rps| rps.iter().any(|rp| rp.request_id == *request_id))
511 .unwrap_or(false);
512
513 est_conn || pen_conn
514 }
515
516 pub fn is_pending_inbound(&self, peer: &PeerId, request_id: &InboundRequestId) -> bool {
520 self.connected
521 .get(peer)
522 .map(|cs| {
523 cs.iter()
524 .any(|c| c.pending_inbound_responses.contains(request_id))
525 })
526 .unwrap_or(false)
527 }
528
529 fn next_outbound_request_id(&mut self) -> OutboundRequestId {
531 let request_id = self.next_outbound_request_id;
532 self.next_outbound_request_id.0 += 1;
533 request_id
534 }
535
536 fn try_send_request(
540 &mut self,
541 peer: &PeerId,
542 request: OutboundMessage<TCodec>,
543 ) -> Option<OutboundMessage<TCodec>> {
544 if let Some(connections) = self.connected.get_mut(peer) {
545 if connections.is_empty() {
546 return Some(request);
547 }
548 let ix = (request.request_id.0 as usize) % connections.len();
549 let conn = &mut connections[ix];
550 conn.pending_outbound_responses.insert(request.request_id);
551 self.pending_events.push_back(ToSwarm::NotifyHandler {
552 peer_id: *peer,
553 handler: NotifyHandler::One(conn.id),
554 event: request,
555 });
556 None
557 } else {
558 Some(request)
559 }
560 }
561
562 fn remove_pending_outbound_response(
568 &mut self,
569 peer: &PeerId,
570 connection: ConnectionId,
571 request: OutboundRequestId,
572 ) -> bool {
573 self.get_connection_mut(peer, connection)
574 .map(|c| c.pending_outbound_responses.remove(&request))
575 .unwrap_or(false)
576 }
577
578 fn remove_pending_inbound_response(
584 &mut self,
585 peer: &PeerId,
586 connection: ConnectionId,
587 request: InboundRequestId,
588 ) -> bool {
589 self.get_connection_mut(peer, connection)
590 .map(|c| c.pending_inbound_responses.remove(&request))
591 .unwrap_or(false)
592 }
593
594 fn get_connection_mut(
597 &mut self,
598 peer: &PeerId,
599 connection: ConnectionId,
600 ) -> Option<&mut Connection> {
601 self.connected
602 .get_mut(peer)
603 .and_then(|connections| connections.iter_mut().find(|c| c.id == connection))
604 }
605
606 fn on_address_change(
607 &mut self,
608 AddressChange {
609 peer_id,
610 connection_id,
611 new,
612 ..
613 }: AddressChange,
614 ) {
615 let new_address = match new {
616 ConnectedPoint::Dialer { address, .. } => Some(address.clone()),
617 ConnectedPoint::Listener { .. } => None,
618 };
619 let connections = self
620 .connected
621 .get_mut(&peer_id)
622 .expect("Address change can only happen on an established connection.");
623
624 let connection = connections
625 .iter_mut()
626 .find(|c| c.id == connection_id)
627 .expect("Address change can only happen on an established connection.");
628 connection.remote_address = new_address;
629 }
630
631 fn on_connection_closed(
632 &mut self,
633 ConnectionClosed {
634 peer_id,
635 connection_id,
636 remaining_established,
637 ..
638 }: ConnectionClosed,
639 ) {
640 let connections = self
641 .connected
642 .get_mut(&peer_id)
643 .expect("Expected some established connection to peer before closing.");
644
645 let connection = connections
646 .iter()
647 .position(|c| c.id == connection_id)
648 .map(|p: usize| connections.remove(p))
649 .expect("Expected connection to be established before closing.");
650
651 debug_assert_eq!(connections.is_empty(), remaining_established == 0);
652 if connections.is_empty() {
653 self.connected.remove(&peer_id);
654 }
655
656 for request_id in connection.pending_inbound_responses {
657 self.pending_events
658 .push_back(ToSwarm::GenerateEvent(Event::InboundFailure {
659 peer: peer_id,
660 request_id,
661 error: InboundFailure::ConnectionClosed,
662 }));
663 }
664
665 for request_id in connection.pending_outbound_responses {
666 self.pending_events
667 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
668 peer: peer_id,
669 request_id,
670 error: OutboundFailure::ConnectionClosed,
671 }));
672 }
673 }
674
675 fn on_dial_failure(&mut self, DialFailure { peer_id, .. }: DialFailure) {
676 if let Some(peer) = peer_id {
677 if let Some(pending) = self.pending_outbound_requests.remove(&peer) {
684 for request in pending {
685 self.pending_events
686 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
687 peer,
688 request_id: request.request_id,
689 error: OutboundFailure::DialFailure,
690 }));
691 }
692 }
693 }
694 }
695
696 fn preload_new_handler(
698 &mut self,
699 handler: &mut Handler<TCodec>,
700 peer: PeerId,
701 connection_id: ConnectionId,
702 remote_address: Option<Multiaddr>,
703 ) {
704 let mut connection = Connection::new(connection_id, remote_address);
705
706 if let Some(pending_requests) = self.pending_outbound_requests.remove(&peer) {
707 for request in pending_requests {
708 connection
709 .pending_outbound_responses
710 .insert(request.request_id);
711 handler.on_behaviour_event(request);
712 }
713 }
714
715 self.connected.entry(peer).or_default().push(connection);
716 }
717}
718
719impl<TCodec> NetworkBehaviour for Behaviour<TCodec>
720where
721 TCodec: Codec + Send + Clone + 'static,
722{
723 type ConnectionHandler = Handler<TCodec>;
724 type ToSwarm = Event<TCodec::Request, TCodec::Response>;
725
726 fn handle_established_inbound_connection(
727 &mut self,
728 connection_id: ConnectionId,
729 peer: PeerId,
730 _: &Multiaddr,
731 _: &Multiaddr,
732 ) -> Result<THandler<Self>, ConnectionDenied> {
733 let mut handler = Handler::new(
734 self.inbound_protocols.clone(),
735 self.codec.clone(),
736 self.config.request_timeout,
737 self.next_inbound_request_id.clone(),
738 self.config.max_concurrent_streams,
739 );
740
741 self.preload_new_handler(&mut handler, peer, connection_id, None);
742
743 Ok(handler)
744 }
745
746 fn handle_pending_outbound_connection(
747 &mut self,
748 _connection_id: ConnectionId,
749 maybe_peer: Option<PeerId>,
750 _addresses: &[Multiaddr],
751 _effective_role: Endpoint,
752 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
753 let peer = match maybe_peer {
754 None => return Ok(vec![]),
755 Some(peer) => peer,
756 };
757
758 let mut addresses = Vec::new();
759 if let Some(connections) = self.connected.get(&peer) {
760 addresses.extend(connections.iter().filter_map(|c| c.remote_address.clone()))
761 }
762
763 let cached_addrs = self.addresses.get(&peer);
764 addresses.extend(cached_addrs);
765
766 Ok(addresses)
767 }
768
769 fn handle_established_outbound_connection(
770 &mut self,
771 connection_id: ConnectionId,
772 peer: PeerId,
773 remote_address: &Multiaddr,
774 _: Endpoint,
775 _: PortUse,
776 ) -> Result<THandler<Self>, ConnectionDenied> {
777 let mut handler = Handler::new(
778 self.inbound_protocols.clone(),
779 self.codec.clone(),
780 self.config.request_timeout,
781 self.next_inbound_request_id.clone(),
782 self.config.max_concurrent_streams,
783 );
784
785 self.preload_new_handler(
786 &mut handler,
787 peer,
788 connection_id,
789 Some(remote_address.clone()),
790 );
791
792 Ok(handler)
793 }
794
795 fn on_swarm_event(&mut self, event: FromSwarm) {
796 self.addresses.on_swarm_event(&event);
797 match event {
798 FromSwarm::ConnectionEstablished(_) => {}
799 FromSwarm::ConnectionClosed(connection_closed) => {
800 self.on_connection_closed(connection_closed)
801 }
802 FromSwarm::AddressChange(address_change) => self.on_address_change(address_change),
803 FromSwarm::DialFailure(dial_failure) => self.on_dial_failure(dial_failure),
804 _ => {}
805 }
806 }
807
808 fn on_connection_handler_event(
809 &mut self,
810 peer: PeerId,
811 connection: ConnectionId,
812 event: THandlerOutEvent<Self>,
813 ) {
814 match event {
815 handler::Event::Response {
816 request_id,
817 response,
818 } => {
819 let removed = self.remove_pending_outbound_response(&peer, connection, request_id);
820 debug_assert!(
821 removed,
822 "Expect request_id to be pending before receiving response.",
823 );
824
825 let message = Message::Response {
826 request_id,
827 response,
828 };
829 self.pending_events
830 .push_back(ToSwarm::GenerateEvent(Event::Message { peer, message }));
831 }
832 handler::Event::Request {
833 request_id,
834 request,
835 sender,
836 } => match self.get_connection_mut(&peer, connection) {
837 Some(connection) => {
838 let inserted = connection.pending_inbound_responses.insert(request_id);
839 debug_assert!(inserted, "Expect id of new request to be unknown.");
840
841 let channel = ResponseChannel { sender };
842 let message = Message::Request {
843 request_id,
844 request,
845 channel,
846 };
847 self.pending_events
848 .push_back(ToSwarm::GenerateEvent(Event::Message { peer, message }));
849 }
850 None => {
851 tracing::debug!("Connection ({connection}) closed after `Event::Request` ({request_id}) has been emitted.");
852 }
853 },
854 handler::Event::ResponseSent(request_id) => {
855 let removed = self.remove_pending_inbound_response(&peer, connection, request_id);
856 debug_assert!(
857 removed,
858 "Expect request_id to be pending before response is sent."
859 );
860
861 self.pending_events
862 .push_back(ToSwarm::GenerateEvent(Event::ResponseSent {
863 peer,
864 request_id,
865 }));
866 }
867 handler::Event::ResponseOmission(request_id) => {
868 let removed = self.remove_pending_inbound_response(&peer, connection, request_id);
869 debug_assert!(
870 removed,
871 "Expect request_id to be pending before response is omitted.",
872 );
873
874 self.pending_events
875 .push_back(ToSwarm::GenerateEvent(Event::InboundFailure {
876 peer,
877 request_id,
878 error: InboundFailure::ResponseOmission,
879 }));
880 }
881 handler::Event::OutboundTimeout(request_id) => {
882 let removed = self.remove_pending_outbound_response(&peer, connection, request_id);
883 debug_assert!(
884 removed,
885 "Expect request_id to be pending before request times out."
886 );
887
888 self.pending_events
889 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
890 peer,
891 request_id,
892 error: OutboundFailure::Timeout,
893 }));
894 }
895 handler::Event::OutboundUnsupportedProtocols(request_id) => {
896 let removed = self.remove_pending_outbound_response(&peer, connection, request_id);
897 debug_assert!(
898 removed,
899 "Expect request_id to be pending before failing to connect.",
900 );
901
902 self.pending_events
903 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
904 peer,
905 request_id,
906 error: OutboundFailure::UnsupportedProtocols,
907 }));
908 }
909 handler::Event::OutboundStreamFailed { request_id, error } => {
910 let removed = self.remove_pending_outbound_response(&peer, connection, request_id);
911 debug_assert!(removed, "Expect request_id to be pending upon failure");
912
913 self.pending_events
914 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
915 peer,
916 request_id,
917 error: OutboundFailure::Io(error),
918 }))
919 }
920 handler::Event::InboundTimeout(request_id) => {
921 let removed = self.remove_pending_inbound_response(&peer, connection, request_id);
922
923 if removed {
924 self.pending_events
925 .push_back(ToSwarm::GenerateEvent(Event::InboundFailure {
926 peer,
927 request_id,
928 error: InboundFailure::Timeout,
929 }));
930 } else {
931 tracing::debug!(
933 "Inbound request timeout for an unknown request_id ({request_id})"
934 );
935 }
936 }
937 handler::Event::InboundStreamFailed { request_id, error } => {
938 let removed = self.remove_pending_inbound_response(&peer, connection, request_id);
939
940 if removed {
941 self.pending_events
942 .push_back(ToSwarm::GenerateEvent(Event::InboundFailure {
943 peer,
944 request_id,
945 error: InboundFailure::Io(error),
946 }));
947 } else {
948 tracing::debug!("Inbound failure is reported for an unknown request_id ({request_id}): {error}");
950 }
951 }
952 }
953 }
954
955 #[tracing::instrument(level = "trace", name = "NetworkBehaviour::poll", skip(self))]
956 fn poll(&mut self, _: &mut Context<'_>) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
957 if let Some(ev) = self.pending_events.pop_front() {
958 return Poll::Ready(ev);
959 } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
960 self.pending_events.shrink_to_fit();
961 }
962
963 Poll::Pending
964 }
965}
966
967const EMPTY_QUEUE_SHRINK_THRESHOLD: usize = 100;
972
973struct Connection {
975 id: ConnectionId,
976 remote_address: Option<Multiaddr>,
977 pending_outbound_responses: HashSet<OutboundRequestId>,
981 pending_inbound_responses: HashSet<InboundRequestId>,
984}
985
986impl Connection {
987 fn new(id: ConnectionId, remote_address: Option<Multiaddr>) -> Self {
988 Self {
989 id,
990 remote_address,
991 pending_outbound_responses: Default::default(),
992 pending_inbound_responses: Default::default(),
993 }
994 }
995}