1use crate::{
22 addresses::PublicAddresses,
23 error::{Error, ImmediateDialError, SubstreamError},
24 protocol::{connection::ConnectionHandle, InnerTransportEvent, TransportEvent},
25 transport::{manager::TransportManagerHandle, Endpoint},
26 types::{protocol::ProtocolName, ConnectionId, SubstreamId},
27 PeerId, DEFAULT_CHANNEL_SIZE,
28};
29
30use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt};
31use multiaddr::{Multiaddr, Protocol};
32use multihash::Multihash;
33use tokio::sync::mpsc::{channel, Receiver, Sender};
34
35use std::{
36 collections::{HashMap, HashSet},
37 fmt::Debug,
38 pin::Pin,
39 sync::{
40 atomic::{AtomicUsize, Ordering},
41 Arc,
42 },
43 task::{Context, Poll, Waker},
44 time::{Duration, Instant},
45};
46
47const LOG_TARGET: &str = "litep2p::transport-service";
49
50#[derive(Debug)]
59struct ConnectionContext {
60 primary: ConnectionHandle,
62
63 secondary: Option<ConnectionHandle>,
65}
66
67impl ConnectionContext {
68 fn new(primary: ConnectionHandle) -> Self {
70 Self {
71 primary,
72 secondary: None,
73 }
74 }
75
76 fn downgrade(&mut self, connection_id: &ConnectionId) {
79 if self.primary.connection_id() == connection_id {
80 self.primary.close();
81 return;
82 }
83
84 if let Some(handle) = &mut self.secondary {
85 if handle.connection_id() == connection_id {
86 handle.close();
87 return;
88 }
89 }
90
91 tracing::debug!(
92 target: LOG_TARGET,
93 primary = ?self.primary.connection_id(),
94 secondary = ?self.secondary.as_ref().map(|handle| handle.connection_id()),
95 ?connection_id,
96 "connection doesn't exist, cannot downgrade",
97 );
98 }
99
100 fn try_upgrade(&mut self, connection_id: &ConnectionId) {
102 if self.primary.connection_id() == connection_id {
103 self.primary.try_upgrade();
104 return;
105 }
106
107 if let Some(handle) = &mut self.secondary {
108 if handle.connection_id() == connection_id {
109 handle.try_upgrade();
110 return;
111 }
112 }
113
114 tracing::debug!(
115 target: LOG_TARGET,
116 primary = ?self.primary.connection_id(),
117 secondary = ?self.secondary.as_ref().map(|handle| handle.connection_id()),
118 ?connection_id,
119 "connection doesn't exist, cannot upgrade",
120 );
121 }
122}
123
124#[derive(Debug)]
131struct KeepAliveTracker {
132 keep_alive_timeout: Duration,
134
135 last_activity: HashMap<(PeerId, ConnectionId), Instant>,
137
138 pending_keep_alive_timeouts: FuturesUnordered<BoxFuture<'static, (PeerId, ConnectionId)>>,
140
141 waker: Option<Waker>,
143}
144
145impl KeepAliveTracker {
146 pub fn new(keep_alive_timeout: Duration) -> Self {
148 Self {
149 keep_alive_timeout,
150 last_activity: HashMap::new(),
151 pending_keep_alive_timeouts: FuturesUnordered::new(),
152 waker: None,
153 }
154 }
155
156 pub fn on_connection_established(&mut self, peer: PeerId, connection_id: ConnectionId) {
158 self.substream_activity(peer, connection_id);
159 }
160
161 pub fn on_connection_closed(&mut self, peer: PeerId, connection_id: ConnectionId) {
163 self.last_activity.remove(&(peer, connection_id));
164 }
165
166 pub fn substream_activity(&mut self, peer: PeerId, connection_id: ConnectionId) {
168 if self.last_activity.insert((peer, connection_id), Instant::now()).is_none() {
170 let timeout = self.keep_alive_timeout;
172 self.pending_keep_alive_timeouts.push(Box::pin(async move {
173 tokio::time::sleep(timeout).await;
174 (peer, connection_id)
175 }));
176 }
177
178 tracing::trace!(
179 target: LOG_TARGET,
180 ?peer,
181 ?connection_id,
182 ?self.keep_alive_timeout,
183 last_activity = ?self.last_activity.len(),
184 pending_keep_alive_timeouts = ?self.pending_keep_alive_timeouts.len(),
185 "substream activity",
186 );
187
188 if let Some(waker) = self.waker.take() {
190 waker.wake()
191 }
192 }
193}
194
195impl Stream for KeepAliveTracker {
196 type Item = (PeerId, ConnectionId);
197
198 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
199 if self.pending_keep_alive_timeouts.is_empty() {
200 self.waker = Some(cx.waker().clone());
202 return Poll::Pending;
203 }
204
205 match self.pending_keep_alive_timeouts.poll_next_unpin(cx) {
206 Poll::Ready(Some(key)) => {
207 let Some(last_activity) = self.last_activity.get(&key) else {
209 tracing::debug!(
210 target: LOG_TARGET,
211 peer = ?key.0,
212 connection_id = ?key.1,
213 "Last activity no longer tracks the connection (closed event triggered)",
214 );
215
216 cx.waker().wake_by_ref();
220 return Poll::Pending;
221 };
222
223 let inactive_for = last_activity.elapsed();
225 if inactive_for < self.keep_alive_timeout {
226 let timeout = self.keep_alive_timeout.saturating_sub(inactive_for);
227
228 tracing::trace!(
229 target: LOG_TARGET,
230 peer = ?key.0,
231 connection_id = ?key.1,
232 ?timeout,
233 "keep-alive timeout not yet reached",
234 );
235
236 self.pending_keep_alive_timeouts.push(Box::pin(async move {
238 tokio::time::sleep(timeout).await;
239 key
240 }));
241
242 cx.waker().wake_by_ref();
245 return Poll::Pending;
246 }
247
248 tracing::debug!(
250 target: LOG_TARGET,
251 peer = ?key.0,
252 connection_id = ?key.1,
253 "keep-alive timeout triggered",
254 );
255 self.last_activity.remove(&key);
256 Poll::Ready(Some(key))
257 }
258 Poll::Ready(None) | Poll::Pending => Poll::Pending,
259 }
260 }
261}
262
263#[derive(Debug)]
266pub struct TransportService {
267 local_peer_id: PeerId,
269
270 protocol: ProtocolName,
272
273 fallback_names: Vec<ProtocolName>,
275
276 connections: HashMap<PeerId, ConnectionContext>,
278
279 transport_handle: TransportManagerHandle,
281
282 rx: Receiver<InnerTransportEvent>,
284
285 next_substream_id: Arc<AtomicUsize>,
287
288 keep_alive_tracker: KeepAliveTracker,
290}
291
292impl TransportService {
293 pub(crate) fn new(
295 local_peer_id: PeerId,
296 protocol: ProtocolName,
297 fallback_names: Vec<ProtocolName>,
298 next_substream_id: Arc<AtomicUsize>,
299 transport_handle: TransportManagerHandle,
300 keep_alive_timeout: Duration,
301 ) -> (Self, Sender<InnerTransportEvent>) {
302 let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE);
303
304 let keep_alive_tracker = KeepAliveTracker::new(keep_alive_timeout);
305
306 (
307 Self {
308 rx,
309 protocol,
310 local_peer_id,
311 fallback_names,
312 transport_handle,
313 next_substream_id,
314 connections: HashMap::new(),
315 keep_alive_tracker,
316 },
317 tx,
318 )
319 }
320
321 pub fn public_addresses(&self) -> PublicAddresses {
323 self.transport_handle.public_addresses()
324 }
325
326 pub fn listen_addresses(&self) -> HashSet<Multiaddr> {
328 self.transport_handle.listen_addresses()
329 }
330
331 fn on_connection_established(
333 &mut self,
334 peer: PeerId,
335 endpoint: Endpoint,
336 connection_id: ConnectionId,
337 handle: ConnectionHandle,
338 ) -> Option<TransportEvent> {
339 tracing::debug!(
340 target: LOG_TARGET,
341 ?peer,
342 ?endpoint,
343 ?connection_id,
344 protocol = %self.protocol,
345 current_state = ?self.connections.get(&peer),
346 "on connection established",
347 );
348
349 match self.connections.get_mut(&peer) {
350 Some(context) => match context.secondary {
351 Some(_) => {
352 tracing::debug!(
353 target: LOG_TARGET,
354 ?peer,
355 ?connection_id,
356 ?endpoint,
357 protocol = %self.protocol,
358 "ignoring third connection",
359 );
360 None
361 }
362 None => {
363 self.keep_alive_tracker.on_connection_established(peer, connection_id);
364
365 tracing::trace!(
366 target: LOG_TARGET,
367 ?peer,
368 ?endpoint,
369 ?connection_id,
370 protocol = %self.protocol,
371 "secondary connection established",
372 );
373
374 context.secondary = Some(handle);
375
376 None
377 }
378 },
379 None => {
380 tracing::trace!(
381 target: LOG_TARGET,
382 ?peer,
383 ?endpoint,
384 ?connection_id,
385 protocol = %self.protocol,
386 "primary connection established",
387 );
388
389 self.connections.insert(peer, ConnectionContext::new(handle));
390
391 self.keep_alive_tracker.on_connection_established(peer, connection_id);
392
393 Some(TransportEvent::ConnectionEstablished { peer, endpoint })
394 }
395 }
396 }
397
398 fn on_connection_closed(
400 &mut self,
401 peer: PeerId,
402 connection_id: ConnectionId,
403 ) -> Option<TransportEvent> {
404 tracing::debug!(
405 target: LOG_TARGET,
406 ?peer,
407 ?connection_id,
408 protocol = %self.protocol,
409 current_state = ?self.connections.get(&peer),
410 "on connection closed",
411 );
412
413 self.keep_alive_tracker.on_connection_closed(peer, connection_id);
414
415 let Some(context) = self.connections.get_mut(&peer) else {
416 tracing::warn!(
417 target: LOG_TARGET,
418 ?peer,
419 ?connection_id,
420 protocol = %self.protocol,
421 "connection closed to a non-existent peer",
422 );
423
424 debug_assert!(false);
425 return None;
426 };
427
428 if context.primary.connection_id() == &connection_id {
431 tracing::trace!(
432 target: LOG_TARGET,
433 ?peer,
434 ?connection_id,
435 protocol = %self.protocol,
436 "primary connection closed"
437 );
438
439 match context.secondary.take() {
440 None => {
441 self.connections.remove(&peer);
442 return Some(TransportEvent::ConnectionClosed { peer });
443 }
444 Some(handle) => {
445 tracing::debug!(
446 target: LOG_TARGET,
447 ?peer,
448 ?connection_id,
449 protocol = %self.protocol,
450 "switch to secondary connection",
451 );
452
453 context.primary = handle;
454 return None;
455 }
456 }
457 }
458
459 match context.secondary.take() {
460 Some(handle) if handle.connection_id() == &connection_id => {
461 tracing::trace!(
462 target: LOG_TARGET,
463 ?peer,
464 ?connection_id,
465 protocol = %self.protocol,
466 "secondary connection closed",
467 );
468
469 None
470 }
471 connection_state => {
472 tracing::debug!(
473 target: LOG_TARGET,
474 ?peer,
475 ?connection_id,
476 ?connection_state,
477 protocol = %self.protocol,
478 "connection closed but it doesn't exist",
479 );
480
481 None
482 }
483 }
484 }
485
486 pub fn dial(&mut self, peer: &PeerId) -> Result<(), ImmediateDialError> {
490 tracing::trace!(
491 target: LOG_TARGET,
492 ?peer,
493 protocol = %self.protocol,
494 "Dial peer requested",
495 );
496
497 self.transport_handle.dial(peer)
498 }
499
500 pub fn dial_address(&mut self, address: Multiaddr) -> Result<(), ImmediateDialError> {
509 tracing::trace!(
510 target: LOG_TARGET,
511 ?address,
512 protocol = %self.protocol,
513 "Dial address requested",
514 );
515
516 self.transport_handle.dial_address(address)
517 }
518
519 pub fn add_known_address(&mut self, peer: &PeerId, addresses: impl Iterator<Item = Multiaddr>) {
523 let addresses: HashSet<Multiaddr> = addresses
524 .filter_map(|address| {
525 if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) {
526 Some(address.with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).ok()?)))
527 } else {
528 Some(address)
529 }
530 })
531 .collect();
532
533 self.transport_handle.add_known_address(peer, addresses.into_iter());
534 }
535
536 pub fn open_substream(&mut self, peer: PeerId) -> Result<SubstreamId, SubstreamError> {
541 let connection = &mut self
543 .connections
544 .get_mut(&peer)
545 .ok_or(SubstreamError::PeerDoesNotExist(peer))?
546 .primary;
547
548 let connection_id = *connection.connection_id();
549
550 let permit = connection.try_get_permit().ok_or(SubstreamError::ConnectionClosed)?;
551 let substream_id =
552 SubstreamId::from(self.next_substream_id.fetch_add(1usize, Ordering::Relaxed));
553
554 tracing::trace!(
555 target: LOG_TARGET,
556 ?peer,
557 protocol = %self.protocol,
558 ?substream_id,
559 ?connection_id,
560 "open substream",
561 );
562
563 self.keep_alive_tracker.substream_activity(peer, connection_id);
564 connection.try_upgrade();
565
566 connection
567 .open_substream(
568 self.protocol.clone(),
569 self.fallback_names.clone(),
570 substream_id,
571 permit,
572 )
573 .map(|_| substream_id)
574 }
575
576 pub fn force_close(&mut self, peer: PeerId) -> crate::Result<()> {
578 let connection =
579 &mut self.connections.get_mut(&peer).ok_or(Error::PeerDoesntExist(peer))?;
580
581 tracing::trace!(
582 target: LOG_TARGET,
583 ?peer,
584 protocol = %self.protocol,
585 secondary = ?connection.secondary,
586 "forcibly closing the connection",
587 );
588
589 if let Some(ref mut connection) = connection.secondary {
590 let _ = connection.force_close();
591 }
592
593 connection.primary.force_close()
594 }
595
596 pub fn local_peer_id(&self) -> PeerId {
598 self.local_peer_id
599 }
600
601 pub fn unregister_protocol(&self) {
606 self.transport_handle.unregister_protocol(self.protocol.clone());
607 }
608}
609
610impl Stream for TransportService {
611 type Item = TransportEvent;
612
613 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
614 let protocol_name = self.protocol.clone();
615 let duration = self.keep_alive_tracker.keep_alive_timeout;
616
617 while let Poll::Ready(event) = self.rx.poll_recv(cx) {
618 match event {
619 None => {
620 tracing::warn!(
621 target: LOG_TARGET,
622 protocol = ?protocol_name,
623 "transport service closed"
624 );
625 return Poll::Ready(None);
626 }
627 Some(InnerTransportEvent::ConnectionEstablished {
628 peer,
629 endpoint,
630 sender,
631 connection,
632 }) => {
633 if let Some(event) =
634 self.on_connection_established(peer, endpoint, connection, sender)
635 {
636 return Poll::Ready(Some(event));
637 }
638 }
639 Some(InnerTransportEvent::ConnectionClosed { peer, connection }) => {
640 if let Some(event) = self.on_connection_closed(peer, connection) {
641 return Poll::Ready(Some(event));
642 }
643 }
644 Some(InnerTransportEvent::SubstreamOpened {
645 peer,
646 protocol,
647 fallback,
648 direction,
649 substream,
650 connection_id,
651 }) => {
652 if protocol == self.protocol {
653 self.keep_alive_tracker.substream_activity(peer, connection_id);
654 if let Some(context) = self.connections.get_mut(&peer) {
655 context.try_upgrade(&connection_id);
656 }
657 }
658
659 return Poll::Ready(Some(TransportEvent::SubstreamOpened {
660 peer,
661 protocol,
662 fallback,
663 direction,
664 substream,
665 }));
666 }
667 Some(event) => return Poll::Ready(Some(event.into())),
668 }
669 }
670
671 while let Poll::Ready(Some((peer, connection_id))) =
672 self.keep_alive_tracker.poll_next_unpin(cx)
673 {
674 if let Some(context) = self.connections.get_mut(&peer) {
675 tracing::debug!(
676 target: LOG_TARGET,
677 ?peer,
678 ?connection_id,
679 protocol = ?protocol_name,
680 ?duration,
681 "keep-alive timeout over, downgrade connection",
682 );
683
684 context.downgrade(&connection_id);
685 }
686 }
687
688 Poll::Pending
689 }
690}
691
692#[cfg(test)]
693mod tests {
694 use super::*;
695 use crate::{
696 protocol::{ProtocolCommand, TransportService},
697 transport::{
698 manager::{handle::InnerTransportManagerCommand, TransportManagerHandle},
699 KEEP_ALIVE_TIMEOUT,
700 },
701 };
702 use futures::StreamExt;
703 use parking_lot::RwLock;
704 use std::collections::HashSet;
705
706 fn transport_service() -> (
708 TransportService,
709 Sender<InnerTransportEvent>,
710 Receiver<InnerTransportManagerCommand>,
711 ) {
712 let (cmd_tx, cmd_rx) = channel(64);
713 let peer = PeerId::random();
714
715 let handle = TransportManagerHandle::new(
716 peer,
717 Arc::new(RwLock::new(HashMap::new())),
718 cmd_tx,
719 HashSet::new(),
720 Default::default(),
721 PublicAddresses::new(peer),
722 );
723
724 let (service, sender) = TransportService::new(
725 peer,
726 ProtocolName::from("/notif/1"),
727 Vec::new(),
728 Arc::new(AtomicUsize::new(0usize)),
729 handle,
730 KEEP_ALIVE_TIMEOUT,
731 );
732
733 (service, sender, cmd_rx)
734 }
735
736 #[tokio::test]
737 async fn secondary_connection_stored() {
738 let (mut service, sender, _) = transport_service();
739 let peer = PeerId::random();
740
741 let (cmd_tx1, _cmd_rx1) = channel(64);
743 sender
744 .send(InnerTransportEvent::ConnectionEstablished {
745 peer,
746 connection: ConnectionId::from(0usize),
747 endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(0usize)),
748 sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1),
749 })
750 .await
751 .unwrap();
752
753 if let Some(TransportEvent::ConnectionEstablished {
754 peer: connected_peer,
755 endpoint,
756 }) = service.next().await
757 {
758 assert_eq!(connected_peer, peer);
759 assert_eq!(endpoint.address(), &Multiaddr::empty());
760 } else {
761 panic!("expected event from `TransportService`");
762 };
763
764 let (cmd_tx2, _cmd_rx2) = channel(64);
766 sender
767 .send(InnerTransportEvent::ConnectionEstablished {
768 peer,
769 connection: ConnectionId::from(1usize),
770 endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1usize)),
771 sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2),
772 })
773 .await
774 .unwrap();
775
776 futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
777 std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
778 std::task::Poll::Pending => std::task::Poll::Ready(()),
779 })
780 .await;
781
782 let context = service.connections.get(&peer).unwrap();
783 assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
784 assert_eq!(
785 context.secondary.as_ref().unwrap().connection_id(),
786 &ConnectionId::from(1usize)
787 );
788 }
789
790 #[tokio::test]
791 async fn tertiary_connection_ignored() {
792 let (mut service, sender, _) = transport_service();
793 let peer = PeerId::random();
794
795 let (cmd_tx1, _cmd_rx1) = channel(64);
797 sender
798 .send(InnerTransportEvent::ConnectionEstablished {
799 peer,
800 connection: ConnectionId::from(0usize),
801 endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)),
802 sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1),
803 })
804 .await
805 .unwrap();
806
807 if let Some(TransportEvent::ConnectionEstablished {
808 peer: connected_peer,
809 endpoint,
810 }) = service.next().await
811 {
812 assert_eq!(connected_peer, peer);
813 assert_eq!(endpoint.address(), &Multiaddr::empty());
814 } else {
815 panic!("expected event from `TransportService`");
816 };
817
818 let (cmd_tx2, _cmd_rx2) = channel(64);
820 sender
821 .send(InnerTransportEvent::ConnectionEstablished {
822 peer,
823 connection: ConnectionId::from(1usize),
824 endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1usize)),
825 sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2),
826 })
827 .await
828 .unwrap();
829
830 futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
831 std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
832 std::task::Poll::Pending => std::task::Poll::Ready(()),
833 })
834 .await;
835
836 let context = service.connections.get(&peer).unwrap();
837 assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
838 assert_eq!(
839 context.secondary.as_ref().unwrap().connection_id(),
840 &ConnectionId::from(1usize)
841 );
842
843 let (cmd_tx3, mut cmd_rx3) = channel(64);
845 sender
846 .send(InnerTransportEvent::ConnectionEstablished {
847 peer,
848 connection: ConnectionId::from(2usize),
849 endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(2usize)),
850 sender: ConnectionHandle::new(ConnectionId::from(2usize), cmd_tx3),
851 })
852 .await
853 .unwrap();
854
855 futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
856 std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
857 std::task::Poll::Pending => std::task::Poll::Ready(()),
858 })
859 .await;
860
861 let context = service.connections.get(&peer).unwrap();
862 assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
863 assert_eq!(
864 context.secondary.as_ref().unwrap().connection_id(),
865 &ConnectionId::from(1usize)
866 );
867 assert!(cmd_rx3.try_recv().is_err());
868 }
869
870 #[tokio::test]
871 async fn secondary_closing_does_not_emit_event() {
872 let _ = tracing_subscriber::fmt()
873 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
874 .try_init();
875
876 let (mut service, sender, _) = transport_service();
877 let peer = PeerId::random();
878
879 let (cmd_tx1, _cmd_rx1) = channel(64);
881 sender
882 .send(InnerTransportEvent::ConnectionEstablished {
883 peer,
884 connection: ConnectionId::from(0usize),
885 endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)),
886 sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1),
887 })
888 .await
889 .unwrap();
890
891 if let Some(TransportEvent::ConnectionEstablished {
892 peer: connected_peer,
893 endpoint,
894 }) = service.next().await
895 {
896 assert_eq!(connected_peer, peer);
897 assert_eq!(endpoint.address(), &Multiaddr::empty());
898 } else {
899 panic!("expected event from `TransportService`");
900 };
901
902 let (cmd_tx2, _cmd_rx2) = channel(64);
904 sender
905 .send(InnerTransportEvent::ConnectionEstablished {
906 peer,
907 connection: ConnectionId::from(1usize),
908 endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1usize)),
909 sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2),
910 })
911 .await
912 .unwrap();
913
914 futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
915 std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
916 std::task::Poll::Pending => std::task::Poll::Ready(()),
917 })
918 .await;
919
920 let context = service.connections.get(&peer).unwrap();
921 assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
922 assert_eq!(
923 context.secondary.as_ref().unwrap().connection_id(),
924 &ConnectionId::from(1usize)
925 );
926
927 sender
929 .send(InnerTransportEvent::ConnectionClosed {
930 peer,
931 connection: ConnectionId::from(1usize),
932 })
933 .await
934 .unwrap();
935
936 futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
938 std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
939 std::task::Poll::Pending => std::task::Poll::Ready(()),
940 })
941 .await;
942
943 let context = service.connections.get(&peer).unwrap();
945 assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
946 assert!(context.secondary.is_none());
947 }
948
949 #[tokio::test]
950 async fn convert_secondary_to_primary() {
951 let (mut service, sender, _) = transport_service();
952 let peer = PeerId::random();
953
954 let (cmd_tx1, mut cmd_rx1) = channel(64);
956 sender
957 .send(InnerTransportEvent::ConnectionEstablished {
958 peer,
959 connection: ConnectionId::from(0usize),
960 endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)),
961 sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1),
962 })
963 .await
964 .unwrap();
965
966 if let Some(TransportEvent::ConnectionEstablished {
967 peer: connected_peer,
968 endpoint,
969 }) = service.next().await
970 {
971 assert_eq!(connected_peer, peer);
972 assert_eq!(endpoint.address(), &Multiaddr::empty());
973 } else {
974 panic!("expected event from `TransportService`");
975 };
976
977 let (cmd_tx2, mut cmd_rx2) = channel(64);
979 sender
980 .send(InnerTransportEvent::ConnectionEstablished {
981 peer,
982 connection: ConnectionId::from(1usize),
983 endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1usize)),
984 sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2),
985 })
986 .await
987 .unwrap();
988
989 futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
990 std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
991 std::task::Poll::Pending => std::task::Poll::Ready(()),
992 })
993 .await;
994
995 let context = service.connections.get(&peer).unwrap();
996 assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
997 assert_eq!(
998 context.secondary.as_ref().unwrap().connection_id(),
999 &ConnectionId::from(1usize)
1000 );
1001
1002 sender
1004 .send(InnerTransportEvent::ConnectionClosed {
1005 peer,
1006 connection: ConnectionId::from(0usize),
1007 })
1008 .await
1009 .unwrap();
1010
1011 futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
1013 std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
1014 std::task::Poll::Pending => std::task::Poll::Ready(()),
1015 })
1016 .await;
1017
1018 let context = service.connections.get(&peer).unwrap();
1020 assert_eq!(context.primary.connection_id(), &ConnectionId::from(1usize));
1021 assert!(context.secondary.is_none());
1022 assert!(cmd_rx1.try_recv().is_err());
1023
1024 sender
1026 .send(InnerTransportEvent::ConnectionClosed {
1027 peer,
1028 connection: ConnectionId::from(1usize),
1029 })
1030 .await
1031 .unwrap();
1032
1033 if let Some(TransportEvent::ConnectionClosed {
1034 peer: disconnected_peer,
1035 }) = service.next().await
1036 {
1037 assert_eq!(disconnected_peer, peer);
1038 } else {
1039 panic!("expected event from `TransportService`");
1040 };
1041
1042 assert!(service.connections.get(&peer).is_none());
1044 assert!(cmd_rx2.try_recv().is_err());
1045 }
1046
1047 #[tokio::test]
1048 async fn keep_alive_timeout_expires_for_a_stale_connection() {
1049 let _ = tracing_subscriber::fmt()
1050 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1051 .try_init();
1052
1053 let (mut service, sender, _) = transport_service();
1054 let peer = PeerId::random();
1055
1056 let (cmd_tx1, _cmd_rx1) = channel(64);
1058 sender
1059 .send(InnerTransportEvent::ConnectionEstablished {
1060 peer,
1061 connection: ConnectionId::from(1337usize),
1062 endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)),
1063 sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1),
1064 })
1065 .await
1066 .unwrap();
1067
1068 if let Some(TransportEvent::ConnectionEstablished {
1069 peer: connected_peer,
1070 endpoint,
1071 }) = service.next().await
1072 {
1073 assert_eq!(connected_peer, peer);
1074 assert_eq!(endpoint.address(), &Multiaddr::empty());
1075 } else {
1076 panic!("expected event from `TransportService`");
1077 };
1078
1079 assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1081 match service.connections.get(&peer) {
1082 Some(context) => {
1083 assert_eq!(
1084 context.primary.connection_id(),
1085 &ConnectionId::from(1337usize)
1086 );
1087 assert!(context.secondary.is_none());
1088 }
1089 None => panic!("expected {peer} to exist"),
1090 }
1091
1092 sender
1094 .send(InnerTransportEvent::ConnectionClosed {
1095 peer,
1096 connection: ConnectionId::from(1337usize),
1097 })
1098 .await
1099 .unwrap();
1100
1101 if let Some(TransportEvent::ConnectionClosed {
1103 peer: connected_peer,
1104 }) = service.next().await
1105 {
1106 assert_eq!(connected_peer, peer);
1107 } else {
1108 panic!("expected event from `TransportService`");
1109 }
1110
1111 assert!(service.keep_alive_tracker.last_activity.is_empty());
1114 assert!(service.connections.get(&peer).is_none());
1115
1116 let (cmd_tx1, _cmd_rx1) = channel(64);
1118 sender
1119 .send(InnerTransportEvent::ConnectionEstablished {
1120 peer,
1121 connection: ConnectionId::from(1338usize),
1122 endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1338usize)),
1123 sender: ConnectionHandle::new(ConnectionId::from(1338usize), cmd_tx1),
1124 })
1125 .await
1126 .unwrap();
1127
1128 if let Some(TransportEvent::ConnectionEstablished {
1129 peer: connected_peer,
1130 endpoint,
1131 }) = service.next().await
1132 {
1133 assert_eq!(connected_peer, peer);
1134 assert_eq!(endpoint.address(), &Multiaddr::empty());
1135 } else {
1136 panic!("expected event from `TransportService`");
1137 };
1138
1139 assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1140 match service.connections.get(&peer) {
1141 Some(context) => {
1142 assert_eq!(
1143 context.primary.connection_id(),
1144 &ConnectionId::from(1338usize)
1145 );
1146 assert!(context.secondary.is_none());
1147 }
1148 None => panic!("expected {peer} to exist"),
1149 }
1150
1151 match tokio::time::timeout(Duration::from_secs(10), service.next()).await {
1152 Ok(event) => panic!("didn't expect an event: {event:?}"),
1153 Err(_) => {}
1154 }
1155 }
1156
1157 async fn poll_service(service: &mut TransportService) {
1158 futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
1159 std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
1160 std::task::Poll::Pending => std::task::Poll::Ready(()),
1161 })
1162 .await;
1163 }
1164
1165 #[tokio::test]
1166 async fn keep_alive_timeout_downgrades_connections() {
1167 let _ = tracing_subscriber::fmt()
1168 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1169 .try_init();
1170
1171 let (mut service, sender, _) = transport_service();
1172 let peer = PeerId::random();
1173
1174 let (cmd_tx1, _cmd_rx1) = channel(64);
1176 sender
1177 .send(InnerTransportEvent::ConnectionEstablished {
1178 peer,
1179 connection: ConnectionId::from(1337usize),
1180 endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)),
1181 sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1),
1182 })
1183 .await
1184 .unwrap();
1185
1186 if let Some(TransportEvent::ConnectionEstablished {
1187 peer: connected_peer,
1188 endpoint,
1189 }) = service.next().await
1190 {
1191 assert_eq!(connected_peer, peer);
1192 assert_eq!(endpoint.address(), &Multiaddr::empty());
1193 } else {
1194 panic!("expected event from `TransportService`");
1195 };
1196
1197 assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1199 match service.connections.get(&peer) {
1200 Some(context) => {
1201 assert_eq!(
1202 context.primary.connection_id(),
1203 &ConnectionId::from(1337usize)
1204 );
1205 assert!(context.primary.is_active());
1207 assert!(context.secondary.is_none());
1208 }
1209 None => panic!("expected {peer} to exist"),
1210 }
1211
1212 poll_service(&mut service).await;
1213 tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await;
1214 poll_service(&mut service).await;
1215
1216 match service.connections.get(&peer) {
1218 Some(context) => {
1219 assert_eq!(
1220 context.primary.connection_id(),
1221 &ConnectionId::from(1337usize)
1222 );
1223 assert!(!context.primary.is_active());
1225 assert!(context.secondary.is_none());
1226 }
1227 None => panic!("expected {peer} to exist"),
1228 }
1229
1230 assert_eq!(service.keep_alive_tracker.last_activity.len(), 0);
1231 }
1232
1233 #[tokio::test]
1234 async fn keep_alive_timeout_reset_when_user_opens_substream() {
1235 let _ = tracing_subscriber::fmt()
1236 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1237 .try_init();
1238
1239 let (mut service, sender, _) = transport_service();
1240 let peer = PeerId::random();
1241
1242 let (cmd_tx1, _cmd_rx1) = channel(64);
1244 sender
1245 .send(InnerTransportEvent::ConnectionEstablished {
1246 peer,
1247 connection: ConnectionId::from(1337usize),
1248 endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)),
1249 sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1),
1250 })
1251 .await
1252 .unwrap();
1253
1254 if let Some(TransportEvent::ConnectionEstablished {
1255 peer: connected_peer,
1256 endpoint,
1257 }) = service.next().await
1258 {
1259 assert_eq!(connected_peer, peer);
1260 assert_eq!(endpoint.address(), &Multiaddr::empty());
1261 } else {
1262 panic!("expected event from `TransportService`");
1263 };
1264
1265 assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1267 match service.connections.get(&peer) {
1268 Some(context) => {
1269 assert_eq!(
1270 context.primary.connection_id(),
1271 &ConnectionId::from(1337usize)
1272 );
1273 assert!(context.primary.is_active());
1275 assert!(context.secondary.is_none());
1276 }
1277 None => panic!("expected {peer} to exist"),
1278 }
1279
1280 poll_service(&mut service).await;
1281 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1283
1284 service.open_substream(peer).unwrap();
1288 assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1289
1290 poll_service(&mut service).await;
1291 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1293 poll_service(&mut service).await;
1294 assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1295 match service.connections.get(&peer) {
1299 Some(context) => {
1300 assert_eq!(
1301 context.primary.connection_id(),
1302 &ConnectionId::from(1337usize)
1303 );
1304 assert!(context.primary.is_active());
1305 assert!(context.secondary.is_none());
1306 }
1307 None => panic!("expected {peer} to exist"),
1308 }
1309
1310 poll_service(&mut service).await;
1311 tokio::time::sleep(KEEP_ALIVE_TIMEOUT).await;
1312 poll_service(&mut service).await;
1313
1314 assert_eq!(service.keep_alive_tracker.last_activity.len(), 0);
1315
1316 match service.connections.get(&peer) {
1319 Some(context) => {
1320 assert_eq!(
1321 context.primary.connection_id(),
1322 &ConnectionId::from(1337usize)
1323 );
1324 assert!(!context.primary.is_active());
1325 assert!(context.secondary.is_none());
1326 }
1327 None => panic!("expected {peer} to exist"),
1328 }
1329 }
1330
1331 #[tokio::test]
1332 async fn downgraded_connection_without_substreams_is_closed() {
1333 let _ = tracing_subscriber::fmt()
1334 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1335 .try_init();
1336
1337 let (mut service, sender, _) = transport_service();
1338 let peer = PeerId::random();
1339
1340 let (cmd_tx1, mut cmd_rx1) = channel(64);
1342 sender
1343 .send(InnerTransportEvent::ConnectionEstablished {
1344 peer,
1345 connection: ConnectionId::from(1337usize),
1346 endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)),
1347 sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1),
1348 })
1349 .await
1350 .unwrap();
1351
1352 if let Some(TransportEvent::ConnectionEstablished {
1353 peer: connected_peer,
1354 endpoint,
1355 }) = service.next().await
1356 {
1357 assert_eq!(connected_peer, peer);
1358 assert_eq!(endpoint.address(), &Multiaddr::empty());
1359 } else {
1360 panic!("expected event from `TransportService`");
1361 };
1362
1363 assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1365 match service.connections.get(&peer) {
1366 Some(context) => {
1367 assert_eq!(
1368 context.primary.connection_id(),
1369 &ConnectionId::from(1337usize)
1370 );
1371 assert!(context.primary.is_active());
1373 assert!(context.secondary.is_none());
1374 }
1375 None => panic!("expected {peer} to exist"),
1376 }
1377
1378 let substream_id = service.open_substream(peer).unwrap();
1380 let second_substream_id = service.open_substream(peer).unwrap();
1381
1382 poll_service(&mut service).await;
1384 tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await;
1385 poll_service(&mut service).await;
1386
1387 let mut permits = Vec::new();
1388
1389 let protocol_command = cmd_rx1.recv().await.unwrap();
1391 match protocol_command {
1392 ProtocolCommand::OpenSubstream {
1393 protocol,
1394 substream_id: opened_substream_id,
1395 permit,
1396 ..
1397 } => {
1398 assert_eq!(protocol, ProtocolName::from("/notif/1"));
1399 assert_eq!(substream_id, opened_substream_id);
1400
1401 permits.push(permit);
1403 }
1404 _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1405 }
1406
1407 let protocol_command = cmd_rx1.recv().await.unwrap();
1409 match protocol_command {
1410 ProtocolCommand::OpenSubstream {
1411 protocol,
1412 substream_id: opened_substream_id,
1413 permit,
1414 ..
1415 } => {
1416 assert_eq!(protocol, ProtocolName::from("/notif/1"));
1417 assert_eq!(second_substream_id, opened_substream_id);
1418
1419 permits.push(permit);
1421 }
1422 _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1423 }
1424
1425 let permit = permits.pop();
1427 drop(permit);
1434
1435 let substream_id = service.open_substream(peer).unwrap();
1438 let protocol_command = cmd_rx1.recv().await.unwrap();
1440 match protocol_command {
1441 ProtocolCommand::OpenSubstream {
1442 protocol,
1443 substream_id: opened_substream_id,
1444 permit,
1445 ..
1446 } => {
1447 assert_eq!(protocol, ProtocolName::from("/notif/1"));
1448 assert_eq!(substream_id, opened_substream_id);
1449
1450 permits.push(permit);
1452 }
1453 _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1454 }
1455
1456 drop(permits);
1458
1459 poll_service(&mut service).await;
1460 tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await;
1461 poll_service(&mut service).await;
1462
1463 assert_eq!(
1467 service.open_substream(peer),
1468 Err(SubstreamError::ConnectionClosed)
1469 );
1470 }
1471
1472 #[tokio::test]
1473 async fn substream_opening_upgrades_connection_and_resets_keep_alive() {
1474 let _ = tracing_subscriber::fmt()
1475 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1476 .try_init();
1477
1478 let (mut service, sender, _) = transport_service();
1479 let peer = PeerId::random();
1480
1481 let (cmd_tx1, mut cmd_rx1) = channel(64);
1483 sender
1484 .send(InnerTransportEvent::ConnectionEstablished {
1485 peer,
1486 connection: ConnectionId::from(1337usize),
1487 endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)),
1488 sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1),
1489 })
1490 .await
1491 .unwrap();
1492
1493 if let Some(TransportEvent::ConnectionEstablished {
1494 peer: connected_peer,
1495 endpoint,
1496 }) = service.next().await
1497 {
1498 assert_eq!(connected_peer, peer);
1499 assert_eq!(endpoint.address(), &Multiaddr::empty());
1500 } else {
1501 panic!("expected event from `TransportService`");
1502 };
1503
1504 assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1506 match service.connections.get(&peer) {
1507 Some(context) => {
1508 assert_eq!(
1509 context.primary.connection_id(),
1510 &ConnectionId::from(1337usize)
1511 );
1512 assert!(context.primary.is_active());
1514 assert!(context.secondary.is_none());
1515 }
1516 None => panic!("expected {peer} to exist"),
1517 }
1518
1519 let substream_id = service.open_substream(peer).unwrap();
1521 let second_substream_id = service.open_substream(peer).unwrap();
1522
1523 let mut permits = Vec::new();
1524 let protocol_command = cmd_rx1.recv().await.unwrap();
1526 match protocol_command {
1527 ProtocolCommand::OpenSubstream {
1528 protocol,
1529 substream_id: opened_substream_id,
1530 permit,
1531 ..
1532 } => {
1533 assert_eq!(protocol, ProtocolName::from("/notif/1"));
1534 assert_eq!(substream_id, opened_substream_id);
1535
1536 permits.push(permit);
1538 }
1539 _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1540 }
1541
1542 let protocol_command = cmd_rx1.recv().await.unwrap();
1544 match protocol_command {
1545 ProtocolCommand::OpenSubstream {
1546 protocol,
1547 substream_id: opened_substream_id,
1548 permit,
1549 ..
1550 } => {
1551 assert_eq!(protocol, ProtocolName::from("/notif/1"));
1552 assert_eq!(second_substream_id, opened_substream_id);
1553
1554 permits.push(permit);
1556 }
1557 _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1558 }
1559
1560 poll_service(&mut service).await;
1562 tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await;
1563 poll_service(&mut service).await;
1564
1565 match service.connections.get(&peer) {
1567 Some(context) => {
1568 assert_eq!(
1569 context.primary.connection_id(),
1570 &ConnectionId::from(1337usize)
1571 );
1572 assert!(!context.primary.is_active());
1574 assert!(context.secondary.is_none());
1575 }
1576 None => panic!("expected {peer} to exist"),
1577 }
1578 assert_eq!(service.keep_alive_tracker.last_activity.len(), 0);
1579
1580 let substream_id = service.open_substream(peer).unwrap();
1583 let protocol_command = cmd_rx1.recv().await.unwrap();
1584 match protocol_command {
1585 ProtocolCommand::OpenSubstream {
1586 protocol,
1587 substream_id: opened_substream_id,
1588 permit,
1589 ..
1590 } => {
1591 assert_eq!(protocol, ProtocolName::from("/notif/1"));
1592 assert_eq!(substream_id, opened_substream_id);
1593
1594 permits.push(permit);
1596 }
1597 _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1598 }
1599
1600 poll_service(&mut service).await;
1601
1602 match service.connections.get(&peer) {
1604 Some(context) => {
1605 assert_eq!(
1606 context.primary.connection_id(),
1607 &ConnectionId::from(1337usize)
1608 );
1609 assert!(context.primary.is_active());
1611 assert!(context.secondary.is_none());
1612 }
1613 None => panic!("expected {peer} to exist"),
1614 }
1615 assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1616
1617 drop(permits);
1619
1620 match service.connections.get(&peer) {
1622 Some(context) => {
1623 assert_eq!(
1624 context.primary.connection_id(),
1625 &ConnectionId::from(1337usize)
1626 );
1627 assert!(context.primary.is_active());
1629 assert!(context.secondary.is_none());
1630 }
1631 None => panic!("expected {peer} to exist"),
1632 }
1633 assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1634
1635 poll_service(&mut service).await;
1637 tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await;
1638 poll_service(&mut service).await;
1639
1640 match service.connections.get(&peer) {
1641 Some(context) => {
1642 assert_eq!(
1643 context.primary.connection_id(),
1644 &ConnectionId::from(1337usize)
1645 );
1646 assert!(!context.primary.is_active());
1649 assert!(context.secondary.is_none());
1650 }
1651 None => panic!("expected {peer} to exist"),
1652 }
1653
1654 assert_eq!(
1658 service.open_substream(peer),
1659 Err(SubstreamError::ConnectionClosed)
1660 );
1661 }
1662
1663 #[tokio::test]
1664 async fn keep_alive_pop_elements() {
1665 let mut tracker = KeepAliveTracker::new(Duration::from_secs(1));
1666
1667 let (peer1, connection1) = (PeerId::random(), ConnectionId::from(1usize));
1668 let (peer2, connection2) = (PeerId::random(), ConnectionId::from(2usize));
1669 let added_keys = HashSet::from([(peer1, connection1), (peer2, connection2)]);
1670
1671 tracker.on_connection_established(peer1, connection1);
1672 tracker.on_connection_established(peer2, connection2);
1673
1674 tokio::time::sleep(Duration::from_secs(2)).await;
1675
1676 let key = tracker.next().await.unwrap();
1677 assert!(added_keys.contains(&key));
1678
1679 let key = tracker.next().await.unwrap();
1680 assert!(added_keys.contains(&key));
1681
1682 assert!(tracker.pending_keep_alive_timeouts.is_empty());
1684 assert!(tracker.last_activity.is_empty());
1685 }
1686}