litep2p/protocol/
transport_service.rs

1// Copyright 2023 litep2p developers
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::{
22    addresses::PublicAddresses,
23    error::{Error, ImmediateDialError, SubstreamError},
24    protocol::{connection::ConnectionHandle, InnerTransportEvent, TransportEvent},
25    transport::{manager::TransportManagerHandle, Endpoint},
26    types::{protocol::ProtocolName, ConnectionId, SubstreamId},
27    PeerId, DEFAULT_CHANNEL_SIZE,
28};
29
30use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt};
31use multiaddr::{Multiaddr, Protocol};
32use tokio::sync::mpsc::{channel, Receiver, Sender};
33
34use std::{
35    collections::{HashMap, HashSet},
36    fmt::Debug,
37    pin::Pin,
38    sync::{
39        atomic::{AtomicUsize, Ordering},
40        Arc,
41    },
42    task::{Context, Poll, Waker},
43    time::{Duration, Instant},
44};
45
46/// Logging target for the file.
47const LOG_TARGET: &str = "litep2p::transport-service";
48
49/// Connection context for the peer.
50///
51/// Each peer is allowed to have at most two connections open. The first open connection is the
52/// primary connections which the local node uses to open substreams to remote. Secondary connection
53/// may be open if local and remote opened connections at the same time.
54///
55/// Secondary connection may be promoted to a primary connection if the primary connections closes
56/// while the secondary connections remains open.
57#[derive(Debug)]
58struct ConnectionContext {
59    /// Primary connection.
60    primary: ConnectionHandle,
61
62    /// Secondary connection, if it exists.
63    secondary: Option<ConnectionHandle>,
64}
65
66impl ConnectionContext {
67    /// Create new [`ConnectionContext`].
68    fn new(primary: ConnectionHandle) -> Self {
69        Self {
70            primary,
71            secondary: None,
72        }
73    }
74
75    /// Downgrade connection to non-active which means it will be closed
76    /// if there are no substreams open over it.
77    fn downgrade(&mut self, connection_id: &ConnectionId) {
78        if self.primary.connection_id() == connection_id {
79            self.primary.close();
80            return;
81        }
82
83        if let Some(handle) = &mut self.secondary {
84            if handle.connection_id() == connection_id {
85                handle.close();
86                return;
87            }
88        }
89
90        tracing::debug!(
91            target: LOG_TARGET,
92            primary = ?self.primary.connection_id(),
93            secondary = ?self.secondary.as_ref().map(|handle| handle.connection_id()),
94            ?connection_id,
95            "connection doesn't exist, cannot downgrade",
96        );
97    }
98
99    /// Try to upgrade the connection to active state.
100    fn try_upgrade(&mut self, connection_id: &ConnectionId) {
101        if self.primary.connection_id() == connection_id {
102            self.primary.try_upgrade();
103            return;
104        }
105
106        if let Some(handle) = &mut self.secondary {
107            if handle.connection_id() == connection_id {
108                handle.try_upgrade();
109                return;
110            }
111        }
112
113        tracing::debug!(
114            target: LOG_TARGET,
115            primary = ?self.primary.connection_id(),
116            secondary = ?self.secondary.as_ref().map(|handle| handle.connection_id()),
117            ?connection_id,
118            "connection doesn't exist, cannot upgrade",
119        );
120    }
121}
122
123/// Tracks connection keep-alive timeouts.
124///
125/// A connection keep-alive timeout is started when a connection is established.
126/// If no substreams are opened over the connection within the timeout,
127/// the connection is downgraded. However, if a substream is opened over the connection,
128/// the timeout is reset.
129#[derive(Debug)]
130struct KeepAliveTracker {
131    /// Close the connection if no substreams are open within this time frame.
132    keep_alive_timeout: Duration,
133
134    /// Track substream last activity.
135    last_activity: HashMap<(PeerId, ConnectionId), Instant>,
136
137    /// Pending keep-alive timeouts.
138    pending_keep_alive_timeouts: FuturesUnordered<BoxFuture<'static, (PeerId, ConnectionId)>>,
139
140    /// Saved waker.
141    waker: Option<Waker>,
142}
143
144impl KeepAliveTracker {
145    /// Create new [`KeepAliveTracker`].
146    pub fn new(keep_alive_timeout: Duration) -> Self {
147        Self {
148            keep_alive_timeout,
149            last_activity: HashMap::new(),
150            pending_keep_alive_timeouts: FuturesUnordered::new(),
151            waker: None,
152        }
153    }
154
155    /// Called on connection established event to add a new keep-alive timeout.
156    pub fn on_connection_established(&mut self, peer: PeerId, connection_id: ConnectionId) {
157        self.substream_activity(peer, connection_id);
158    }
159
160    /// Called on connection closed event.
161    pub fn on_connection_closed(&mut self, peer: PeerId, connection_id: ConnectionId) {
162        self.last_activity.remove(&(peer, connection_id));
163    }
164
165    /// Called on substream opened event to track the last activity.
166    pub fn substream_activity(&mut self, peer: PeerId, connection_id: ConnectionId) {
167        // Keep track of the connection ID and the time the substream was opened.
168        if self.last_activity.insert((peer, connection_id), Instant::now()).is_none() {
169            // Refill futures if there is no pending keep-alive timeout.
170            let timeout = self.keep_alive_timeout;
171            self.pending_keep_alive_timeouts.push(Box::pin(async move {
172                tokio::time::sleep(timeout).await;
173                (peer, connection_id)
174            }));
175        }
176
177        tracing::trace!(
178            target: LOG_TARGET,
179            ?peer,
180            ?connection_id,
181            ?self.keep_alive_timeout,
182            last_activity = ?self.last_activity.len(),
183            pending_keep_alive_timeouts = ?self.pending_keep_alive_timeouts.len(),
184            "substream activity",
185        );
186
187        // Wake any pending poll.
188        if let Some(waker) = self.waker.take() {
189            waker.wake()
190        }
191    }
192}
193
194impl Stream for KeepAliveTracker {
195    type Item = (PeerId, ConnectionId);
196
197    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
198        if self.pending_keep_alive_timeouts.is_empty() {
199            // No pending keep-alive timeouts.
200            self.waker = Some(cx.waker().clone());
201            return Poll::Pending;
202        }
203
204        match self.pending_keep_alive_timeouts.poll_next_unpin(cx) {
205            Poll::Ready(Some(key)) => {
206                // Check last-activity time.
207                let Some(last_activity) = self.last_activity.get(&key) else {
208                    tracing::debug!(
209                        target: LOG_TARGET,
210                        peer = ?key.0,
211                        connection_id = ?key.1,
212                        "Last activity no longer tracks the connection (closed event triggered)",
213                    );
214
215                    // We have effectively ignored this `Poll::Ready` event. To prevent the
216                    // future from getting stuck, we need to tell the executor to poll again
217                    // for more events.
218                    cx.waker().wake_by_ref();
219                    return Poll::Pending;
220                };
221
222                // Keep-alive timeout not reached yet.
223                let inactive_for = last_activity.elapsed();
224                if inactive_for < self.keep_alive_timeout {
225                    let timeout = self.keep_alive_timeout.saturating_sub(inactive_for);
226
227                    tracing::trace!(
228                        target: LOG_TARGET,
229                        peer = ?key.0,
230                        connection_id = ?key.1,
231                        ?timeout,
232                        "keep-alive timeout not yet reached",
233                    );
234
235                    // Refill the keep alive timeouts.
236                    self.pending_keep_alive_timeouts.push(Box::pin(async move {
237                        tokio::time::sleep(timeout).await;
238                        key
239                    }));
240
241                    // This is similar to the `last_activity` check above, we need to inform
242                    // the executor that this object may produce more events.
243                    cx.waker().wake_by_ref();
244                    return Poll::Pending;
245                }
246
247                // Keep-alive timeout reached.
248                tracing::debug!(
249                    target: LOG_TARGET,
250                    peer = ?key.0,
251                    connection_id = ?key.1,
252                    "keep-alive timeout triggered",
253                );
254                self.last_activity.remove(&key);
255                Poll::Ready(Some(key))
256            }
257            Poll::Ready(None) | Poll::Pending => Poll::Pending,
258        }
259    }
260}
261
262/// Whether this protocol substream activity can keep connection alive.
263#[derive(Debug, Clone, Copy, PartialEq, Eq)]
264pub enum SubstreamKeepAlive {
265    /// Yes.
266    Yes,
267    /// No.
268    No,
269}
270
271impl SubstreamKeepAlive {
272    /// Shortcut to `(self == SubstreamKeepAlive::Yes).then()`.
273    #[inline]
274    pub fn then<T, F: FnOnce() -> T>(&self, f: F) -> Option<T> {
275        (*self == SubstreamKeepAlive::Yes).then(f)
276    }
277}
278
279/// Provides an interfaces for [`Litep2p`](crate::Litep2p) protocols to interact
280/// with the underlying transport protocols.
281#[derive(Debug)]
282pub struct TransportService {
283    /// Local peer ID.
284    local_peer_id: PeerId,
285
286    /// Protocol.
287    protocol: ProtocolName,
288
289    /// Fallback names for the protocol.
290    fallback_names: Vec<ProtocolName>,
291
292    /// Open connections.
293    connections: HashMap<PeerId, ConnectionContext>,
294
295    /// Transport handle.
296    transport_handle: TransportManagerHandle,
297
298    /// RX channel for receiving events from tranports and connections.
299    rx: Receiver<InnerTransportEvent>,
300
301    /// Next substream ID.
302    next_substream_id: Arc<AtomicUsize>,
303
304    /// Close the connection if no substreams are open within this time frame.
305    keep_alive_tracker: KeepAliveTracker,
306
307    /// Whether this protocol susbstreams should keep connection alive.
308    substream_keep_alive: SubstreamKeepAlive,
309}
310
311impl TransportService {
312    /// Create new [`TransportService`].
313    pub(crate) fn new(
314        local_peer_id: PeerId,
315        protocol: ProtocolName,
316        fallback_names: Vec<ProtocolName>,
317        next_substream_id: Arc<AtomicUsize>,
318        transport_handle: TransportManagerHandle,
319        keep_alive_timeout: Duration,
320        substream_keep_alive: SubstreamKeepAlive,
321    ) -> (Self, Sender<InnerTransportEvent>) {
322        let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE);
323
324        let keep_alive_tracker = KeepAliveTracker::new(keep_alive_timeout);
325
326        (
327            Self {
328                rx,
329                protocol,
330                local_peer_id,
331                fallback_names,
332                transport_handle,
333                next_substream_id,
334                connections: HashMap::new(),
335                keep_alive_tracker,
336                substream_keep_alive,
337            },
338            tx,
339        )
340    }
341
342    /// Get the list of public addresses of the node.
343    pub fn public_addresses(&self) -> PublicAddresses {
344        self.transport_handle.public_addresses()
345    }
346
347    /// Get the list of listen addresses of the node.
348    pub fn listen_addresses(&self) -> HashSet<Multiaddr> {
349        self.transport_handle.listen_addresses()
350    }
351
352    /// Handle connection established event.
353    fn on_connection_established(
354        &mut self,
355        peer: PeerId,
356        endpoint: Endpoint,
357        connection_id: ConnectionId,
358        handle: ConnectionHandle,
359    ) -> Option<TransportEvent> {
360        tracing::debug!(
361            target: LOG_TARGET,
362            ?peer,
363            ?endpoint,
364            ?connection_id,
365            protocol = %self.protocol,
366            current_state = ?self.connections.get(&peer),
367            "on connection established",
368        );
369
370        match self.connections.get_mut(&peer) {
371            Some(context) => match context.secondary {
372                Some(_) => {
373                    tracing::debug!(
374                        target: LOG_TARGET,
375                        ?peer,
376                        ?connection_id,
377                        ?endpoint,
378                        protocol = %self.protocol,
379                        "ignoring third connection",
380                    );
381                    None
382                }
383                None => {
384                    self.keep_alive_tracker.on_connection_established(peer, connection_id);
385
386                    tracing::trace!(
387                        target: LOG_TARGET,
388                        ?peer,
389                        ?endpoint,
390                        ?connection_id,
391                        protocol = %self.protocol,
392                        "secondary connection established",
393                    );
394
395                    context.secondary = Some(handle);
396
397                    None
398                }
399            },
400            None => {
401                tracing::trace!(
402                    target: LOG_TARGET,
403                    ?peer,
404                    ?endpoint,
405                    ?connection_id,
406                    protocol = %self.protocol,
407                    "primary connection established",
408                );
409
410                self.connections.insert(peer, ConnectionContext::new(handle));
411
412                self.keep_alive_tracker.on_connection_established(peer, connection_id);
413
414                Some(TransportEvent::ConnectionEstablished { peer, endpoint })
415            }
416        }
417    }
418
419    /// Handle connection closed event.
420    fn on_connection_closed(
421        &mut self,
422        peer: PeerId,
423        connection_id: ConnectionId,
424    ) -> Option<TransportEvent> {
425        tracing::debug!(
426            target: LOG_TARGET,
427            ?peer,
428            ?connection_id,
429            protocol = %self.protocol,
430            current_state = ?self.connections.get(&peer),
431            "on connection closed",
432        );
433
434        self.keep_alive_tracker.on_connection_closed(peer, connection_id);
435
436        let Some(context) = self.connections.get_mut(&peer) else {
437            tracing::warn!(
438                target: LOG_TARGET,
439                ?peer,
440                ?connection_id,
441                protocol = %self.protocol,
442                "connection closed to a non-existent peer",
443            );
444
445            debug_assert!(false);
446            return None;
447        };
448
449        // if the primary connection was closed, check if there exist a secondary connection
450        // and if it does, convert the secondary connection a primary connection
451        if context.primary.connection_id() == &connection_id {
452            tracing::trace!(
453                target: LOG_TARGET,
454                ?peer,
455                ?connection_id,
456                protocol = %self.protocol,
457                "primary connection closed"
458            );
459
460            match context.secondary.take() {
461                None => {
462                    self.connections.remove(&peer);
463                    return Some(TransportEvent::ConnectionClosed { peer });
464                }
465                Some(handle) => {
466                    tracing::debug!(
467                        target: LOG_TARGET,
468                        ?peer,
469                        ?connection_id,
470                        protocol = %self.protocol,
471                        "switch to secondary connection",
472                    );
473
474                    context.primary = handle;
475                    return None;
476                }
477            }
478        }
479
480        match context.secondary.take() {
481            Some(handle) if handle.connection_id() == &connection_id => {
482                tracing::trace!(
483                    target: LOG_TARGET,
484                    ?peer,
485                    ?connection_id,
486                    protocol = %self.protocol,
487                    "secondary connection closed",
488                );
489
490                None
491            }
492            connection_state => {
493                tracing::debug!(
494                    target: LOG_TARGET,
495                    ?peer,
496                    ?connection_id,
497                    ?connection_state,
498                    protocol = %self.protocol,
499                    "connection closed but it doesn't exist",
500                );
501
502                None
503            }
504        }
505    }
506
507    /// Dial `peer` using `PeerId`.
508    ///
509    /// Call fails if `Litep2p` doesn't have a known address for the peer.
510    pub fn dial(&mut self, peer: &PeerId) -> Result<(), ImmediateDialError> {
511        tracing::trace!(
512            target: LOG_TARGET,
513            ?peer,
514            protocol = %self.protocol,
515            "Dial peer requested",
516        );
517
518        self.transport_handle.dial(peer)
519    }
520
521    /// Dial peer using a `Multiaddr`.
522    ///
523    /// Call fails if the address is not in correct format or it contains an unsupported/disabled
524    /// transport.
525    ///
526    /// Calling this function is only necessary for those addresses that are discovered out-of-band
527    /// since `Litep2p` internally keeps track of all peer addresses it has learned through user
528    /// calling this function, Kademlia peer discoveries and `Identify` responses.
529    pub fn dial_address(&mut self, address: Multiaddr) -> Result<(), ImmediateDialError> {
530        tracing::trace!(
531            target: LOG_TARGET,
532            ?address,
533            protocol = %self.protocol,
534            "Dial address requested",
535        );
536
537        self.transport_handle.dial_address(address)
538    }
539
540    /// Add one or more addresses for `peer`.
541    ///
542    /// The list is filtered for duplicates and unsupported transports.
543    pub fn add_known_address(&mut self, peer: &PeerId, addresses: impl Iterator<Item = Multiaddr>) {
544        let addresses: HashSet<Multiaddr> = addresses
545            .map(|address| {
546                if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) {
547                    address.with(Protocol::P2p((*peer).into()))
548                } else {
549                    address
550                }
551            })
552            .collect();
553
554        self.transport_handle.add_known_address(peer, addresses.into_iter());
555    }
556
557    /// Open substream to `peer`.
558    ///
559    /// Call fails if there is no connection open to `peer` or the channel towards
560    /// the connection is clogged.
561    pub fn open_substream(&mut self, peer: PeerId) -> Result<SubstreamId, SubstreamError> {
562        // always prefer the primary connection
563        let connection = &mut self
564            .connections
565            .get_mut(&peer)
566            .ok_or(SubstreamError::PeerDoesNotExist(peer))?
567            .primary;
568
569        let connection_id = *connection.connection_id();
570
571        // This permit will be passed on until the substream is reported back to
572        // [`TransportService`] in [`InnerTransportEvent::SubstreamOpened`] and connection
573        // upgraded.
574        let permit = connection.try_get_permit().ok_or(SubstreamError::ConnectionClosed)?;
575
576        let substream_id =
577            SubstreamId::from(self.next_substream_id.fetch_add(1usize, Ordering::Relaxed));
578
579        tracing::trace!(
580            target: LOG_TARGET,
581            ?peer,
582            protocol = %self.protocol,
583            ?substream_id,
584            ?connection_id,
585            "open substream",
586        );
587
588        if self.substream_keep_alive == SubstreamKeepAlive::Yes {
589            self.keep_alive_tracker.substream_activity(peer, connection_id);
590            connection.try_upgrade();
591        }
592
593        connection
594            .open_substream(
595                self.protocol.clone(),
596                self.fallback_names.clone(),
597                substream_id,
598                permit,
599                self.substream_keep_alive,
600            )
601            .map(|_| substream_id)
602    }
603
604    /// Forcibly close the connection, even if other protocols have substreams open over it.
605    pub fn force_close(&mut self, peer: PeerId) -> crate::Result<()> {
606        let connection =
607            &mut self.connections.get_mut(&peer).ok_or(Error::PeerDoesntExist(peer))?;
608
609        tracing::trace!(
610            target: LOG_TARGET,
611            ?peer,
612            protocol = %self.protocol,
613            secondary = ?connection.secondary,
614            "forcibly closing the connection",
615        );
616
617        if let Some(ref mut connection) = connection.secondary {
618            let _ = connection.force_close();
619        }
620
621        connection.primary.force_close()
622    }
623
624    /// Get local peer ID.
625    pub fn local_peer_id(&self) -> PeerId {
626        self.local_peer_id
627    }
628
629    /// Dynamically unregister a protocol.
630    ///
631    /// This must be called when a protocol is no longer needed (e.g. user dropped the protocol
632    /// handle).
633    pub fn unregister_protocol(&self) {
634        self.transport_handle.unregister_protocol(self.protocol.clone());
635    }
636}
637
638impl Stream for TransportService {
639    type Item = TransportEvent;
640
641    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
642        let protocol_name = self.protocol.clone();
643        let keep_alive_timeout = self.keep_alive_tracker.keep_alive_timeout;
644
645        while let Poll::Ready(event) = self.rx.poll_recv(cx) {
646            match event {
647                None => {
648                    tracing::warn!(
649                        target: LOG_TARGET,
650                        protocol = ?protocol_name,
651                        "transport service closed"
652                    );
653                    return Poll::Ready(None);
654                }
655                Some(InnerTransportEvent::ConnectionEstablished {
656                    peer,
657                    endpoint,
658                    sender,
659                    connection,
660                }) => {
661                    if let Some(event) =
662                        self.on_connection_established(peer, endpoint, connection, sender)
663                    {
664                        return Poll::Ready(Some(event));
665                    }
666                }
667                Some(InnerTransportEvent::ConnectionClosed { peer, connection }) => {
668                    if let Some(event) = self.on_connection_closed(peer, connection) {
669                        return Poll::Ready(Some(event));
670                    }
671                }
672                Some(InnerTransportEvent::SubstreamOpened {
673                    peer,
674                    protocol,
675                    fallback,
676                    direction,
677                    substream,
678                    connection_id,
679                    opening_permit,
680                }) => {
681                    if protocol == self.protocol
682                        && self.substream_keep_alive == SubstreamKeepAlive::Yes
683                    {
684                        self.keep_alive_tracker.substream_activity(peer, connection_id);
685                        if let Some(context) = self.connections.get_mut(&peer) {
686                            context.try_upgrade(&connection_id);
687                        }
688                    }
689
690                    // Connection is upgraded, we must now drop the permit.
691                    // This is for the reader, not for compiler.
692                    drop(opening_permit);
693
694                    return Poll::Ready(Some(TransportEvent::SubstreamOpened {
695                        peer,
696                        protocol,
697                        fallback,
698                        direction,
699                        substream,
700                    }));
701                }
702                Some(event) => return Poll::Ready(Some(event.into())),
703            }
704        }
705
706        while let Poll::Ready(Some((peer, connection_id))) =
707            self.keep_alive_tracker.poll_next_unpin(cx)
708        {
709            if let Some(context) = self.connections.get_mut(&peer) {
710                tracing::debug!(
711                    target: LOG_TARGET,
712                    ?peer,
713                    ?connection_id,
714                    protocol = ?protocol_name,
715                    timeout = ?keep_alive_timeout,
716                    "keep-alive timeout over, downgrade connection",
717                );
718
719                context.downgrade(&connection_id);
720            }
721        }
722
723        Poll::Pending
724    }
725}
726
727#[cfg(test)]
728mod tests {
729    use super::*;
730    use crate::{
731        protocol::{ProtocolCommand, SubstreamKeepAlive, TransportService},
732        transport::{
733            manager::{handle::InnerTransportManagerCommand, TransportManagerHandle},
734            KEEP_ALIVE_TIMEOUT,
735        },
736    };
737    use futures::StreamExt;
738    use parking_lot::RwLock;
739    use std::collections::HashSet;
740
741    /// Create new `TransportService`
742    fn transport_service() -> (
743        TransportService,
744        Sender<InnerTransportEvent>,
745        Receiver<InnerTransportManagerCommand>,
746    ) {
747        let (cmd_tx, cmd_rx) = channel(64);
748        let peer = PeerId::random();
749
750        let handle = TransportManagerHandle::new(
751            peer,
752            Arc::new(RwLock::new(HashMap::new())),
753            cmd_tx,
754            HashSet::new(),
755            Default::default(),
756            PublicAddresses::new(peer),
757        );
758
759        let (service, sender) = TransportService::new(
760            peer,
761            ProtocolName::from("/notif/1"),
762            Vec::new(),
763            Arc::new(AtomicUsize::new(0usize)),
764            handle,
765            KEEP_ALIVE_TIMEOUT,
766            SubstreamKeepAlive::Yes,
767        );
768
769        (service, sender, cmd_rx)
770    }
771
772    #[tokio::test]
773    async fn secondary_connection_stored() {
774        let (mut service, sender, _) = transport_service();
775        let peer = PeerId::random();
776
777        // register first connection
778        let (cmd_tx1, _cmd_rx1) = channel(64);
779        sender
780            .send(InnerTransportEvent::ConnectionEstablished {
781                peer,
782                connection: ConnectionId::from(0usize),
783                endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(0usize)),
784                sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1),
785            })
786            .await
787            .unwrap();
788
789        if let Some(TransportEvent::ConnectionEstablished {
790            peer: connected_peer,
791            endpoint,
792        }) = service.next().await
793        {
794            assert_eq!(connected_peer, peer);
795            assert_eq!(endpoint.address(), &Multiaddr::empty());
796        } else {
797            panic!("expected event from `TransportService`");
798        };
799
800        // register secondary connection
801        let (cmd_tx2, _cmd_rx2) = channel(64);
802        sender
803            .send(InnerTransportEvent::ConnectionEstablished {
804                peer,
805                connection: ConnectionId::from(1usize),
806                endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1usize)),
807                sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2),
808            })
809            .await
810            .unwrap();
811
812        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
813            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
814            std::task::Poll::Pending => std::task::Poll::Ready(()),
815        })
816        .await;
817
818        let context = service.connections.get(&peer).unwrap();
819        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
820        assert_eq!(
821            context.secondary.as_ref().unwrap().connection_id(),
822            &ConnectionId::from(1usize)
823        );
824    }
825
826    #[tokio::test]
827    async fn tertiary_connection_ignored() {
828        let (mut service, sender, _) = transport_service();
829        let peer = PeerId::random();
830
831        // register first connection
832        let (cmd_tx1, _cmd_rx1) = channel(64);
833        sender
834            .send(InnerTransportEvent::ConnectionEstablished {
835                peer,
836                connection: ConnectionId::from(0usize),
837                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)),
838                sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1),
839            })
840            .await
841            .unwrap();
842
843        if let Some(TransportEvent::ConnectionEstablished {
844            peer: connected_peer,
845            endpoint,
846        }) = service.next().await
847        {
848            assert_eq!(connected_peer, peer);
849            assert_eq!(endpoint.address(), &Multiaddr::empty());
850        } else {
851            panic!("expected event from `TransportService`");
852        };
853
854        // register secondary connection
855        let (cmd_tx2, _cmd_rx2) = channel(64);
856        sender
857            .send(InnerTransportEvent::ConnectionEstablished {
858                peer,
859                connection: ConnectionId::from(1usize),
860                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1usize)),
861                sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2),
862            })
863            .await
864            .unwrap();
865
866        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
867            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
868            std::task::Poll::Pending => std::task::Poll::Ready(()),
869        })
870        .await;
871
872        let context = service.connections.get(&peer).unwrap();
873        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
874        assert_eq!(
875            context.secondary.as_ref().unwrap().connection_id(),
876            &ConnectionId::from(1usize)
877        );
878
879        // try to register tertiary connection and verify it's ignored
880        let (cmd_tx3, mut cmd_rx3) = channel(64);
881        sender
882            .send(InnerTransportEvent::ConnectionEstablished {
883                peer,
884                connection: ConnectionId::from(2usize),
885                endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(2usize)),
886                sender: ConnectionHandle::new(ConnectionId::from(2usize), cmd_tx3),
887            })
888            .await
889            .unwrap();
890
891        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
892            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
893            std::task::Poll::Pending => std::task::Poll::Ready(()),
894        })
895        .await;
896
897        let context = service.connections.get(&peer).unwrap();
898        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
899        assert_eq!(
900            context.secondary.as_ref().unwrap().connection_id(),
901            &ConnectionId::from(1usize)
902        );
903        assert!(cmd_rx3.try_recv().is_err());
904    }
905
906    #[tokio::test]
907    async fn secondary_closing_does_not_emit_event() {
908        let _ = tracing_subscriber::fmt()
909            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
910            .try_init();
911
912        let (mut service, sender, _) = transport_service();
913        let peer = PeerId::random();
914
915        // register first connection
916        let (cmd_tx1, _cmd_rx1) = channel(64);
917        sender
918            .send(InnerTransportEvent::ConnectionEstablished {
919                peer,
920                connection: ConnectionId::from(0usize),
921                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)),
922                sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1),
923            })
924            .await
925            .unwrap();
926
927        if let Some(TransportEvent::ConnectionEstablished {
928            peer: connected_peer,
929            endpoint,
930        }) = service.next().await
931        {
932            assert_eq!(connected_peer, peer);
933            assert_eq!(endpoint.address(), &Multiaddr::empty());
934        } else {
935            panic!("expected event from `TransportService`");
936        };
937
938        // register secondary connection
939        let (cmd_tx2, _cmd_rx2) = channel(64);
940        sender
941            .send(InnerTransportEvent::ConnectionEstablished {
942                peer,
943                connection: ConnectionId::from(1usize),
944                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1usize)),
945                sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2),
946            })
947            .await
948            .unwrap();
949
950        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
951            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
952            std::task::Poll::Pending => std::task::Poll::Ready(()),
953        })
954        .await;
955
956        let context = service.connections.get(&peer).unwrap();
957        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
958        assert_eq!(
959            context.secondary.as_ref().unwrap().connection_id(),
960            &ConnectionId::from(1usize)
961        );
962
963        // close the secondary connection
964        sender
965            .send(InnerTransportEvent::ConnectionClosed {
966                peer,
967                connection: ConnectionId::from(1usize),
968            })
969            .await
970            .unwrap();
971
972        // verify that the protocol is not notified
973        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
974            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
975            std::task::Poll::Pending => std::task::Poll::Ready(()),
976        })
977        .await;
978
979        // verify that the secondary connection doesn't exist anymore
980        let context = service.connections.get(&peer).unwrap();
981        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
982        assert!(context.secondary.is_none());
983    }
984
985    #[tokio::test]
986    async fn convert_secondary_to_primary() {
987        let (mut service, sender, _) = transport_service();
988        let peer = PeerId::random();
989
990        // register first connection
991        let (cmd_tx1, mut cmd_rx1) = channel(64);
992        sender
993            .send(InnerTransportEvent::ConnectionEstablished {
994                peer,
995                connection: ConnectionId::from(0usize),
996                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)),
997                sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1),
998            })
999            .await
1000            .unwrap();
1001
1002        if let Some(TransportEvent::ConnectionEstablished {
1003            peer: connected_peer,
1004            endpoint,
1005        }) = service.next().await
1006        {
1007            assert_eq!(connected_peer, peer);
1008            assert_eq!(endpoint.address(), &Multiaddr::empty());
1009        } else {
1010            panic!("expected event from `TransportService`");
1011        };
1012
1013        // register secondary connection
1014        let (cmd_tx2, mut cmd_rx2) = channel(64);
1015        sender
1016            .send(InnerTransportEvent::ConnectionEstablished {
1017                peer,
1018                connection: ConnectionId::from(1usize),
1019                endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1usize)),
1020                sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2),
1021            })
1022            .await
1023            .unwrap();
1024
1025        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
1026            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
1027            std::task::Poll::Pending => std::task::Poll::Ready(()),
1028        })
1029        .await;
1030
1031        let context = service.connections.get(&peer).unwrap();
1032        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
1033        assert_eq!(
1034            context.secondary.as_ref().unwrap().connection_id(),
1035            &ConnectionId::from(1usize)
1036        );
1037
1038        // close the primary connection
1039        sender
1040            .send(InnerTransportEvent::ConnectionClosed {
1041                peer,
1042                connection: ConnectionId::from(0usize),
1043            })
1044            .await
1045            .unwrap();
1046
1047        // verify that the protocol is not notified
1048        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
1049            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
1050            std::task::Poll::Pending => std::task::Poll::Ready(()),
1051        })
1052        .await;
1053
1054        // verify that the primary connection has been replaced
1055        let context = service.connections.get(&peer).unwrap();
1056        assert_eq!(context.primary.connection_id(), &ConnectionId::from(1usize));
1057        assert!(context.secondary.is_none());
1058        assert!(cmd_rx1.try_recv().is_err());
1059
1060        // close the secondary connection as well
1061        sender
1062            .send(InnerTransportEvent::ConnectionClosed {
1063                peer,
1064                connection: ConnectionId::from(1usize),
1065            })
1066            .await
1067            .unwrap();
1068
1069        if let Some(TransportEvent::ConnectionClosed {
1070            peer: disconnected_peer,
1071        }) = service.next().await
1072        {
1073            assert_eq!(disconnected_peer, peer);
1074        } else {
1075            panic!("expected event from `TransportService`");
1076        };
1077
1078        // verify that the primary connection has been replaced
1079        assert!(!service.connections.contains_key(&peer));
1080        assert!(cmd_rx2.try_recv().is_err());
1081    }
1082
1083    #[tokio::test]
1084    async fn keep_alive_timeout_expires_for_a_stale_connection() {
1085        let _ = tracing_subscriber::fmt()
1086            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1087            .try_init();
1088
1089        let (mut service, sender, _) = transport_service();
1090        let peer = PeerId::random();
1091
1092        // register first connection
1093        let (cmd_tx1, _cmd_rx1) = channel(64);
1094        sender
1095            .send(InnerTransportEvent::ConnectionEstablished {
1096                peer,
1097                connection: ConnectionId::from(1337usize),
1098                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)),
1099                sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1),
1100            })
1101            .await
1102            .unwrap();
1103
1104        if let Some(TransportEvent::ConnectionEstablished {
1105            peer: connected_peer,
1106            endpoint,
1107        }) = service.next().await
1108        {
1109            assert_eq!(connected_peer, peer);
1110            assert_eq!(endpoint.address(), &Multiaddr::empty());
1111        } else {
1112            panic!("expected event from `TransportService`");
1113        };
1114
1115        // verify the first connection state is correct
1116        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1117        match service.connections.get(&peer) {
1118            Some(context) => {
1119                assert_eq!(
1120                    context.primary.connection_id(),
1121                    &ConnectionId::from(1337usize)
1122                );
1123                assert!(context.secondary.is_none());
1124            }
1125            None => panic!("expected {peer} to exist"),
1126        }
1127
1128        // close the primary connection
1129        sender
1130            .send(InnerTransportEvent::ConnectionClosed {
1131                peer,
1132                connection: ConnectionId::from(1337usize),
1133            })
1134            .await
1135            .unwrap();
1136
1137        // verify that the protocols are notified of the connection closing as well
1138        if let Some(TransportEvent::ConnectionClosed {
1139            peer: connected_peer,
1140        }) = service.next().await
1141        {
1142            assert_eq!(connected_peer, peer);
1143        } else {
1144            panic!("expected event from `TransportService`");
1145        }
1146
1147        // Because the connection was closed, the peer is no longer tracked for keep-alive.
1148        // This leads to better tracking overall since we don't have to track stale connections.
1149        assert!(service.keep_alive_tracker.last_activity.is_empty());
1150        assert!(!service.connections.contains_key(&peer));
1151
1152        // Register new primary connection.
1153        let (cmd_tx1, _cmd_rx1) = channel(64);
1154        sender
1155            .send(InnerTransportEvent::ConnectionEstablished {
1156                peer,
1157                connection: ConnectionId::from(1338usize),
1158                endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1338usize)),
1159                sender: ConnectionHandle::new(ConnectionId::from(1338usize), cmd_tx1),
1160            })
1161            .await
1162            .unwrap();
1163
1164        if let Some(TransportEvent::ConnectionEstablished {
1165            peer: connected_peer,
1166            endpoint,
1167        }) = service.next().await
1168        {
1169            assert_eq!(connected_peer, peer);
1170            assert_eq!(endpoint.address(), &Multiaddr::empty());
1171        } else {
1172            panic!("expected event from `TransportService`");
1173        };
1174
1175        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1176        match service.connections.get(&peer) {
1177            Some(context) => {
1178                assert_eq!(
1179                    context.primary.connection_id(),
1180                    &ConnectionId::from(1338usize)
1181                );
1182                assert!(context.secondary.is_none());
1183            }
1184            None => panic!("expected {peer} to exist"),
1185        }
1186
1187        match tokio::time::timeout(Duration::from_secs(10), service.next()).await {
1188            Ok(event) => panic!("didn't expect an event: {event:?}"),
1189            Err(_) => {}
1190        }
1191    }
1192
1193    async fn poll_service(service: &mut TransportService) {
1194        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
1195            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
1196            std::task::Poll::Pending => std::task::Poll::Ready(()),
1197        })
1198        .await;
1199    }
1200
1201    #[tokio::test]
1202    async fn keep_alive_timeout_downgrades_connections() {
1203        let _ = tracing_subscriber::fmt()
1204            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1205            .try_init();
1206
1207        let (mut service, sender, _) = transport_service();
1208        let peer = PeerId::random();
1209
1210        // register first connection
1211        let (cmd_tx1, _cmd_rx1) = channel(64);
1212        sender
1213            .send(InnerTransportEvent::ConnectionEstablished {
1214                peer,
1215                connection: ConnectionId::from(1337usize),
1216                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)),
1217                sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1),
1218            })
1219            .await
1220            .unwrap();
1221
1222        if let Some(TransportEvent::ConnectionEstablished {
1223            peer: connected_peer,
1224            endpoint,
1225        }) = service.next().await
1226        {
1227            assert_eq!(connected_peer, peer);
1228            assert_eq!(endpoint.address(), &Multiaddr::empty());
1229        } else {
1230            panic!("expected event from `TransportService`");
1231        };
1232
1233        // verify the first connection state is correct
1234        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1235        match service.connections.get(&peer) {
1236            Some(context) => {
1237                assert_eq!(
1238                    context.primary.connection_id(),
1239                    &ConnectionId::from(1337usize)
1240                );
1241                // Check the connection is still active.
1242                assert!(context.primary.is_active());
1243                assert!(context.secondary.is_none());
1244            }
1245            None => panic!("expected {peer} to exist"),
1246        }
1247
1248        poll_service(&mut service).await;
1249        tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await;
1250        poll_service(&mut service).await;
1251
1252        // Verify the connection is downgraded.
1253        match service.connections.get(&peer) {
1254            Some(context) => {
1255                assert_eq!(
1256                    context.primary.connection_id(),
1257                    &ConnectionId::from(1337usize)
1258                );
1259                // Check the connection is not active.
1260                assert!(!context.primary.is_active());
1261                assert!(context.secondary.is_none());
1262            }
1263            None => panic!("expected {peer} to exist"),
1264        }
1265
1266        assert_eq!(service.keep_alive_tracker.last_activity.len(), 0);
1267    }
1268
1269    #[tokio::test]
1270    async fn keep_alive_timeout_reset_when_user_opens_substream() {
1271        let _ = tracing_subscriber::fmt()
1272            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1273            .try_init();
1274
1275        let (mut service, sender, _) = transport_service();
1276        let peer = PeerId::random();
1277
1278        // register first connection
1279        let (cmd_tx1, _cmd_rx1) = channel(64);
1280        sender
1281            .send(InnerTransportEvent::ConnectionEstablished {
1282                peer,
1283                connection: ConnectionId::from(1337usize),
1284                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)),
1285                sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1),
1286            })
1287            .await
1288            .unwrap();
1289
1290        if let Some(TransportEvent::ConnectionEstablished {
1291            peer: connected_peer,
1292            endpoint,
1293        }) = service.next().await
1294        {
1295            assert_eq!(connected_peer, peer);
1296            assert_eq!(endpoint.address(), &Multiaddr::empty());
1297        } else {
1298            panic!("expected event from `TransportService`");
1299        };
1300
1301        // verify the first connection state is correct
1302        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1303        match service.connections.get(&peer) {
1304            Some(context) => {
1305                assert_eq!(
1306                    context.primary.connection_id(),
1307                    &ConnectionId::from(1337usize)
1308                );
1309                // Check the connection is still active.
1310                assert!(context.primary.is_active());
1311                assert!(context.secondary.is_none());
1312            }
1313            None => panic!("expected {peer} to exist"),
1314        }
1315
1316        poll_service(&mut service).await;
1317        // Sleep for almost the entire keep-alive timeout.
1318        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1319
1320        // This ensures we reset the keep-alive timer when other protocols
1321        // want to open a substream.
1322        // We are still tracking the same peer.
1323        service.open_substream(peer).unwrap();
1324        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1325
1326        poll_service(&mut service).await;
1327        // The keep alive timeout should be advanced.
1328        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
1329        poll_service(&mut service).await;
1330        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1331        // If the `service.open_substream` wasn't called, the connection would have been downgraded.
1332        // Instead the keep-alive was forwarded `KEEP_ALIVE_TIMEOUT` seconds into the future.
1333        // Verify the connection is still active.
1334        match service.connections.get(&peer) {
1335            Some(context) => {
1336                assert_eq!(
1337                    context.primary.connection_id(),
1338                    &ConnectionId::from(1337usize)
1339                );
1340                assert!(context.primary.is_active());
1341                assert!(context.secondary.is_none());
1342            }
1343            None => panic!("expected {peer} to exist"),
1344        }
1345
1346        poll_service(&mut service).await;
1347        tokio::time::sleep(KEEP_ALIVE_TIMEOUT).await;
1348        poll_service(&mut service).await;
1349
1350        assert_eq!(service.keep_alive_tracker.last_activity.len(), 0);
1351
1352        // The connection had no substream activity for `KEEP_ALIVE_TIMEOUT` seconds.
1353        // Verify the connection is downgraded.
1354        match service.connections.get(&peer) {
1355            Some(context) => {
1356                assert_eq!(
1357                    context.primary.connection_id(),
1358                    &ConnectionId::from(1337usize)
1359                );
1360                assert!(!context.primary.is_active());
1361                assert!(context.secondary.is_none());
1362            }
1363            None => panic!("expected {peer} to exist"),
1364        }
1365    }
1366
1367    #[tokio::test]
1368    async fn downgraded_connection_without_substreams_is_closed() {
1369        let _ = tracing_subscriber::fmt()
1370            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1371            .try_init();
1372
1373        let (mut service, sender, _) = transport_service();
1374        let peer = PeerId::random();
1375
1376        // register first connection
1377        let (cmd_tx1, mut cmd_rx1) = channel(64);
1378        sender
1379            .send(InnerTransportEvent::ConnectionEstablished {
1380                peer,
1381                connection: ConnectionId::from(1337usize),
1382                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)),
1383                sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1),
1384            })
1385            .await
1386            .unwrap();
1387
1388        if let Some(TransportEvent::ConnectionEstablished {
1389            peer: connected_peer,
1390            endpoint,
1391        }) = service.next().await
1392        {
1393            assert_eq!(connected_peer, peer);
1394            assert_eq!(endpoint.address(), &Multiaddr::empty());
1395        } else {
1396            panic!("expected event from `TransportService`");
1397        };
1398
1399        // verify the first connection state is correct
1400        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1401        match service.connections.get(&peer) {
1402            Some(context) => {
1403                assert_eq!(
1404                    context.primary.connection_id(),
1405                    &ConnectionId::from(1337usize)
1406                );
1407                // Check the connection is still active.
1408                assert!(context.primary.is_active());
1409                assert!(context.secondary.is_none());
1410            }
1411            None => panic!("expected {peer} to exist"),
1412        }
1413
1414        // Open substreams to the peer.
1415        let substream_id = service.open_substream(peer).unwrap();
1416        let second_substream_id = service.open_substream(peer).unwrap();
1417
1418        // Simulate keep-alive timeout expiration.
1419        poll_service(&mut service).await;
1420        tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await;
1421        poll_service(&mut service).await;
1422
1423        let mut permits = Vec::new();
1424
1425        // First substream.
1426        let protocol_command = cmd_rx1.recv().await.unwrap();
1427        match protocol_command {
1428            ProtocolCommand::OpenSubstream {
1429                protocol,
1430                substream_id: opened_substream_id,
1431                permit,
1432                ..
1433            } => {
1434                assert_eq!(protocol, ProtocolName::from("/notif/1"));
1435                assert_eq!(substream_id, opened_substream_id);
1436
1437                // Save the substream permit for later.
1438                permits.push(permit);
1439            }
1440            _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1441        }
1442
1443        // Second substream.
1444        let protocol_command = cmd_rx1.recv().await.unwrap();
1445        match protocol_command {
1446            ProtocolCommand::OpenSubstream {
1447                protocol,
1448                substream_id: opened_substream_id,
1449                permit,
1450                ..
1451            } => {
1452                assert_eq!(protocol, ProtocolName::from("/notif/1"));
1453                assert_eq!(second_substream_id, opened_substream_id);
1454
1455                // Save the substream permit for later.
1456                permits.push(permit);
1457            }
1458            _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1459        }
1460
1461        // Drop one permit.
1462        let permit = permits.pop();
1463        // Individual transports like TCP will open a substream
1464        // and then will generate a `SubstreamOpened` event via
1465        // the protocol-set handler.
1466        //
1467        // The substream is used by individual protocols and then
1468        // is closed. This simulates the substream being closed.
1469        drop(permit);
1470
1471        // Open a new substream to the peer. This will succeed as long as we still have
1472        // one substream open.
1473        let substream_id = service.open_substream(peer).unwrap();
1474        // Handle the substream.
1475        let protocol_command = cmd_rx1.recv().await.unwrap();
1476        match protocol_command {
1477            ProtocolCommand::OpenSubstream {
1478                protocol,
1479                substream_id: opened_substream_id,
1480                permit,
1481                ..
1482            } => {
1483                assert_eq!(protocol, ProtocolName::from("/notif/1"));
1484                assert_eq!(substream_id, opened_substream_id);
1485
1486                // Save the substream permit for later.
1487                permits.push(permit);
1488            }
1489            _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1490        }
1491
1492        // Drop all substreams.
1493        drop(permits);
1494
1495        poll_service(&mut service).await;
1496        tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await;
1497        poll_service(&mut service).await;
1498
1499        // Cannot open a new substream because:
1500        // 1. connection was downgraded by keep-alive timeout
1501        // 2. all substreams were dropped.
1502        assert_eq!(
1503            service.open_substream(peer),
1504            Err(SubstreamError::ConnectionClosed)
1505        );
1506    }
1507
1508    #[tokio::test]
1509    async fn substream_opening_upgrades_connection_and_resets_keep_alive() {
1510        let _ = tracing_subscriber::fmt()
1511            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1512            .try_init();
1513
1514        let (mut service, sender, _) = transport_service();
1515        let peer = PeerId::random();
1516
1517        // register first connection
1518        let (cmd_tx1, mut cmd_rx1) = channel(64);
1519        sender
1520            .send(InnerTransportEvent::ConnectionEstablished {
1521                peer,
1522                connection: ConnectionId::from(1337usize),
1523                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)),
1524                sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1),
1525            })
1526            .await
1527            .unwrap();
1528
1529        if let Some(TransportEvent::ConnectionEstablished {
1530            peer: connected_peer,
1531            endpoint,
1532        }) = service.next().await
1533        {
1534            assert_eq!(connected_peer, peer);
1535            assert_eq!(endpoint.address(), &Multiaddr::empty());
1536        } else {
1537            panic!("expected event from `TransportService`");
1538        };
1539
1540        // verify the first connection state is correct
1541        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1542        match service.connections.get(&peer) {
1543            Some(context) => {
1544                assert_eq!(
1545                    context.primary.connection_id(),
1546                    &ConnectionId::from(1337usize)
1547                );
1548                // Check the connection is still active.
1549                assert!(context.primary.is_active());
1550                assert!(context.secondary.is_none());
1551            }
1552            None => panic!("expected {peer} to exist"),
1553        }
1554
1555        // Open substreams to the peer.
1556        let substream_id = service.open_substream(peer).unwrap();
1557        let second_substream_id = service.open_substream(peer).unwrap();
1558
1559        let mut permits = Vec::new();
1560        // First substream.
1561        let protocol_command = cmd_rx1.recv().await.unwrap();
1562        match protocol_command {
1563            ProtocolCommand::OpenSubstream {
1564                protocol,
1565                substream_id: opened_substream_id,
1566                permit,
1567                ..
1568            } => {
1569                assert_eq!(protocol, ProtocolName::from("/notif/1"));
1570                assert_eq!(substream_id, opened_substream_id);
1571
1572                // Save the substream permit for later.
1573                permits.push(permit);
1574            }
1575            _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1576        }
1577
1578        // Second substream.
1579        let protocol_command = cmd_rx1.recv().await.unwrap();
1580        match protocol_command {
1581            ProtocolCommand::OpenSubstream {
1582                protocol,
1583                substream_id: opened_substream_id,
1584                permit,
1585                ..
1586            } => {
1587                assert_eq!(protocol, ProtocolName::from("/notif/1"));
1588                assert_eq!(second_substream_id, opened_substream_id);
1589
1590                // Save the substream permit for later.
1591                permits.push(permit);
1592            }
1593            _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1594        }
1595
1596        // Sleep to trigger keep-alive timeout.
1597        poll_service(&mut service).await;
1598        tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await;
1599        poll_service(&mut service).await;
1600
1601        // Verify the connection is downgraded.
1602        match service.connections.get(&peer) {
1603            Some(context) => {
1604                assert_eq!(
1605                    context.primary.connection_id(),
1606                    &ConnectionId::from(1337usize)
1607                );
1608                // Check the connection is not active.
1609                assert!(!context.primary.is_active());
1610                assert!(context.secondary.is_none());
1611            }
1612            None => panic!("expected {peer} to exist"),
1613        }
1614        assert_eq!(service.keep_alive_tracker.last_activity.len(), 0);
1615
1616        // Open a new substream to the peer. This will succeed as long as we still have
1617        // at least substream permit.
1618        let substream_id = service.open_substream(peer).unwrap();
1619        let protocol_command = cmd_rx1.recv().await.unwrap();
1620        match protocol_command {
1621            ProtocolCommand::OpenSubstream {
1622                protocol,
1623                substream_id: opened_substream_id,
1624                permit,
1625                ..
1626            } => {
1627                assert_eq!(protocol, ProtocolName::from("/notif/1"));
1628                assert_eq!(substream_id, opened_substream_id);
1629
1630                // Save the substream permit for later.
1631                permits.push(permit);
1632            }
1633            _ => panic!("expected `ProtocolCommand::OpenSubstream`"),
1634        }
1635
1636        poll_service(&mut service).await;
1637
1638        // Verify the connection is upgraded and keep-alive is tracked.
1639        match service.connections.get(&peer) {
1640            Some(context) => {
1641                assert_eq!(
1642                    context.primary.connection_id(),
1643                    &ConnectionId::from(1337usize)
1644                );
1645                // Check the connection is active, because it was upgraded by the last substream.
1646                assert!(context.primary.is_active());
1647                assert!(context.secondary.is_none());
1648            }
1649            None => panic!("expected {peer} to exist"),
1650        }
1651        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1652
1653        // Drop all substreams
1654        drop(permits);
1655
1656        // The connection is still active, because it was upgraded by the last substream open.
1657        match service.connections.get(&peer) {
1658            Some(context) => {
1659                assert_eq!(
1660                    context.primary.connection_id(),
1661                    &ConnectionId::from(1337usize)
1662                );
1663                // Check the connection is active, because it was upgraded by the last substream.
1664                assert!(context.primary.is_active());
1665                assert!(context.secondary.is_none());
1666            }
1667            None => panic!("expected {peer} to exist"),
1668        }
1669        assert_eq!(service.keep_alive_tracker.last_activity.len(), 1);
1670
1671        // Sleep to trigger keep-alive timeout.
1672        poll_service(&mut service).await;
1673        tokio::time::sleep(KEEP_ALIVE_TIMEOUT + std::time::Duration::from_secs(1)).await;
1674        poll_service(&mut service).await;
1675
1676        match service.connections.get(&peer) {
1677            Some(context) => {
1678                assert_eq!(
1679                    context.primary.connection_id(),
1680                    &ConnectionId::from(1337usize)
1681                );
1682                // No longer active because it was downgraded by keep-alive and no
1683                // substream opens were made.
1684                assert!(!context.primary.is_active());
1685                assert!(context.secondary.is_none());
1686            }
1687            None => panic!("expected {peer} to exist"),
1688        }
1689
1690        // Cannot open a new substream because:
1691        // 1. connection was downgraded by keep-alive timeout
1692        // 2. all substreams were dropped.
1693        assert_eq!(
1694            service.open_substream(peer),
1695            Err(SubstreamError::ConnectionClosed)
1696        );
1697    }
1698
1699    #[tokio::test]
1700    async fn keep_alive_pop_elements() {
1701        let mut tracker = KeepAliveTracker::new(Duration::from_secs(1));
1702
1703        let (peer1, connection1) = (PeerId::random(), ConnectionId::from(1usize));
1704        let (peer2, connection2) = (PeerId::random(), ConnectionId::from(2usize));
1705        let added_keys = HashSet::from([(peer1, connection1), (peer2, connection2)]);
1706
1707        tracker.on_connection_established(peer1, connection1);
1708        tracker.on_connection_established(peer2, connection2);
1709
1710        tokio::time::sleep(Duration::from_secs(2)).await;
1711
1712        let key = tracker.next().await.unwrap();
1713        assert!(added_keys.contains(&key));
1714
1715        let key = tracker.next().await.unwrap();
1716        assert!(added_keys.contains(&key));
1717
1718        // No more elements.
1719        assert!(tracker.pending_keep_alive_timeouts.is_empty());
1720        assert!(tracker.last_activity.is_empty());
1721    }
1722}