1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
68
69#[cfg(feature = "cbor")]
70pub mod cbor;
71mod codec;
72mod handler;
73#[cfg(feature = "json")]
74pub mod json;
75
76pub use codec::Codec;
77pub use handler::ProtocolSupport;
78
79use crate::handler::protocol::RequestProtocol;
80use futures::channel::oneshot;
81use handler::Handler;
82use libp2p_core::{ConnectedPoint, Endpoint, Multiaddr};
83use libp2p_identity::PeerId;
84use libp2p_swarm::{
85 behaviour::{AddressChange, ConnectionClosed, DialFailure, FromSwarm},
86 dial_opts::DialOpts,
87 ConnectionDenied, ConnectionHandler, ConnectionId, NetworkBehaviour, NotifyHandler,
88 PollParameters, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm,
89};
90use smallvec::SmallVec;
91use std::{
92 collections::{HashMap, HashSet, VecDeque},
93 fmt,
94 sync::{atomic::AtomicU64, Arc},
95 task::{Context, Poll},
96 time::Duration,
97};
98
99#[derive(Debug)]
101pub enum Message<TRequest, TResponse, TChannelResponse = TResponse> {
102 Request {
104 request_id: RequestId,
106 request: TRequest,
108 channel: ResponseChannel<TChannelResponse>,
114 },
115 Response {
117 request_id: RequestId,
121 response: TResponse,
123 },
124}
125
126#[derive(Debug)]
128pub enum Event<TRequest, TResponse, TChannelResponse = TResponse> {
129 Message {
131 peer: PeerId,
133 message: Message<TRequest, TResponse, TChannelResponse>,
135 },
136 OutboundFailure {
138 peer: PeerId,
140 request_id: RequestId,
142 error: OutboundFailure,
144 },
145 InboundFailure {
147 peer: PeerId,
149 request_id: RequestId,
151 error: InboundFailure,
153 },
154 ResponseSent {
159 peer: PeerId,
161 request_id: RequestId,
163 },
164}
165
166#[derive(Debug, Clone, PartialEq, Eq)]
169pub enum OutboundFailure {
170 DialFailure,
172 Timeout,
177 ConnectionClosed,
182 UnsupportedProtocols,
184}
185
186impl fmt::Display for OutboundFailure {
187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 match self {
189 OutboundFailure::DialFailure => write!(f, "Failed to dial the requested peer"),
190 OutboundFailure::Timeout => write!(f, "Timeout while waiting for a response"),
191 OutboundFailure::ConnectionClosed => {
192 write!(f, "Connection was closed before a response was received")
193 }
194 OutboundFailure::UnsupportedProtocols => {
195 write!(f, "The remote supports none of the requested protocols")
196 }
197 }
198 }
199}
200
201impl std::error::Error for OutboundFailure {}
202
203#[derive(Debug, Clone, PartialEq, Eq)]
206pub enum InboundFailure {
207 Timeout,
212 ConnectionClosed,
214 UnsupportedProtocols,
217 ResponseOmission,
221}
222
223impl fmt::Display for InboundFailure {
224 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225 match self {
226 InboundFailure::Timeout => {
227 write!(f, "Timeout while receiving request or sending response")
228 }
229 InboundFailure::ConnectionClosed => {
230 write!(f, "Connection was closed before a response could be sent")
231 }
232 InboundFailure::UnsupportedProtocols => write!(
233 f,
234 "The local peer supports none of the protocols requested by the remote"
235 ),
236 InboundFailure::ResponseOmission => write!(
237 f,
238 "The response channel was dropped without sending a response to the remote"
239 ),
240 }
241 }
242}
243
244impl std::error::Error for InboundFailure {}
245
246#[derive(Debug)]
250pub struct ResponseChannel<TResponse> {
251 sender: oneshot::Sender<TResponse>,
252}
253
254impl<TResponse> ResponseChannel<TResponse> {
255 pub fn is_open(&self) -> bool {
263 !self.sender.is_canceled()
264 }
265}
266
267#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
275pub struct RequestId(u64);
276
277impl fmt::Display for RequestId {
278 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
279 write!(f, "{}", self.0)
280 }
281}
282
283#[derive(Debug, Clone)]
285pub struct Config {
286 request_timeout: Duration,
287 connection_keep_alive: Duration,
288}
289
290impl Default for Config {
291 fn default() -> Self {
292 Self {
293 connection_keep_alive: Duration::from_secs(10),
294 request_timeout: Duration::from_secs(10),
295 }
296 }
297}
298
299impl Config {
300 #[deprecated(
302 note = "Set a global idle connection timeout via `SwarmBuilder::idle_connection_timeout` instead."
303 )]
304 pub fn set_connection_keep_alive(&mut self, v: Duration) -> &mut Self {
305 self.connection_keep_alive = v;
306 self
307 }
308
309 pub fn set_request_timeout(&mut self, v: Duration) -> &mut Self {
311 self.request_timeout = v;
312 self
313 }
314}
315
316pub struct Behaviour<TCodec>
318where
319 TCodec: Codec + Clone + Send + 'static,
320{
321 inbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
323 outbound_protocols: SmallVec<[TCodec::Protocol; 2]>,
325 next_request_id: RequestId,
327 next_inbound_id: Arc<AtomicU64>,
329 config: Config,
331 codec: TCodec,
333 pending_events:
335 VecDeque<ToSwarm<Event<TCodec::Request, TCodec::Response>, RequestProtocol<TCodec>>>,
336 connected: HashMap<PeerId, SmallVec<[Connection; 2]>>,
339 addresses: HashMap<PeerId, HashSet<Multiaddr>>,
341 pending_outbound_requests: HashMap<PeerId, SmallVec<[RequestProtocol<TCodec>; 10]>>,
344}
345
346impl<TCodec> Behaviour<TCodec>
347where
348 TCodec: Codec + Default + Clone + Send + 'static,
349{
350 pub fn new<I>(protocols: I, cfg: Config) -> Self
352 where
353 I: IntoIterator<Item = (TCodec::Protocol, ProtocolSupport)>,
354 {
355 Self::with_codec(TCodec::default(), protocols, cfg)
356 }
357}
358
359impl<TCodec> Behaviour<TCodec>
360where
361 TCodec: Codec + Clone + Send + 'static,
362{
363 pub fn with_codec<I>(codec: TCodec, protocols: I, cfg: Config) -> Self
366 where
367 I: IntoIterator<Item = (TCodec::Protocol, ProtocolSupport)>,
368 {
369 let mut inbound_protocols = SmallVec::new();
370 let mut outbound_protocols = SmallVec::new();
371 for (p, s) in protocols {
372 if s.inbound() {
373 inbound_protocols.push(p.clone());
374 }
375 if s.outbound() {
376 outbound_protocols.push(p.clone());
377 }
378 }
379 Behaviour {
380 inbound_protocols,
381 outbound_protocols,
382 next_request_id: RequestId(1),
383 next_inbound_id: Arc::new(AtomicU64::new(1)),
384 config: cfg,
385 codec,
386 pending_events: VecDeque::new(),
387 connected: HashMap::new(),
388 pending_outbound_requests: HashMap::new(),
389 addresses: HashMap::new(),
390 }
391 }
392
393 pub fn send_request(&mut self, peer: &PeerId, request: TCodec::Request) -> RequestId {
406 let request_id = self.next_request_id();
407 let request = RequestProtocol {
408 request_id,
409 codec: self.codec.clone(),
410 protocols: self.outbound_protocols.clone(),
411 request,
412 };
413
414 if let Some(request) = self.try_send_request(peer, request) {
415 self.pending_events.push_back(ToSwarm::Dial {
416 opts: DialOpts::peer_id(*peer).build(),
417 });
418 self.pending_outbound_requests
419 .entry(*peer)
420 .or_default()
421 .push(request);
422 }
423
424 request_id
425 }
426
427 pub fn send_response(
439 &mut self,
440 ch: ResponseChannel<TCodec::Response>,
441 rs: TCodec::Response,
442 ) -> Result<(), TCodec::Response> {
443 ch.sender.send(rs)
444 }
445
446 pub fn add_address(&mut self, peer: &PeerId, address: Multiaddr) {
452 self.addresses.entry(*peer).or_default().insert(address);
453 }
454
455 pub fn remove_address(&mut self, peer: &PeerId, address: &Multiaddr) {
457 let mut last = false;
458 if let Some(addresses) = self.addresses.get_mut(peer) {
459 addresses.retain(|a| a != address);
460 last = addresses.is_empty();
461 }
462 if last {
463 self.addresses.remove(peer);
464 }
465 }
466
467 pub fn is_connected(&self, peer: &PeerId) -> bool {
469 if let Some(connections) = self.connected.get(peer) {
470 !connections.is_empty()
471 } else {
472 false
473 }
474 }
475
476 pub fn is_pending_outbound(&self, peer: &PeerId, request_id: &RequestId) -> bool {
480 let est_conn = self
482 .connected
483 .get(peer)
484 .map(|cs| {
485 cs.iter()
486 .any(|c| c.pending_inbound_responses.contains(request_id))
487 })
488 .unwrap_or(false);
489 let pen_conn = self
491 .pending_outbound_requests
492 .get(peer)
493 .map(|rps| rps.iter().any(|rp| rp.request_id == *request_id))
494 .unwrap_or(false);
495
496 est_conn || pen_conn
497 }
498
499 pub fn is_pending_inbound(&self, peer: &PeerId, request_id: &RequestId) -> bool {
503 self.connected
504 .get(peer)
505 .map(|cs| {
506 cs.iter()
507 .any(|c| c.pending_outbound_responses.contains(request_id))
508 })
509 .unwrap_or(false)
510 }
511
512 fn next_request_id(&mut self) -> RequestId {
514 let request_id = self.next_request_id;
515 self.next_request_id.0 += 1;
516 request_id
517 }
518
519 fn try_send_request(
523 &mut self,
524 peer: &PeerId,
525 request: RequestProtocol<TCodec>,
526 ) -> Option<RequestProtocol<TCodec>> {
527 if let Some(connections) = self.connected.get_mut(peer) {
528 if connections.is_empty() {
529 return Some(request);
530 }
531 let ix = (request.request_id.0 as usize) % connections.len();
532 let conn = &mut connections[ix];
533 conn.pending_inbound_responses.insert(request.request_id);
534 self.pending_events.push_back(ToSwarm::NotifyHandler {
535 peer_id: *peer,
536 handler: NotifyHandler::One(conn.id),
537 event: request,
538 });
539 None
540 } else {
541 Some(request)
542 }
543 }
544
545 fn remove_pending_outbound_response(
551 &mut self,
552 peer: &PeerId,
553 connection: ConnectionId,
554 request: RequestId,
555 ) -> bool {
556 self.get_connection_mut(peer, connection)
557 .map(|c| c.pending_outbound_responses.remove(&request))
558 .unwrap_or(false)
559 }
560
561 fn remove_pending_inbound_response(
567 &mut self,
568 peer: &PeerId,
569 connection: ConnectionId,
570 request: &RequestId,
571 ) -> bool {
572 self.get_connection_mut(peer, connection)
573 .map(|c| c.pending_inbound_responses.remove(request))
574 .unwrap_or(false)
575 }
576
577 fn get_connection_mut(
580 &mut self,
581 peer: &PeerId,
582 connection: ConnectionId,
583 ) -> Option<&mut Connection> {
584 self.connected
585 .get_mut(peer)
586 .and_then(|connections| connections.iter_mut().find(|c| c.id == connection))
587 }
588
589 fn on_address_change(
590 &mut self,
591 AddressChange {
592 peer_id,
593 connection_id,
594 new,
595 ..
596 }: AddressChange,
597 ) {
598 let new_address = match new {
599 ConnectedPoint::Dialer { address, .. } => Some(address.clone()),
600 ConnectedPoint::Listener { .. } => None,
601 };
602 let connections = self
603 .connected
604 .get_mut(&peer_id)
605 .expect("Address change can only happen on an established connection.");
606
607 let connection = connections
608 .iter_mut()
609 .find(|c| c.id == connection_id)
610 .expect("Address change can only happen on an established connection.");
611 connection.remote_address = new_address;
612 }
613
614 fn on_connection_closed(
615 &mut self,
616 ConnectionClosed {
617 peer_id,
618 connection_id,
619 remaining_established,
620 ..
621 }: ConnectionClosed<<Self as NetworkBehaviour>::ConnectionHandler>,
622 ) {
623 let connections = self
624 .connected
625 .get_mut(&peer_id)
626 .expect("Expected some established connection to peer before closing.");
627
628 let connection = connections
629 .iter()
630 .position(|c| c.id == connection_id)
631 .map(|p: usize| connections.remove(p))
632 .expect("Expected connection to be established before closing.");
633
634 debug_assert_eq!(connections.is_empty(), remaining_established == 0);
635 if connections.is_empty() {
636 self.connected.remove(&peer_id);
637 }
638
639 for request_id in connection.pending_outbound_responses {
640 self.pending_events
641 .push_back(ToSwarm::GenerateEvent(Event::InboundFailure {
642 peer: peer_id,
643 request_id,
644 error: InboundFailure::ConnectionClosed,
645 }));
646 }
647
648 for request_id in connection.pending_inbound_responses {
649 self.pending_events
650 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
651 peer: peer_id,
652 request_id,
653 error: OutboundFailure::ConnectionClosed,
654 }));
655 }
656 }
657
658 fn on_dial_failure(&mut self, DialFailure { peer_id, .. }: DialFailure) {
659 if let Some(peer) = peer_id {
660 if let Some(pending) = self.pending_outbound_requests.remove(&peer) {
667 for request in pending {
668 self.pending_events
669 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
670 peer,
671 request_id: request.request_id,
672 error: OutboundFailure::DialFailure,
673 }));
674 }
675 }
676 }
677 }
678
679 fn preload_new_handler(
681 &mut self,
682 handler: &mut Handler<TCodec>,
683 peer: PeerId,
684 connection_id: ConnectionId,
685 remote_address: Option<Multiaddr>,
686 ) {
687 let mut connection = Connection::new(connection_id, remote_address);
688
689 if let Some(pending_requests) = self.pending_outbound_requests.remove(&peer) {
690 for request in pending_requests {
691 connection
692 .pending_inbound_responses
693 .insert(request.request_id);
694 handler.on_behaviour_event(request);
695 }
696 }
697
698 self.connected.entry(peer).or_default().push(connection);
699 }
700}
701
702impl<TCodec> NetworkBehaviour for Behaviour<TCodec>
703where
704 TCodec: Codec + Send + Clone + 'static,
705{
706 type ConnectionHandler = Handler<TCodec>;
707 type ToSwarm = Event<TCodec::Request, TCodec::Response>;
708
709 fn handle_established_inbound_connection(
710 &mut self,
711 connection_id: ConnectionId,
712 peer: PeerId,
713 _: &Multiaddr,
714 _: &Multiaddr,
715 ) -> Result<THandler<Self>, ConnectionDenied> {
716 let mut handler = Handler::new(
717 self.inbound_protocols.clone(),
718 self.codec.clone(),
719 self.config.request_timeout,
720 self.config.connection_keep_alive,
721 self.next_inbound_id.clone(),
722 );
723
724 self.preload_new_handler(&mut handler, peer, connection_id, None);
725
726 Ok(handler)
727 }
728
729 fn handle_pending_outbound_connection(
730 &mut self,
731 _connection_id: ConnectionId,
732 maybe_peer: Option<PeerId>,
733 _addresses: &[Multiaddr],
734 _effective_role: Endpoint,
735 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
736 let peer = match maybe_peer {
737 None => return Ok(vec![]),
738 Some(peer) => peer,
739 };
740
741 let mut addresses = Vec::new();
742 if let Some(connections) = self.connected.get(&peer) {
743 addresses.extend(connections.iter().filter_map(|c| c.remote_address.clone()))
744 }
745 if let Some(more) = self.addresses.get(&peer) {
746 addresses.extend(more.iter().cloned());
747 }
748
749 Ok(addresses)
750 }
751
752 fn handle_established_outbound_connection(
753 &mut self,
754 connection_id: ConnectionId,
755 peer: PeerId,
756 remote_address: &Multiaddr,
757 _: Endpoint,
758 ) -> Result<THandler<Self>, ConnectionDenied> {
759 let mut handler = Handler::new(
760 self.inbound_protocols.clone(),
761 self.codec.clone(),
762 self.config.request_timeout,
763 self.config.connection_keep_alive,
764 self.next_inbound_id.clone(),
765 );
766
767 self.preload_new_handler(
768 &mut handler,
769 peer,
770 connection_id,
771 Some(remote_address.clone()),
772 );
773
774 Ok(handler)
775 }
776
777 fn on_swarm_event(&mut self, event: FromSwarm<Self::ConnectionHandler>) {
778 match event {
779 FromSwarm::ConnectionEstablished(_) => {}
780 FromSwarm::ConnectionClosed(connection_closed) => {
781 self.on_connection_closed(connection_closed)
782 }
783 FromSwarm::AddressChange(address_change) => self.on_address_change(address_change),
784 FromSwarm::DialFailure(dial_failure) => self.on_dial_failure(dial_failure),
785 FromSwarm::ListenFailure(_) => {}
786 FromSwarm::NewListener(_) => {}
787 FromSwarm::NewListenAddr(_) => {}
788 FromSwarm::ExpiredListenAddr(_) => {}
789 FromSwarm::ListenerError(_) => {}
790 FromSwarm::ListenerClosed(_) => {}
791 FromSwarm::NewExternalAddrCandidate(_) => {}
792 FromSwarm::ExternalAddrExpired(_) => {}
793 FromSwarm::ExternalAddrConfirmed(_) => {}
794 }
795 }
796
797 fn on_connection_handler_event(
798 &mut self,
799 peer: PeerId,
800 connection: ConnectionId,
801 event: THandlerOutEvent<Self>,
802 ) {
803 match event {
804 handler::Event::Response {
805 request_id,
806 response,
807 } => {
808 let removed = self.remove_pending_inbound_response(&peer, connection, &request_id);
809 debug_assert!(
810 removed,
811 "Expect request_id to be pending before receiving response.",
812 );
813
814 let message = Message::Response {
815 request_id,
816 response,
817 };
818 self.pending_events
819 .push_back(ToSwarm::GenerateEvent(Event::Message { peer, message }));
820 }
821 handler::Event::Request {
822 request_id,
823 request,
824 sender,
825 } => {
826 let channel = ResponseChannel { sender };
827 let message = Message::Request {
828 request_id,
829 request,
830 channel,
831 };
832 self.pending_events
833 .push_back(ToSwarm::GenerateEvent(Event::Message { peer, message }));
834
835 match self.get_connection_mut(&peer, connection) {
836 Some(connection) => {
837 let inserted = connection.pending_outbound_responses.insert(request_id);
838 debug_assert!(inserted, "Expect id of new request to be unknown.");
839 }
840 None => {
842 self.pending_events.push_back(ToSwarm::GenerateEvent(
843 Event::InboundFailure {
844 peer,
845 request_id,
846 error: InboundFailure::ConnectionClosed,
847 },
848 ));
849 }
850 }
851 }
852 handler::Event::ResponseSent(request_id) => {
853 let removed = self.remove_pending_outbound_response(&peer, connection, request_id);
854 debug_assert!(
855 removed,
856 "Expect request_id to be pending before response is sent."
857 );
858
859 self.pending_events
860 .push_back(ToSwarm::GenerateEvent(Event::ResponseSent {
861 peer,
862 request_id,
863 }));
864 }
865 handler::Event::ResponseOmission(request_id) => {
866 let removed = self.remove_pending_outbound_response(&peer, connection, request_id);
867 debug_assert!(
868 removed,
869 "Expect request_id to be pending before response is omitted.",
870 );
871
872 self.pending_events
873 .push_back(ToSwarm::GenerateEvent(Event::InboundFailure {
874 peer,
875 request_id,
876 error: InboundFailure::ResponseOmission,
877 }));
878 }
879 handler::Event::OutboundTimeout(request_id) => {
880 let removed = self.remove_pending_inbound_response(&peer, connection, &request_id);
881 debug_assert!(
882 removed,
883 "Expect request_id to be pending before request times out."
884 );
885
886 self.pending_events
887 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
888 peer,
889 request_id,
890 error: OutboundFailure::Timeout,
891 }));
892 }
893 handler::Event::OutboundUnsupportedProtocols(request_id) => {
894 let removed = self.remove_pending_inbound_response(&peer, connection, &request_id);
895 debug_assert!(
896 removed,
897 "Expect request_id to be pending before failing to connect.",
898 );
899
900 self.pending_events
901 .push_back(ToSwarm::GenerateEvent(Event::OutboundFailure {
902 peer,
903 request_id,
904 error: OutboundFailure::UnsupportedProtocols,
905 }));
906 }
907 }
908 }
909
910 fn poll(
911 &mut self,
912 _: &mut Context<'_>,
913 _: &mut impl PollParameters,
914 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
915 if let Some(ev) = self.pending_events.pop_front() {
916 return Poll::Ready(ev);
917 } else if self.pending_events.capacity() > EMPTY_QUEUE_SHRINK_THRESHOLD {
918 self.pending_events.shrink_to_fit();
919 }
920
921 Poll::Pending
922 }
923}
924
925const EMPTY_QUEUE_SHRINK_THRESHOLD: usize = 100;
930
931struct Connection {
933 id: ConnectionId,
934 remote_address: Option<Multiaddr>,
935 pending_outbound_responses: HashSet<RequestId>,
939 pending_inbound_responses: HashSet<RequestId>,
942}
943
944impl Connection {
945 fn new(id: ConnectionId, remote_address: Option<Multiaddr>) -> Self {
946 Self {
947 id,
948 remote_address,
949 pending_outbound_responses: Default::default(),
950 pending_inbound_responses: Default::default(),
951 }
952 }
953}