litep2p/protocol/
transport_service.rs

1// Copyright 2023 litep2p developers
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::{
22    addresses::PublicAddresses,
23    error::{Error, ImmediateDialError, SubstreamError},
24    protocol::{connection::ConnectionHandle, InnerTransportEvent, TransportEvent},
25    transport::{manager::TransportManagerHandle, Endpoint},
26    types::{protocol::ProtocolName, ConnectionId, SubstreamId},
27    PeerId, DEFAULT_CHANNEL_SIZE,
28};
29
30use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt};
31use multiaddr::{Multiaddr, Protocol};
32use multihash::Multihash;
33use tokio::sync::mpsc::{channel, Receiver, Sender};
34
35use std::{
36    collections::{HashMap, HashSet},
37    fmt::Debug,
38    pin::Pin,
39    sync::{
40        atomic::{AtomicUsize, Ordering},
41        Arc,
42    },
43    task::{Context, Poll, Waker},
44    time::{Duration, Instant},
45};
46
47/// Logging target for the file.
48const LOG_TARGET: &str = "litep2p::transport-service";
49
50/// Connection context for the peer.
51///
52/// Each peer is allowed to have at most two connections open. The first open connection is the
53/// primary connections which the local node uses to open substreams to remote. Secondary connection
54/// may be open if local and remote opened connections at the same time.
55///
56/// Secondary connection may be promoted to a primary connection if the primary connections closes
57/// while the secondary connections remains open.
58#[derive(Debug)]
59struct ConnectionContext {
60    /// Primary connection.
61    primary: ConnectionHandle,
62
63    /// Secondary connection, if it exists.
64    secondary: Option<ConnectionHandle>,
65}
66
67impl ConnectionContext {
68    /// Create new [`ConnectionContext`].
69    fn new(primary: ConnectionHandle) -> Self {
70        Self {
71            primary,
72            secondary: None,
73        }
74    }
75
76    /// Downgrade connection to non-active which means it will be closed
77    /// if there are no substreams open over it.
78    fn downgrade(&mut self, connection_id: &ConnectionId) {
79        if self.primary.connection_id() == connection_id {
80            self.primary.close();
81            return;
82        }
83
84        if let Some(handle) = &mut self.secondary {
85            if handle.connection_id() == connection_id {
86                handle.close();
87                return;
88            }
89        }
90
91        tracing::debug!(
92            target: LOG_TARGET,
93            primary = ?self.primary.connection_id(),
94            secondary = ?self.secondary.as_ref().map(|handle| handle.connection_id()),
95            ?connection_id,
96            "connection doesn't exist, cannot downgrade",
97        );
98    }
99
100    /// Try to upgrade the connection to active state.
101    fn try_upgrade(&mut self, connection_id: &ConnectionId) {
102        if self.primary.connection_id() == connection_id {
103            self.primary.try_upgrade();
104            return;
105        }
106
107        if let Some(handle) = &mut self.secondary {
108            if handle.connection_id() == connection_id {
109                handle.try_upgrade();
110                return;
111            }
112        }
113
114        tracing::debug!(
115            target: LOG_TARGET,
116            primary = ?self.primary.connection_id(),
117            secondary = ?self.secondary.as_ref().map(|handle| handle.connection_id()),
118            ?connection_id,
119            "connection doesn't exist, cannot upgrade",
120        );
121    }
122}
123
124/// Tracks connection keep-alive timeouts.
125///
126/// A connection keep-alive timeout is started when a connection is established.
127/// If no substreams are opened over the connection within the timeout,
128/// the connection is downgraded. However, if a substream is opened over the connection,
129/// the timeout is reset.
130#[derive(Debug)]
131struct KeepAliveTracker {
132    /// Close the connection if no substreams are open within this time frame.
133    keep_alive_timeout: Duration,
134
135    /// Track substream last activity.
136    last_activity: HashMap<(PeerId, ConnectionId), Instant>,
137
138    /// Pending keep-alive timeouts.
139    pending_keep_alive_timeouts: FuturesUnordered<BoxFuture<'static, (PeerId, ConnectionId)>>,
140
141    /// Saved waker.
142    waker: Option<Waker>,
143}
144
145impl KeepAliveTracker {
146    /// Create new [`KeepAliveTracker`].
147    pub fn new(keep_alive_timeout: Duration) -> Self {
148        Self {
149            keep_alive_timeout,
150            last_activity: HashMap::new(),
151            pending_keep_alive_timeouts: FuturesUnordered::new(),
152            waker: None,
153        }
154    }
155
156    /// Called on connection established event to add a new keep-alive timeout.
157    pub fn on_connection_established(&mut self, peer: PeerId, connection_id: ConnectionId) {
158        self.substream_activity(peer, connection_id);
159    }
160
161    /// Called on connection closed event.
162    pub fn on_connection_closed(&mut self, peer: PeerId, connection_id: ConnectionId) {
163        self.last_activity.remove(&(peer, connection_id));
164    }
165
166    /// Called on substream opened event to track the last activity.
167    pub fn substream_activity(&mut self, peer: PeerId, connection_id: ConnectionId) {
168        // Keep track of the connection ID and the time the substream was opened.
169        if self.last_activity.insert((peer, connection_id), Instant::now()).is_none() {
170            // Refill futures if there is no pending keep-alive timeout.
171            let timeout = self.keep_alive_timeout;
172            self.pending_keep_alive_timeouts.push(Box::pin(async move {
173                tokio::time::sleep(timeout).await;
174                (peer, connection_id)
175            }));
176        }
177
178        tracing::trace!(
179            target: LOG_TARGET,
180            ?peer,
181            ?connection_id,
182            ?self.keep_alive_timeout,
183            last_activity = ?self.last_activity.len(),
184            pending_keep_alive_timeouts = ?self.pending_keep_alive_timeouts.len(),
185            "substream activity",
186        );
187
188        // Wake any pending poll.
189        if let Some(waker) = self.waker.take() {
190            waker.wake()
191        }
192    }
193}
194
195impl Stream for KeepAliveTracker {
196    type Item = (PeerId, ConnectionId);
197
198    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
199        if self.pending_keep_alive_timeouts.is_empty() {
200            // No pending keep-alive timeouts.
201            self.waker = Some(cx.waker().clone());
202            return Poll::Pending;
203        }
204
205        match self.pending_keep_alive_timeouts.poll_next_unpin(cx) {
206            Poll::Ready(Some(key)) => {
207                // Check last-activity time.
208                let Some(last_activity) = self.last_activity.get(&key) else {
209                    tracing::debug!(
210                        target: LOG_TARGET,
211                        peer = ?key.0,
212                        connection_id = ?key.1,
213                        "Last activity no longer tracks the connection (closed event triggered)",
214                    );
215
216                    // We have effectively ignored this `Poll::Ready` event. To prevent the
217                    // future from getting stuck, we need to tell the executor to poll again
218                    // for more events.
219                    cx.waker().wake_by_ref();
220                    return Poll::Pending;
221                };
222
223                // Keep-alive timeout not reached yet.
224                let inactive_for = last_activity.elapsed();
225                if inactive_for < self.keep_alive_timeout {
226                    let timeout = self.keep_alive_timeout.saturating_sub(inactive_for);
227
228                    tracing::trace!(
229                        target: LOG_TARGET,
230                        peer = ?key.0,
231                        connection_id = ?key.1,
232                        ?timeout,
233                        "keep-alive timeout not yet reached",
234                    );
235
236                    // Refill the keep alive timeouts.
237                    self.pending_keep_alive_timeouts.push(Box::pin(async move {
238                        tokio::time::sleep(timeout).await;
239                        key
240                    }));
241
242                    // This is similar to the `last_activity` check above, we need to inform
243                    // the executor that this object may produce more events.
244                    cx.waker().wake_by_ref();
245                    return Poll::Pending;
246                }
247
248                // Keep-alive timeout reached.
249                tracing::debug!(
250                    target: LOG_TARGET,
251                    peer = ?key.0,
252                    connection_id = ?key.1,
253                    "keep-alive timeout triggered",
254                );
255                self.last_activity.remove(&key);
256                Poll::Ready(Some(key))
257            }
258            Poll::Ready(None) | Poll::Pending => Poll::Pending,
259        }
260    }
261}
262
263/// Provides an interfaces for [`Litep2p`](crate::Litep2p) protocols to interact
264/// with the underlying transport protocols.
265#[derive(Debug)]
266pub struct TransportService {
267    /// Local peer ID.
268    local_peer_id: PeerId,
269
270    /// Protocol.
271    protocol: ProtocolName,
272
273    /// Fallback names for the protocol.
274    fallback_names: Vec<ProtocolName>,
275
276    /// Open connections.
277    connections: HashMap<PeerId, ConnectionContext>,
278
279    /// Transport handle.
280    transport_handle: TransportManagerHandle,
281
282    /// RX channel for receiving events from tranports and connections.
283    rx: Receiver<InnerTransportEvent>,
284
285    /// Next substream ID.
286    next_substream_id: Arc<AtomicUsize>,
287
288    /// Close the connection if no substreams are open within this time frame.
289    keep_alive_tracker: KeepAliveTracker,
290}
291
292impl TransportService {
293    /// Create new [`TransportService`].
294    pub(crate) fn new(
295        local_peer_id: PeerId,
296        protocol: ProtocolName,
297        fallback_names: Vec<ProtocolName>,
298        next_substream_id: Arc<AtomicUsize>,
299        transport_handle: TransportManagerHandle,
300        keep_alive_timeout: Duration,
301    ) -> (Self, Sender<InnerTransportEvent>) {
302        let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE);
303
304        let keep_alive_tracker = KeepAliveTracker::new(keep_alive_timeout);
305
306        (
307            Self {
308                rx,
309                protocol,
310                local_peer_id,
311                fallback_names,
312                transport_handle,
313                next_substream_id,
314                connections: HashMap::new(),
315                keep_alive_tracker,
316            },
317            tx,
318        )
319    }
320
321    /// Get the list of public addresses of the node.
322    pub fn public_addresses(&self) -> PublicAddresses {
323        self.transport_handle.public_addresses()
324    }
325
326    /// Get the list of listen addresses of the node.
327    pub fn listen_addresses(&self) -> HashSet<Multiaddr> {
328        self.transport_handle.listen_addresses()
329    }
330
331    /// Handle connection established event.
332    fn on_connection_established(
333        &mut self,
334        peer: PeerId,
335        endpoint: Endpoint,
336        connection_id: ConnectionId,
337        handle: ConnectionHandle,
338    ) -> Option<TransportEvent> {
339        tracing::debug!(
340            target: LOG_TARGET,
341            ?peer,
342            ?endpoint,
343            ?connection_id,
344            protocol = %self.protocol,
345            current_state = ?self.connections.get(&peer),
346            "on connection established",
347        );
348
349        match self.connections.get_mut(&peer) {
350            Some(context) => match context.secondary {
351                Some(_) => {
352                    tracing::debug!(
353                        target: LOG_TARGET,
354                        ?peer,
355                        ?connection_id,
356                        ?endpoint,
357                        protocol = %self.protocol,
358                        "ignoring third connection",
359                    );
360                    None
361                }
362                None => {
363                    self.keep_alive_tracker.on_connection_established(peer, connection_id);
364
365                    tracing::trace!(
366                        target: LOG_TARGET,
367                        ?peer,
368                        ?endpoint,
369                        ?connection_id,
370                        protocol = %self.protocol,
371                        "secondary connection established",
372                    );
373
374                    context.secondary = Some(handle);
375
376                    None
377                }
378            },
379            None => {
380                tracing::trace!(
381                    target: LOG_TARGET,
382                    ?peer,
383                    ?endpoint,
384                    ?connection_id,
385                    protocol = %self.protocol,
386                    "primary connection established",
387                );
388
389                self.connections.insert(peer, ConnectionContext::new(handle));
390
391                self.keep_alive_tracker.on_connection_established(peer, connection_id);
392
393                Some(TransportEvent::ConnectionEstablished { peer, endpoint })
394            }
395        }
396    }
397
398    /// Handle connection closed event.
399    fn on_connection_closed(
400        &mut self,
401        peer: PeerId,
402        connection_id: ConnectionId,
403    ) -> Option<TransportEvent> {
404        tracing::debug!(
405            target: LOG_TARGET,
406            ?peer,
407            ?connection_id,
408            protocol = %self.protocol,
409            current_state = ?self.connections.get(&peer),
410            "on connection closed",
411        );
412
413        self.keep_alive_tracker.on_connection_closed(peer, connection_id);
414
415        let Some(context) = self.connections.get_mut(&peer) else {
416            tracing::warn!(
417                target: LOG_TARGET,
418                ?peer,
419                ?connection_id,
420                protocol = %self.protocol,
421                "connection closed to a non-existent peer",
422            );
423
424            debug_assert!(false);
425            return None;
426        };
427
428        // if the primary connection was closed, check if there exist a secondary connection
429        // and if it does, convert the secondary connection a primary connection
430        if context.primary.connection_id() == &connection_id {
431            tracing::trace!(
432                target: LOG_TARGET,
433                ?peer,
434                ?connection_id,
435                protocol = %self.protocol,
436                "primary connection closed"
437            );
438
439            match context.secondary.take() {
440                None => {
441                    self.connections.remove(&peer);
442                    return Some(TransportEvent::ConnectionClosed { peer });
443                }
444                Some(handle) => {
445                    tracing::debug!(
446                        target: LOG_TARGET,
447                        ?peer,
448                        ?connection_id,
449                        protocol = %self.protocol,
450                        "switch to secondary connection",
451                    );
452
453                    context.primary = handle;
454                    return None;
455                }
456            }
457        }
458
459        match context.secondary.take() {
460            Some(handle) if handle.connection_id() == &connection_id => {
461                tracing::trace!(
462                    target: LOG_TARGET,
463                    ?peer,
464                    ?connection_id,
465                    protocol = %self.protocol,
466                    "secondary connection closed",
467                );
468
469                None
470            }
471            connection_state => {
472                tracing::debug!(
473                    target: LOG_TARGET,
474                    ?peer,
475                    ?connection_id,
476                    ?connection_state,
477                    protocol = %self.protocol,
478                    "connection closed but it doesn't exist",
479                );
480
481                None
482            }
483        }
484    }
485
486    /// Dial `peer` using `PeerId`.
487    ///
488    /// Call fails if `Litep2p` doesn't have a known address for the peer.
489    pub fn dial(&mut self, peer: &PeerId) -> Result<(), ImmediateDialError> {
490        tracing::trace!(
491            target: LOG_TARGET,
492            ?peer,
493            protocol = %self.protocol,
494            "Dial peer requested",
495        );
496
497        self.transport_handle.dial(peer)
498    }
499
500    /// Dial peer using a `Multiaddr`.
501    ///
502    /// Call fails if the address is not in correct format or it contains an unsupported/disabled
503    /// transport.
504    ///
505    /// Calling this function is only necessary for those addresses that are discovered out-of-band
506    /// since `Litep2p` internally keeps track of all peer addresses it has learned through user
507    /// calling this function, Kademlia peer discoveries and `Identify` responses.
508    pub fn dial_address(&mut self, address: Multiaddr) -> Result<(), ImmediateDialError> {
509        tracing::trace!(
510            target: LOG_TARGET,
511            ?address,
512            protocol = %self.protocol,
513            "Dial address requested",
514        );
515
516        self.transport_handle.dial_address(address)
517    }
518
519    /// Add one or more addresses for `peer`.
520    ///
521    /// The list is filtered for duplicates and unsupported transports.
522    pub fn add_known_address(&mut self, peer: &PeerId, addresses: impl Iterator<Item = Multiaddr>) {
523        let addresses: HashSet<Multiaddr> = addresses
524            .filter_map(|address| {
525                if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) {
526                    Some(address.with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).ok()?)))
527                } else {
528                    Some(address)
529                }
530            })
531            .collect();
532
533        self.transport_handle.add_known_address(peer, addresses.into_iter());
534    }
535
536    /// Open substream to `peer`.
537    ///
538    /// Call fails if there is no connection open to `peer` or the channel towards
539    /// the connection is clogged.
540    pub fn open_substream(&mut self, peer: PeerId) -> Result<SubstreamId, SubstreamError> {
541        // always prefer the primary connection
542        let connection = &mut self
543            .connections
544            .get_mut(&peer)
545            .ok_or(SubstreamError::PeerDoesNotExist(peer))?
546            .primary;
547
548        let connection_id = *connection.connection_id();
549
550        let permit = connection.try_get_permit().ok_or(SubstreamError::ConnectionClosed)?;
551        let substream_id =
552            SubstreamId::from(self.next_substream_id.fetch_add(1usize, Ordering::Relaxed));
553
554        tracing::trace!(
555            target: LOG_TARGET,
556            ?peer,
557            protocol = %self.protocol,
558            ?substream_id,
559            ?connection_id,
560            "open substream",
561        );
562
563        self.keep_alive_tracker.substream_activity(peer, connection_id);
564        connection.try_upgrade();
565
566        connection
567            .open_substream(
568                self.protocol.clone(),
569                self.fallback_names.clone(),
570                substream_id,
571                permit,
572            )
573            .map(|_| substream_id)
574    }
575
576    /// Forcibly close the connection, even if other protocols have substreams open over it.
577    pub fn force_close(&mut self, peer: PeerId) -> crate::Result<()> {
578        let connection =
579            &mut self.connections.get_mut(&peer).ok_or(Error::PeerDoesntExist(peer))?;
580
581        tracing::trace!(
582            target: LOG_TARGET,
583            ?peer,
584            protocol = %self.protocol,
585            secondary = ?connection.secondary,
586            "forcibly closing the connection",
587        );
588
589        if let Some(ref mut connection) = connection.secondary {
590            let _ = connection.force_close();
591        }
592
593        connection.primary.force_close()
594    }
595
596    /// Get local peer ID.
597    pub fn local_peer_id(&self) -> PeerId {
598        self.local_peer_id
599    }
600
601    /// Dynamically unregister a protocol.
602    ///
603    /// This must be called when a protocol is no longer needed (e.g. user dropped the protocol
604    /// handle).
605    pub fn unregister_protocol(&self) {
606        self.transport_handle.unregister_protocol(self.protocol.clone());
607    }
608}
609
610impl Stream for TransportService {
611    type Item = TransportEvent;
612
613    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
614        let protocol_name = self.protocol.clone();
615        let duration = self.keep_alive_tracker.keep_alive_timeout;
616
617        while let Poll::Ready(event) = self.rx.poll_recv(cx) {
618            match event {
619                None => {
620                    tracing::warn!(
621                        target: LOG_TARGET,
622                        protocol = ?protocol_name,
623                        "transport service closed"
624                    );
625                    return Poll::Ready(None);
626                }
627                Some(InnerTransportEvent::ConnectionEstablished {
628                    peer,
629                    endpoint,
630                    sender,
631                    connection,
632                }) => {
633                    if let Some(event) =
634                        self.on_connection_established(peer, endpoint, connection, sender)
635                    {
636                        return Poll::Ready(Some(event));
637                    }
638                }
639                Some(InnerTransportEvent::ConnectionClosed { peer, connection }) => {
640                    if let Some(event) = self.on_connection_closed(peer, connection) {
641                        return Poll::Ready(Some(event));
642                    }
643                }
644                Some(InnerTransportEvent::SubstreamOpened {
645                    peer,
646                    protocol,
647                    fallback,
648                    direction,
649                    substream,
650                    connection_id,
651                }) => {
652                    if protocol == self.protocol {
653                        self.keep_alive_tracker.substream_activity(peer, connection_id);
654                        if let Some(context) = self.connections.get_mut(&peer) {
655                            context.try_upgrade(&connection_id);
656                        }
657                    }
658
659                    return Poll::Ready(Some(TransportEvent::SubstreamOpened {
660                        peer,
661                        protocol,
662                        fallback,
663                        direction,
664                        substream,
665                    }));
666                }
667                Some(event) => return Poll::Ready(Some(event.into())),
668            }
669        }
670
671        while let Poll::Ready(Some((peer, connection_id))) =
672            self.keep_alive_tracker.poll_next_unpin(cx)
673        {
674            if let Some(context) = self.connections.get_mut(&peer) {
675                tracing::debug!(
676                    target: LOG_TARGET,
677                    ?peer,
678                    ?connection_id,
679                    protocol = ?protocol_name,
680                    ?duration,
681                    "keep-alive timeout over, downgrade connection",
682                );
683
684                context.downgrade(&connection_id);
685            }
686        }
687
688        Poll::Pending
689    }
690}
691
692#[cfg(test)]
693mod tests {
694    use super::*;
695    use crate::{
696        protocol::{ProtocolCommand, TransportService},
697        transport::{
698            manager::{handle::InnerTransportManagerCommand, TransportManagerHandle},
699            KEEP_ALIVE_TIMEOUT,
700        },
701    };
702    use futures::StreamExt;
703    use parking_lot::RwLock;
704    use std::collections::HashSet;
705
706    /// Create new `TransportService`
707    fn transport_service() -> (
708        TransportService,
709        Sender<InnerTransportEvent>,
710        Receiver<InnerTransportManagerCommand>,
711    ) {
712        let (cmd_tx, cmd_rx) = channel(64);
713        let peer = PeerId::random();
714
715        let handle = TransportManagerHandle::new(
716            peer,
717            Arc::new(RwLock::new(HashMap::new())),
718            cmd_tx,
719            HashSet::new(),
720            Default::default(),
721            PublicAddresses::new(peer),
722        );
723
724        let (service, sender) = TransportService::new(
725            peer,
726            ProtocolName::from("/notif/1"),
727            Vec::new(),
728            Arc::new(AtomicUsize::new(0usize)),
729            handle,
730            KEEP_ALIVE_TIMEOUT,
731        );
732
733        (service, sender, cmd_rx)
734    }
735
736    #[tokio::test]
737    async fn secondary_connection_stored() {
738        let (mut service, sender, _) = transport_service();
739        let peer = PeerId::random();
740
741        // register first connection
742        let (cmd_tx1, _cmd_rx1) = channel(64);
743        sender
744            .send(InnerTransportEvent::ConnectionEstablished {
745                peer,
746                connection: ConnectionId::from(0usize),
747                endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(0usize)),
748                sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1),
749            })
750            .await
751            .unwrap();
752
753        if let Some(TransportEvent::ConnectionEstablished {
754            peer: connected_peer,
755            endpoint,
756        }) = service.next().await
757        {
758            assert_eq!(connected_peer, peer);
759            assert_eq!(endpoint.address(), &Multiaddr::empty());
760        } else {
761            panic!("expected event from `TransportService`");
762        };
763
764        // register secondary connection
765        let (cmd_tx2, _cmd_rx2) = channel(64);
766        sender
767            .send(InnerTransportEvent::ConnectionEstablished {
768                peer,
769                connection: ConnectionId::from(1usize),
770                endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1usize)),
771                sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2),
772            })
773            .await
774            .unwrap();
775
776        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
777            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
778            std::task::Poll::Pending => std::task::Poll::Ready(()),
779        })
780        .await;
781
782        let context = service.connections.get(&peer).unwrap();
783        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
784        assert_eq!(
785            context.secondary.as_ref().unwrap().connection_id(),
786            &ConnectionId::from(1usize)
787        );
788    }
789
790    #[tokio::test]
791    async fn tertiary_connection_ignored() {
792        let (mut service, sender, _) = transport_service();
793        let peer = PeerId::random();
794
795        // register first connection
796        let (cmd_tx1, _cmd_rx1) = channel(64);
797        sender
798            .send(InnerTransportEvent::ConnectionEstablished {
799                peer,
800                connection: ConnectionId::from(0usize),
801                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)),
802                sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1),
803            })
804            .await
805            .unwrap();
806
807        if let Some(TransportEvent::ConnectionEstablished {
808            peer: connected_peer,
809            endpoint,
810        }) = service.next().await
811        {
812            assert_eq!(connected_peer, peer);
813            assert_eq!(endpoint.address(), &Multiaddr::empty());
814        } else {
815            panic!("expected event from `TransportService`");
816        };
817
818        // register secondary connection
819        let (cmd_tx2, _cmd_rx2) = channel(64);
820        sender
821            .send(InnerTransportEvent::ConnectionEstablished {
822                peer,
823                connection: ConnectionId::from(1usize),
824                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1usize)),
825                sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2),
826            })
827            .await
828            .unwrap();
829
830        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
831            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
832            std::task::Poll::Pending => std::task::Poll::Ready(()),
833        })
834        .await;
835
836        let context = service.connections.get(&peer).unwrap();
837        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
838        assert_eq!(
839            context.secondary.as_ref().unwrap().connection_id(),
840            &ConnectionId::from(1usize)
841        );
842
843        // try to register tertiary connection and verify it's ignored
844        let (cmd_tx3, mut cmd_rx3) = channel(64);
845        sender
846            .send(InnerTransportEvent::ConnectionEstablished {
847                peer,
848                connection: ConnectionId::from(2usize),
849                endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(2usize)),
850                sender: ConnectionHandle::new(ConnectionId::from(2usize), cmd_tx3),
851            })
852            .await
853            .unwrap();
854
855        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
856            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
857            std::task::Poll::Pending => std::task::Poll::Ready(()),
858        })
859        .await;
860
861        let context = service.connections.get(&peer).unwrap();
862        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
863        assert_eq!(
864            context.secondary.as_ref().unwrap().connection_id(),
865            &ConnectionId::from(1usize)
866        );
867        assert!(cmd_rx3.try_recv().is_err());
868    }
869
870    #[tokio::test]
871    async fn secondary_closing_does_not_emit_event() {
872        let _ = tracing_subscriber::fmt()
873            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
874            .try_init();
875
876        let (mut service, sender, _) = transport_service();
877        let peer = PeerId::random();
878
879        // register first connection
880        let (cmd_tx1, _cmd_rx1) = channel(64);
881        sender
882            .send(InnerTransportEvent::ConnectionEstablished {
883                peer,
884                connection: ConnectionId::from(0usize),
885                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)),
886                sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1),
887            })
888            .await
889            .unwrap();
890
891        if let Some(TransportEvent::ConnectionEstablished {
892            peer: connected_peer,
893            endpoint,
894        }) = service.next().await
895        {
896            assert_eq!(connected_peer, peer);
897            assert_eq!(endpoint.address(), &Multiaddr::empty());
898        } else {
899            panic!("expected event from `TransportService`");
900        };
901
902        // register secondary connection
903        let (cmd_tx2, _cmd_rx2) = channel(64);
904        sender
905            .send(InnerTransportEvent::ConnectionEstablished {
906                peer,
907                connection: ConnectionId::from(1usize),
908                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1usize)),
909                sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2),
910            })
911            .await
912            .unwrap();
913
914        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
915            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
916            std::task::Poll::Pending => std::task::Poll::Ready(()),
917        })
918        .await;
919
920        let context = service.connections.get(&peer).unwrap();
921        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
922        assert_eq!(
923            context.secondary.as_ref().unwrap().connection_id(),
924            &ConnectionId::from(1usize)
925        );
926
927        // close the secondary connection
928        sender
929            .send(InnerTransportEvent::ConnectionClosed {
930                peer,
931                connection: ConnectionId::from(1usize),
932            })
933            .await
934            .unwrap();
935
936        // verify that the protocol is not notified
937        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
938            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
939            std::task::Poll::Pending => std::task::Poll::Ready(()),
940        })
941        .await;
942
943        // verify that the secondary connection doesn't exist anymore
944        let context = service.connections.get(&peer).unwrap();
945        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
946        assert!(context.secondary.is_none());
947    }
948
949    #[tokio::test]
950    async fn convert_secondary_to_primary() {
951        let (mut service, sender, _) = transport_service();
952        let peer = PeerId::random();
953
954        // register first connection
955        let (cmd_tx1, mut cmd_rx1) = channel(64);
956        sender
957            .send(InnerTransportEvent::ConnectionEstablished {
958                peer,
959                connection: ConnectionId::from(0usize),
960                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)),
961                sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1),
962            })
963            .await
964            .unwrap();
965
966        if let Some(TransportEvent::ConnectionEstablished {
967            peer: connected_peer,
968            endpoint,
969        }) = service.next().await
970        {
971            assert_eq!(connected_peer, peer);
972            assert_eq!(endpoint.address(), &Multiaddr::empty());
973        } else {
974            panic!("expected event from `TransportService`");
975        };
976
977        // register secondary connection
978        let (cmd_tx2, mut cmd_rx2) = channel(64);
979        sender
980            .send(InnerTransportEvent::ConnectionEstablished {
981                peer,
982                connection: ConnectionId::from(1usize),
983                endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1usize)),
984                sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2),
985            })
986            .await
987            .unwrap();
988
989        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
990            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
991            std::task::Poll::Pending => std::task::Poll::Ready(()),
992        })
993        .await;
994
995        let context = service.connections.get(&peer).unwrap();
996        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
997        assert_eq!(
998            context.secondary.as_ref().unwrap().connection_id(),
999            &ConnectionId::from(1usize)
1000        );
1001
1002        // close the primary connection
1003        sender
1004            .send(InnerTransportEvent::ConnectionClosed {
1005                peer,
1006                connection: ConnectionId::from(0usize),
1007            })
1008            .await
1009            .unwrap();
1010
1011        // verify that the protocol is not notified
1012        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
1013            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
1014            std::task::Poll::Pending => std::task::Poll::Ready(()),
1015        })
1016        .await;
1017
1018        // verify that the primary connection has been replaced
1019        let context = service.connections.get(&peer).unwrap();
1020        assert_eq!(context.primary.connection_id(), &ConnectionId::from(1usize));
1021        assert!(context.secondary.is_none());
1022        assert!(cmd_rx1.try_recv().is_err());
1023
1024        // close the secondary connection as well
1025        sender
1026            .send(InnerTransportEvent::ConnectionClosed {
1027                peer,
1028                connection: ConnectionId::from(1usize),
1029            })
1030            .await
1031            .unwrap();
1032
1033        if let Some(TransportEvent::ConnectionClosed {
1034            peer: disconnected_peer,
1035        }) = service.next().await
1036        {
1037            assert_eq!(disconnected_peer, peer);
1038        } else {
1039            panic!("expected event from `TransportService`");
1040        };
1041
1042        // verify that the primary connection has been replaced
1043        assert!(service.connections.get(&peer).is_none());
1044        assert!(cmd_rx2.try_recv().is_err());
1045    }
1046
1047    #[tokio::test]
1048    async fn keep_alive_timeout_expires_for_a_stale_connection() {
1049        let _ = tracing_subscriber::fmt()
1050            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1051            .try_init();
1052
1053        let (mut service, sender, _) = transport_service();
1054        let peer = PeerId::random();
1055
1056        // register first connection
1057        let (cmd_tx1, _cmd_rx1) = channel(64);
1058        sender
1059            .send(InnerTransportEvent::ConnectionEstablished {
1060                peer,
1061                connection: ConnectionId::from(1337usize),
1062                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)),
1063                sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1),
1064            })
1065            .await
1066            .unwrap();
1067
1068        if let Some(TransportEvent::ConnectionEstablished {
1069            peer: connected_peer,
1070            endpoint,
1071        }) = service.next().await
1072        {
1073            assert_eq!(connected_peer, peer);
1074            assert_eq!(endpoint.address(), &Multiaddr::empty());
1075        } else {
1076            panic!("expected event from `TransportService`");
1077        };
1078
1079        // verify the first connection state is correct
1080        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1081        match service.connections.get(&peer) {
1082            Some(context) => {
1083                assert_eq!(
1084                    context.primary.connection_id(),
1085                    &ConnectionId::from(1337usize)
1086                );
1087                assert!(context.secondary.is_none());
1088            }
1089            None => panic!("expected {peer} to exist"),
1090        }
1091
1092        // close the primary connection
1093        sender
1094            .send(InnerTransportEvent::ConnectionClosed {
1095                peer,
1096                connection: ConnectionId::from(1337usize),
1097            })
1098            .await
1099            .unwrap();
1100
1101        // verify that the protocols are notified of the connection closing as well
1102        if let Some(TransportEvent::ConnectionClosed {
1103            peer: connected_peer,
1104        }) = service.next().await
1105        {
1106            assert_eq!(connected_peer, peer);
1107        } else {
1108            panic!("expected event from `TransportService`");
1109        }
1110
1111        // Because the connection was closed, the peer is no longer tracked for keep-alive.
1112        // This leads to better tracking overall since we don't have to track stale connections.
1113        assert!(service.keep_alive_tracker.last_activity.is_empty());
1114        assert!(service.connections.get(&peer).is_none());
1115
1116        // Register new primary connection.
1117        let (cmd_tx1, _cmd_rx1) = channel(64);
1118        sender
1119            .send(InnerTransportEvent::ConnectionEstablished {
1120                peer,
1121                connection: ConnectionId::from(1338usize),
1122                endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1338usize)),
1123                sender: ConnectionHandle::new(ConnectionId::from(1338usize), cmd_tx1),
1124            })
1125            .await
1126            .unwrap();
1127
1128        if let Some(TransportEvent::ConnectionEstablished {
1129            peer: connected_peer,
1130            endpoint,
1131        }) = service.next().await
1132        {
1133            assert_eq!(connected_peer, peer);
1134            assert_eq!(endpoint.address(), &Multiaddr::empty());
1135        } else {
1136            panic!("expected event from `TransportService`");
1137        };
1138
1139        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1140        match service.connections.get(&peer) {
1141            Some(context) => {
1142                assert_eq!(
1143                    context.primary.connection_id(),
1144                    &ConnectionId::from(1338usize)
1145                );
1146                assert!(context.secondary.is_none());
1147            }
1148            None => panic!("expected {peer} to exist"),
1149        }
1150
1151        match tokio::time::timeout(Duration::from_secs(10), service.next()).await {
1152            Ok(event) => panic!("didn't expect an event: {event:?}"),
1153            Err(_) => {}
1154        }
1155    }
1156
1157    async fn poll_service(service: &mut TransportService) {
1158        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
1159            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
1160            std::task::Poll::Pending => std::task::Poll::Ready(()),
1161        })
1162        .await;
1163    }
1164
1165    #[tokio::test]
1166    async fn keep_alive_timeout_downgrades_connections() {
1167        let _ = tracing_subscriber::fmt()
1168            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1169            .try_init();
1170
1171        let (mut service, sender, _) = transport_service();
1172        let peer = PeerId::random();
1173
1174        // register first connection
1175        let (cmd_tx1, _cmd_rx1) = channel(64);
1176        sender
1177            .send(InnerTransportEvent::ConnectionEstablished {
1178                peer,
1179                connection: ConnectionId::from(1337usize),
1180                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)),
1181                sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1),
1182            })
1183            .await
1184            .unwrap();
1185
1186        if let Some(TransportEvent::ConnectionEstablished {
1187            peer: connected_peer,
1188            endpoint,
1189        }) = service.next().await
1190        {
1191            assert_eq!(connected_peer, peer);
1192            assert_eq!(endpoint.address(), &Multiaddr::empty());
1193        } else {
1194            panic!("expected event from `TransportService`");
1195        };
1196
1197        // verify the first connection state is correct
1198        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1199        match service.connections.get(&peer) {
1200            Some(context) => {
1201                assert_eq!(
1202                    context.primary.connection_id(),
1203                    &ConnectionId::from(1337usize)
1204                );
1205                // Check the connection is still active.
1206                assert!(context.primary.is_active());
1207                assert!(context.secondary.is_none());
1208            }
1209            None => panic!("expected {peer} to exist"),
1210        }
1211
1212        poll_service(&mut service).await;
1213        tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await;
1214        poll_service(&mut service).await;
1215
1216        // Verify the connection is downgraded.
1217        match service.connections.get(&peer) {
1218            Some(context) => {
1219                assert_eq!(
1220                    context.primary.connection_id(),
1221                    &ConnectionId::from(1337usize)
1222                );
1223                // Check the connection is not active.
1224                assert!(!context.primary.is_active());
1225                assert!(context.secondary.is_none());
1226            }
1227            None => panic!("expected {peer} to exist"),
1228        }
1229
1230        assert_eq!(service.keep_alive_tracker.last_activity.len(), 0);
1231    }
1232
1233    #[tokio::test]
1234    async fn keep_alive_timeout_reset_when_user_opens_substream() {
1235        let _ = tracing_subscriber::fmt()
1236            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1237            .try_init();
1238
1239        let (mut service, sender, _) = transport_service();
1240        let peer = PeerId::random();
1241
1242        // register first connection
1243        let (cmd_tx1, _cmd_rx1) = channel(64);
1244        sender
1245            .send(InnerTransportEvent::ConnectionEstablished {
1246                peer,
1247                connection: ConnectionId::from(1337usize),
1248                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)),
1249                sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1),
1250            })
1251            .await
1252            .unwrap();
1253
1254        if let Some(TransportEvent::ConnectionEstablished {
1255            peer: connected_peer,
1256            endpoint,
1257        }) = service.next().await
1258        {
1259            assert_eq!(connected_peer, peer);
1260            assert_eq!(endpoint.address(), &Multiaddr::empty());
1261        } else {
1262            panic!("expected event from `TransportService`");
1263        };
1264
1265        // verify the first connection state is correct
1266        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1267        match service.connections.get(&peer) {
1268            Some(context) => {
1269                assert_eq!(
1270                    context.primary.connection_id(),
1271                    &ConnectionId::from(1337usize)
1272                );
1273                // Check the connection is still active.
1274                assert!(context.primary.is_active());
1275                assert!(context.secondary.is_none());
1276            }
1277            None => panic!("expected {peer} to exist"),
1278        }
1279
1280        poll_service(&mut service).await;
1281        // Sleep for almost the entire keep-alive timeout.
1282        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1283
1284        // This ensures we reset the keep-alive timer when other protocols
1285        // want to open a substream.
1286        // We are still tracking the same peer.
1287        service.open_substream(peer).unwrap();
1288        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1289
1290        poll_service(&mut service).await;
1291        // The keep alive timeout should be advanced.
1292        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1293        poll_service(&mut service).await;
1294        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1295        // If the `service.open_substream` wasn't called, the connection would have been downgraded.
1296        // Instead the keep-alive was forwarded `KEEP_ALIVE_TIMEOUT` seconds into the future.
1297        // Verify the connection is still active.
1298        match service.connections.get(&peer) {
1299            Some(context) => {
1300                assert_eq!(
1301                    context.primary.connection_id(),
1302                    &ConnectionId::from(1337usize)
1303                );
1304                assert!(context.primary.is_active());
1305                assert!(context.secondary.is_none());
1306            }
1307            None => panic!("expected {peer} to exist"),
1308        }
1309
1310        poll_service(&mut service).await;
1311        tokio::time::sleep(KEEP_ALIVE_TIMEOUT).await;
1312        poll_service(&mut service).await;
1313
1314        assert_eq!(service.keep_alive_tracker.last_activity.len(), 0);
1315
1316        // The connection had no substream activity for `KEEP_ALIVE_TIMEOUT` seconds.
1317        // Verify the connection is downgraded.
1318        match service.connections.get(&peer) {
1319            Some(context) => {
1320                assert_eq!(
1321                    context.primary.connection_id(),
1322                    &ConnectionId::from(1337usize)
1323                );
1324                assert!(!context.primary.is_active());
1325                assert!(context.secondary.is_none());
1326            }
1327            None => panic!("expected {peer} to exist"),
1328        }
1329    }
1330
1331    #[tokio::test]
1332    async fn downgraded_connection_without_substreams_is_closed() {
1333        let _ = tracing_subscriber::fmt()
1334            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1335            .try_init();
1336
1337        let (mut service, sender, _) = transport_service();
1338        let peer = PeerId::random();
1339
1340        // register first connection
1341        let (cmd_tx1, mut cmd_rx1) = channel(64);
1342        sender
1343            .send(InnerTransportEvent::ConnectionEstablished {
1344                peer,
1345                connection: ConnectionId::from(1337usize),
1346                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)),
1347                sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1),
1348            })
1349            .await
1350            .unwrap();
1351
1352        if let Some(TransportEvent::ConnectionEstablished {
1353            peer: connected_peer,
1354            endpoint,
1355        }) = service.next().await
1356        {
1357            assert_eq!(connected_peer, peer);
1358            assert_eq!(endpoint.address(), &Multiaddr::empty());
1359        } else {
1360            panic!("expected event from `TransportService`");
1361        };
1362
1363        // verify the first connection state is correct
1364        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1365        match service.connections.get(&peer) {
1366            Some(context) => {
1367                assert_eq!(
1368                    context.primary.connection_id(),
1369                    &ConnectionId::from(1337usize)
1370                );
1371                // Check the connection is still active.
1372                assert!(context.primary.is_active());
1373                assert!(context.secondary.is_none());
1374            }
1375            None => panic!("expected {peer} to exist"),
1376        }
1377
1378        // Open substreams to the peer.
1379        let substream_id = service.open_substream(peer).unwrap();
1380        let second_substream_id = service.open_substream(peer).unwrap();
1381
1382        // Simulate keep-alive timeout expiration.
1383        poll_service(&mut service).await;
1384        tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await;
1385        poll_service(&mut service).await;
1386
1387        let mut permits = Vec::new();
1388
1389        // First substream.
1390        let protocol_command = cmd_rx1.recv().await.unwrap();
1391        match protocol_command {
1392            ProtocolCommand::OpenSubstream {
1393                protocol,
1394                substream_id: opened_substream_id,
1395                permit,
1396                ..
1397            } => {
1398                assert_eq!(protocol, ProtocolName::from("/notif/1"));
1399                assert_eq!(substream_id, opened_substream_id);
1400
1401                // Save the substream permit for later.
1402                permits.push(permit);
1403            }
1404            _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1405        }
1406
1407        // Second substream.
1408        let protocol_command = cmd_rx1.recv().await.unwrap();
1409        match protocol_command {
1410            ProtocolCommand::OpenSubstream {
1411                protocol,
1412                substream_id: opened_substream_id,
1413                permit,
1414                ..
1415            } => {
1416                assert_eq!(protocol, ProtocolName::from("/notif/1"));
1417                assert_eq!(second_substream_id, opened_substream_id);
1418
1419                // Save the substream permit for later.
1420                permits.push(permit);
1421            }
1422            _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1423        }
1424
1425        // Drop one permit.
1426        let permit = permits.pop();
1427        // Individual transports like TCP will open a substream
1428        // and then will generate a `SubstreamOpened` event via
1429        // the protocol-set handler.
1430        //
1431        // The substream is used by individual protocols and then
1432        // is closed. This simulates the substream being closed.
1433        drop(permit);
1434
1435        // Open a new substream to the peer. This will succeed as long as we still have
1436        // one substream open.
1437        let substream_id = service.open_substream(peer).unwrap();
1438        // Handle the substream.
1439        let protocol_command = cmd_rx1.recv().await.unwrap();
1440        match protocol_command {
1441            ProtocolCommand::OpenSubstream {
1442                protocol,
1443                substream_id: opened_substream_id,
1444                permit,
1445                ..
1446            } => {
1447                assert_eq!(protocol, ProtocolName::from("/notif/1"));
1448                assert_eq!(substream_id, opened_substream_id);
1449
1450                // Save the substream permit for later.
1451                permits.push(permit);
1452            }
1453            _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1454        }
1455
1456        // Drop all substreams.
1457        drop(permits);
1458
1459        poll_service(&mut service).await;
1460        tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await;
1461        poll_service(&mut service).await;
1462
1463        // Cannot open a new substream because:
1464        // 1. connection was downgraded by keep-alive timeout
1465        // 2. all substreams were dropped.
1466        assert_eq!(
1467            service.open_substream(peer),
1468            Err(SubstreamError::ConnectionClosed)
1469        );
1470    }
1471
1472    #[tokio::test]
1473    async fn substream_opening_upgrades_connection_and_resets_keep_alive() {
1474        let _ = tracing_subscriber::fmt()
1475            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1476            .try_init();
1477
1478        let (mut service, sender, _) = transport_service();
1479        let peer = PeerId::random();
1480
1481        // register first connection
1482        let (cmd_tx1, mut cmd_rx1) = channel(64);
1483        sender
1484            .send(InnerTransportEvent::ConnectionEstablished {
1485                peer,
1486                connection: ConnectionId::from(1337usize),
1487                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)),
1488                sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1),
1489            })
1490            .await
1491            .unwrap();
1492
1493        if let Some(TransportEvent::ConnectionEstablished {
1494            peer: connected_peer,
1495            endpoint,
1496        }) = service.next().await
1497        {
1498            assert_eq!(connected_peer, peer);
1499            assert_eq!(endpoint.address(), &Multiaddr::empty());
1500        } else {
1501            panic!("expected event from `TransportService`");
1502        };
1503
1504        // verify the first connection state is correct
1505        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1506        match service.connections.get(&peer) {
1507            Some(context) => {
1508                assert_eq!(
1509                    context.primary.connection_id(),
1510                    &ConnectionId::from(1337usize)
1511                );
1512                // Check the connection is still active.
1513                assert!(context.primary.is_active());
1514                assert!(context.secondary.is_none());
1515            }
1516            None => panic!("expected {peer} to exist"),
1517        }
1518
1519        // Open substreams to the peer.
1520        let substream_id = service.open_substream(peer).unwrap();
1521        let second_substream_id = service.open_substream(peer).unwrap();
1522
1523        let mut permits = Vec::new();
1524        // First substream.
1525        let protocol_command = cmd_rx1.recv().await.unwrap();
1526        match protocol_command {
1527            ProtocolCommand::OpenSubstream {
1528                protocol,
1529                substream_id: opened_substream_id,
1530                permit,
1531                ..
1532            } => {
1533                assert_eq!(protocol, ProtocolName::from("/notif/1"));
1534                assert_eq!(substream_id, opened_substream_id);
1535
1536                // Save the substream permit for later.
1537                permits.push(permit);
1538            }
1539            _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1540        }
1541
1542        // Second substream.
1543        let protocol_command = cmd_rx1.recv().await.unwrap();
1544        match protocol_command {
1545            ProtocolCommand::OpenSubstream {
1546                protocol,
1547                substream_id: opened_substream_id,
1548                permit,
1549                ..
1550            } => {
1551                assert_eq!(protocol, ProtocolName::from("/notif/1"));
1552                assert_eq!(second_substream_id, opened_substream_id);
1553
1554                // Save the substream permit for later.
1555                permits.push(permit);
1556            }
1557            _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1558        }
1559
1560        // Sleep to trigger keep-alive timeout.
1561        poll_service(&mut service).await;
1562        tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await;
1563        poll_service(&mut service).await;
1564
1565        // Verify the connection is downgraded.
1566        match service.connections.get(&peer) {
1567            Some(context) => {
1568                assert_eq!(
1569                    context.primary.connection_id(),
1570                    &ConnectionId::from(1337usize)
1571                );
1572                // Check the connection is not active.
1573                assert!(!context.primary.is_active());
1574                assert!(context.secondary.is_none());
1575            }
1576            None => panic!("expected {peer} to exist"),
1577        }
1578        assert_eq!(service.keep_alive_tracker.last_activity.len(), 0);
1579
1580        // Open a new substream to the peer. This will succeed as long as we still have
1581        // at least substream permit.
1582        let substream_id = service.open_substream(peer).unwrap();
1583        let protocol_command = cmd_rx1.recv().await.unwrap();
1584        match protocol_command {
1585            ProtocolCommand::OpenSubstream {
1586                protocol,
1587                substream_id: opened_substream_id,
1588                permit,
1589                ..
1590            } => {
1591                assert_eq!(protocol, ProtocolName::from("/notif/1"));
1592                assert_eq!(substream_id, opened_substream_id);
1593
1594                // Save the substream permit for later.
1595                permits.push(permit);
1596            }
1597            _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1598        }
1599
1600        poll_service(&mut service).await;
1601
1602        // Verify the connection is upgraded and keep-alive is tracked.
1603        match service.connections.get(&peer) {
1604            Some(context) => {
1605                assert_eq!(
1606                    context.primary.connection_id(),
1607                    &ConnectionId::from(1337usize)
1608                );
1609                // Check the connection is active, because it was upgraded by the last substream.
1610                assert!(context.primary.is_active());
1611                assert!(context.secondary.is_none());
1612            }
1613            None => panic!("expected {peer} to exist"),
1614        }
1615        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1616
1617        // Drop all substreams
1618        drop(permits);
1619
1620        // The connection is still active, because it was upgraded by the last substream open.
1621        match service.connections.get(&peer) {
1622            Some(context) => {
1623                assert_eq!(
1624                    context.primary.connection_id(),
1625                    &ConnectionId::from(1337usize)
1626                );
1627                // Check the connection is active, because it was upgraded by the last substream.
1628                assert!(context.primary.is_active());
1629                assert!(context.secondary.is_none());
1630            }
1631            None => panic!("expected {peer} to exist"),
1632        }
1633        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1634
1635        // Sleep to trigger keep-alive timeout.
1636        poll_service(&mut service).await;
1637        tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await;
1638        poll_service(&mut service).await;
1639
1640        match service.connections.get(&peer) {
1641            Some(context) => {
1642                assert_eq!(
1643                    context.primary.connection_id(),
1644                    &ConnectionId::from(1337usize)
1645                );
1646                // No longer active because it was downgraded by keep-alive and no
1647                // substream opens were made.
1648                assert!(!context.primary.is_active());
1649                assert!(context.secondary.is_none());
1650            }
1651            None => panic!("expected {peer} to exist"),
1652        }
1653
1654        // Cannot open a new substream because:
1655        // 1. connection was downgraded by keep-alive timeout
1656        // 2. all substreams were dropped.
1657        assert_eq!(
1658            service.open_substream(peer),
1659            Err(SubstreamError::ConnectionClosed)
1660        );
1661    }
1662
1663    #[tokio::test]
1664    async fn keep_alive_pop_elements() {
1665        let mut tracker = KeepAliveTracker::new(Duration::from_secs(1));
1666
1667        let (peer1, connection1) = (PeerId::random(), ConnectionId::from(1usize));
1668        let (peer2, connection2) = (PeerId::random(), ConnectionId::from(2usize));
1669        let added_keys = HashSet::from([(peer1, connection1), (peer2, connection2)]);
1670
1671        tracker.on_connection_established(peer1, connection1);
1672        tracker.on_connection_established(peer2, connection2);
1673
1674        tokio::time::sleep(Duration::from_secs(2)).await;
1675
1676        let key = tracker.next().await.unwrap();
1677        assert!(added_keys.contains(&key));
1678
1679        let key = tracker.next().await.unwrap();
1680        assert!(added_keys.contains(&key));
1681
1682        // No more elements.
1683        assert!(tracker.pending_keep_alive_timeouts.is_empty());
1684        assert!(tracker.last_activity.is_empty());
1685    }
1686}