litep2p/protocol/
transport_service.rs

1// Copyright 2023 litep2p developers
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::{
22    addresses::PublicAddresses,
23    error::{Error, ImmediateDialError, SubstreamError},
24    protocol::{connection::ConnectionHandle, InnerTransportEvent, TransportEvent},
25    transport::{manager::TransportManagerHandle, Endpoint},
26    types::{protocol::ProtocolName, ConnectionId, SubstreamId},
27    PeerId, DEFAULT_CHANNEL_SIZE,
28};
29
30use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt};
31use multiaddr::{Multiaddr, Protocol};
32use multihash::Multihash;
33use tokio::sync::mpsc::{channel, Receiver, Sender};
34
35use std::{
36    collections::{HashMap, HashSet},
37    fmt::Debug,
38    pin::Pin,
39    sync::{
40        atomic::{AtomicUsize, Ordering},
41        Arc,
42    },
43    task::{Context, Poll},
44    time::Duration,
45};
46
47/// Logging target for the file.
48const LOG_TARGET: &str = "litep2p::transport-service";
49
50/// Connection context for the peer.
51///
52/// Each peer is allowed to have at most two connections open. The first open connection is the
53/// primary connections which the local node uses to open substreams to remote. Secondary connection
54/// may be open if local and remote opened connections at the same time.
55///
56/// Secondary connection may be promoted to a primary connection if the primary connections closes
57/// while the secondary connections remains open.
58#[derive(Debug)]
59struct ConnectionContext {
60    /// Primary connection.
61    primary: ConnectionHandle,
62
63    /// Secondary connection, if it exists.
64    secondary: Option<ConnectionHandle>,
65}
66
67impl ConnectionContext {
68    /// Create new [`ConnectionContext`].
69    fn new(primary: ConnectionHandle) -> Self {
70        Self {
71            primary,
72            secondary: None,
73        }
74    }
75
76    /// Downgrade connection to non-active which means it will be closed
77    /// if there are no substreams open over it.
78    fn downgrade(&mut self, connection_id: &ConnectionId) {
79        if self.primary.connection_id() == connection_id {
80            self.primary.close();
81            return;
82        }
83
84        if let Some(handle) = &mut self.secondary {
85            if handle.connection_id() == connection_id {
86                handle.close();
87                return;
88            }
89        }
90
91        tracing::debug!(
92            target: LOG_TARGET,
93            primary = ?self.primary.connection_id(),
94            secondary = ?self.secondary.as_ref().map(|handle| handle.connection_id()),
95            ?connection_id,
96            "connection doesn't exist, cannot downgrade",
97        );
98    }
99}
100
101/// Provides an interfaces for [`Litep2p`](crate::Litep2p) protocols to interact
102/// with the underlying transport protocols.
103#[derive(Debug)]
104pub struct TransportService {
105    /// Local peer ID.
106    local_peer_id: PeerId,
107
108    /// Protocol.
109    protocol: ProtocolName,
110
111    /// Fallback names for the protocol.
112    fallback_names: Vec<ProtocolName>,
113
114    /// Open connections.
115    connections: HashMap<PeerId, ConnectionContext>,
116
117    /// Transport handle.
118    transport_handle: TransportManagerHandle,
119
120    /// RX channel for receiving events from tranports and connections.
121    rx: Receiver<InnerTransportEvent>,
122
123    /// Next substream ID.
124    next_substream_id: Arc<AtomicUsize>,
125
126    /// Close the connection if no substreams are open within this time frame.
127    keep_alive_timeout: Duration,
128
129    /// Pending keep-alive timeouts.
130    pending_keep_alive_timeouts: FuturesUnordered<BoxFuture<'static, (PeerId, ConnectionId)>>,
131}
132
133impl TransportService {
134    /// Create new [`TransportService`].
135    pub(crate) fn new(
136        local_peer_id: PeerId,
137        protocol: ProtocolName,
138        fallback_names: Vec<ProtocolName>,
139        next_substream_id: Arc<AtomicUsize>,
140        transport_handle: TransportManagerHandle,
141        keep_alive_timeout: Duration,
142    ) -> (Self, Sender<InnerTransportEvent>) {
143        let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE);
144
145        (
146            Self {
147                rx,
148                protocol,
149                local_peer_id,
150                fallback_names,
151                transport_handle,
152                next_substream_id,
153                connections: HashMap::new(),
154                keep_alive_timeout,
155                pending_keep_alive_timeouts: FuturesUnordered::new(),
156            },
157            tx,
158        )
159    }
160
161    /// Get the list of public addresses of the node.
162    pub fn public_addresses(&self) -> PublicAddresses {
163        self.transport_handle.public_addresses()
164    }
165
166    /// Get the list of listen addresses of the node.
167    pub fn listen_addresses(&self) -> HashSet<Multiaddr> {
168        self.transport_handle.listen_addresses()
169    }
170
171    /// Handle connection established event.
172    fn on_connection_established(
173        &mut self,
174        peer: PeerId,
175        endpoint: Endpoint,
176        connection_id: ConnectionId,
177        handle: ConnectionHandle,
178    ) -> Option<TransportEvent> {
179        tracing::debug!(
180            target: LOG_TARGET,
181            ?peer,
182            protocol = %self.protocol,
183            ?endpoint,
184            ?connection_id,
185            "connection established",
186        );
187        let keep_alive_timeout = self.keep_alive_timeout;
188
189        match self.connections.get_mut(&peer) {
190            Some(context) => match context.secondary {
191                Some(_) => {
192                    tracing::debug!(
193                        target: LOG_TARGET,
194                        ?peer,
195                        ?connection_id,
196                        ?endpoint,
197                        "ignoring third connection",
198                    );
199                    None
200                }
201                None => {
202                    self.pending_keep_alive_timeouts.push(Box::pin(async move {
203                        tokio::time::sleep(keep_alive_timeout).await;
204                        (peer, connection_id)
205                    }));
206                    context.secondary = Some(handle);
207
208                    None
209                }
210            },
211            None => {
212                self.connections.insert(peer, ConnectionContext::new(handle));
213                self.pending_keep_alive_timeouts.push(Box::pin(async move {
214                    tokio::time::sleep(keep_alive_timeout).await;
215                    (peer, connection_id)
216                }));
217
218                Some(TransportEvent::ConnectionEstablished { peer, endpoint })
219            }
220        }
221    }
222
223    /// Handle connection closed event.
224    fn on_connection_closed(
225        &mut self,
226        peer: PeerId,
227        connection_id: ConnectionId,
228    ) -> Option<TransportEvent> {
229        let Some(context) = self.connections.get_mut(&peer) else {
230            tracing::warn!(
231                target: LOG_TARGET,
232                ?peer,
233                ?connection_id,
234                "connection closed to a non-existent peer",
235            );
236
237            debug_assert!(false);
238            return None;
239        };
240
241        // if the primary connection was closed, check if there exist a secondary connection
242        // and if it does, convert the secondary connection a primary connection
243        if context.primary.connection_id() == &connection_id {
244            tracing::trace!(target: LOG_TARGET, ?peer, ?connection_id, "primary connection closed");
245
246            match context.secondary.take() {
247                None => {
248                    self.connections.remove(&peer);
249                    return Some(TransportEvent::ConnectionClosed { peer });
250                }
251                Some(handle) => {
252                    tracing::debug!(
253                        target: LOG_TARGET,
254                        ?peer,
255                        ?connection_id,
256                        "switch to secondary connection",
257                    );
258
259                    context.primary = handle;
260                    return None;
261                }
262            }
263        }
264
265        match context.secondary.take() {
266            Some(handle) if handle.connection_id() == &connection_id => {
267                tracing::trace!(
268                    target: LOG_TARGET,
269                    ?peer,
270                    ?connection_id,
271                    "secondary connection closed",
272                );
273
274                None
275            }
276            connection_state => {
277                tracing::debug!(
278                    target: LOG_TARGET,
279                    ?peer,
280                    ?connection_id,
281                    ?connection_state,
282                    "connection closed but it doesn't exist",
283                );
284
285                None
286            }
287        }
288    }
289
290    /// Dial `peer` using `PeerId`.
291    ///
292    /// Call fails if `Litep2p` doesn't have a known address for the peer.
293    pub fn dial(&mut self, peer: &PeerId) -> Result<(), ImmediateDialError> {
294        self.transport_handle.dial(peer)
295    }
296
297    /// Dial peer using a `Multiaddr`.
298    ///
299    /// Call fails if the address is not in correct format or it contains an unsupported/disabled
300    /// transport.
301    ///
302    /// Calling this function is only necessary for those addresses that are discovered out-of-band
303    /// since `Litep2p` internally keeps track of all peer addresses it has learned through user
304    /// calling this function, Kademlia peer discoveries and `Identify` responses.
305    pub fn dial_address(&mut self, address: Multiaddr) -> Result<(), ImmediateDialError> {
306        self.transport_handle.dial_address(address)
307    }
308
309    /// Add one or more addresses for `peer`.
310    ///
311    /// The list is filtered for duplicates and unsupported transports.
312    pub fn add_known_address(&mut self, peer: &PeerId, addresses: impl Iterator<Item = Multiaddr>) {
313        let addresses: HashSet<Multiaddr> = addresses
314            .filter_map(|address| {
315                if !std::matches!(address.iter().last(), Some(Protocol::P2p(_))) {
316                    Some(address.with(Protocol::P2p(Multihash::from_bytes(&peer.to_bytes()).ok()?)))
317                } else {
318                    Some(address)
319                }
320            })
321            .collect();
322
323        self.transport_handle.add_known_address(peer, addresses.into_iter());
324    }
325
326    /// Open substream to `peer`.
327    ///
328    /// Call fails if there is no connection open to `peer` or the channel towards
329    /// the connection is clogged.
330    pub fn open_substream(&mut self, peer: PeerId) -> Result<SubstreamId, SubstreamError> {
331        // always prefer the primary connection
332        let connection = &mut self
333            .connections
334            .get_mut(&peer)
335            .ok_or(SubstreamError::PeerDoesNotExist(peer))?
336            .primary;
337
338        let permit = connection.try_get_permit().ok_or(SubstreamError::ConnectionClosed)?;
339        let substream_id =
340            SubstreamId::from(self.next_substream_id.fetch_add(1usize, Ordering::Relaxed));
341
342        tracing::trace!(
343            target: LOG_TARGET,
344            ?peer,
345            protocol = %self.protocol,
346            ?substream_id,
347            "open substream",
348        );
349
350        connection
351            .open_substream(
352                self.protocol.clone(),
353                self.fallback_names.clone(),
354                substream_id,
355                permit,
356            )
357            .map(|_| substream_id)
358    }
359
360    /// Forcibly close the connection, even if other protocols have substreams open over it.
361    pub fn force_close(&mut self, peer: PeerId) -> crate::Result<()> {
362        let connection =
363            &mut self.connections.get_mut(&peer).ok_or(Error::PeerDoesntExist(peer))?;
364
365        tracing::debug!(
366            target: LOG_TARGET,
367            ?peer,
368            protocol = %self.protocol,
369            secondary = ?connection.secondary,
370            "forcibly closing the connection",
371        );
372
373        if let Some(ref mut connection) = connection.secondary {
374            let _ = connection.force_close();
375        }
376
377        connection.primary.force_close()
378    }
379
380    /// Get local peer ID.
381    pub fn local_peer_id(&self) -> PeerId {
382        self.local_peer_id
383    }
384}
385
386impl Stream for TransportService {
387    type Item = TransportEvent;
388
389    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
390        while let Poll::Ready(event) = self.rx.poll_recv(cx) {
391            match event {
392                None => return Poll::Ready(None),
393                Some(InnerTransportEvent::ConnectionEstablished {
394                    peer,
395                    endpoint,
396                    sender,
397                    connection,
398                }) => {
399                    if let Some(event) =
400                        self.on_connection_established(peer, endpoint, connection, sender)
401                    {
402                        return Poll::Ready(Some(event));
403                    }
404                }
405                Some(InnerTransportEvent::ConnectionClosed { peer, connection }) => {
406                    if let Some(event) = self.on_connection_closed(peer, connection) {
407                        return Poll::Ready(Some(event));
408                    }
409                }
410                Some(event) => return Poll::Ready(Some(event.into())),
411            }
412        }
413
414        while let Poll::Ready(Some((peer, connection_id))) =
415            self.pending_keep_alive_timeouts.poll_next_unpin(cx)
416        {
417            if let Some(context) = self.connections.get_mut(&peer) {
418                tracing::trace!(
419                    target: LOG_TARGET,
420                    ?peer,
421                    ?connection_id,
422                    "keep-alive timeout over, downgrade connection",
423                );
424
425                context.downgrade(&connection_id);
426            }
427        }
428
429        Poll::Pending
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use crate::{
437        protocol::TransportService,
438        transport::{
439            manager::{handle::InnerTransportManagerCommand, TransportManagerHandle},
440            KEEP_ALIVE_TIMEOUT,
441        },
442    };
443    use futures::StreamExt;
444    use parking_lot::RwLock;
445    use std::collections::HashSet;
446
447    /// Create new `TransportService`
448    fn transport_service() -> (
449        TransportService,
450        Sender<InnerTransportEvent>,
451        Receiver<InnerTransportManagerCommand>,
452    ) {
453        let (cmd_tx, cmd_rx) = channel(64);
454        let peer = PeerId::random();
455
456        let handle = TransportManagerHandle::new(
457            peer,
458            Arc::new(RwLock::new(HashMap::new())),
459            cmd_tx,
460            HashSet::new(),
461            Default::default(),
462            PublicAddresses::new(peer),
463        );
464
465        let (service, sender) = TransportService::new(
466            peer,
467            ProtocolName::from("/notif/1"),
468            Vec::new(),
469            Arc::new(AtomicUsize::new(0usize)),
470            handle,
471            KEEP_ALIVE_TIMEOUT,
472        );
473
474        (service, sender, cmd_rx)
475    }
476
477    #[tokio::test]
478    async fn secondary_connection_stored() {
479        let (mut service, sender, _) = transport_service();
480        let peer = PeerId::random();
481
482        // register first connection
483        let (cmd_tx1, _cmd_rx1) = channel(64);
484        sender
485            .send(InnerTransportEvent::ConnectionEstablished {
486                peer,
487                connection: ConnectionId::from(0usize),
488                endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(0usize)),
489                sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1),
490            })
491            .await
492            .unwrap();
493
494        if let Some(TransportEvent::ConnectionEstablished {
495            peer: connected_peer,
496            endpoint,
497        }) = service.next().await
498        {
499            assert_eq!(connected_peer, peer);
500            assert_eq!(endpoint.address(), &Multiaddr::empty());
501        } else {
502            panic!("expected event from `TransportService`");
503        };
504
505        // register secondary connection
506        let (cmd_tx2, _cmd_rx2) = channel(64);
507        sender
508            .send(InnerTransportEvent::ConnectionEstablished {
509                peer,
510                connection: ConnectionId::from(1usize),
511                endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1usize)),
512                sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2),
513            })
514            .await
515            .unwrap();
516
517        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
518            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
519            std::task::Poll::Pending => std::task::Poll::Ready(()),
520        })
521        .await;
522
523        let context = service.connections.get(&peer).unwrap();
524        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
525        assert_eq!(
526            context.secondary.as_ref().unwrap().connection_id(),
527            &ConnectionId::from(1usize)
528        );
529    }
530
531    #[tokio::test]
532    async fn tertiary_connection_ignored() {
533        let (mut service, sender, _) = transport_service();
534        let peer = PeerId::random();
535
536        // register first connection
537        let (cmd_tx1, _cmd_rx1) = channel(64);
538        sender
539            .send(InnerTransportEvent::ConnectionEstablished {
540                peer,
541                connection: ConnectionId::from(0usize),
542                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)),
543                sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1),
544            })
545            .await
546            .unwrap();
547
548        if let Some(TransportEvent::ConnectionEstablished {
549            peer: connected_peer,
550            endpoint,
551        }) = service.next().await
552        {
553            assert_eq!(connected_peer, peer);
554            assert_eq!(endpoint.address(), &Multiaddr::empty());
555        } else {
556            panic!("expected event from `TransportService`");
557        };
558
559        // register secondary connection
560        let (cmd_tx2, _cmd_rx2) = channel(64);
561        sender
562            .send(InnerTransportEvent::ConnectionEstablished {
563                peer,
564                connection: ConnectionId::from(1usize),
565                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1usize)),
566                sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2),
567            })
568            .await
569            .unwrap();
570
571        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
572            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
573            std::task::Poll::Pending => std::task::Poll::Ready(()),
574        })
575        .await;
576
577        let context = service.connections.get(&peer).unwrap();
578        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
579        assert_eq!(
580            context.secondary.as_ref().unwrap().connection_id(),
581            &ConnectionId::from(1usize)
582        );
583
584        // try to register tertiary connection and verify it's ignored
585        let (cmd_tx3, mut cmd_rx3) = channel(64);
586        sender
587            .send(InnerTransportEvent::ConnectionEstablished {
588                peer,
589                connection: ConnectionId::from(2usize),
590                endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(2usize)),
591                sender: ConnectionHandle::new(ConnectionId::from(2usize), cmd_tx3),
592            })
593            .await
594            .unwrap();
595
596        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
597            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
598            std::task::Poll::Pending => std::task::Poll::Ready(()),
599        })
600        .await;
601
602        let context = service.connections.get(&peer).unwrap();
603        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
604        assert_eq!(
605            context.secondary.as_ref().unwrap().connection_id(),
606            &ConnectionId::from(1usize)
607        );
608        assert!(cmd_rx3.try_recv().is_err());
609    }
610
611    #[tokio::test]
612    async fn secondary_closing_doesnt_emit_event() {
613        let (mut service, sender, _) = transport_service();
614        let peer = PeerId::random();
615
616        // register first connection
617        let (cmd_tx1, _cmd_rx1) = channel(64);
618        sender
619            .send(InnerTransportEvent::ConnectionEstablished {
620                peer,
621                connection: ConnectionId::from(0usize),
622                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)),
623                sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1),
624            })
625            .await
626            .unwrap();
627
628        if let Some(TransportEvent::ConnectionEstablished {
629            peer: connected_peer,
630            endpoint,
631        }) = service.next().await
632        {
633            assert_eq!(connected_peer, peer);
634            assert_eq!(endpoint.address(), &Multiaddr::empty());
635        } else {
636            panic!("expected event from `TransportService`");
637        };
638
639        // register secondary connection
640        let (cmd_tx2, _cmd_rx2) = channel(64);
641        sender
642            .send(InnerTransportEvent::ConnectionEstablished {
643                peer,
644                connection: ConnectionId::from(1usize),
645                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1usize)),
646                sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2),
647            })
648            .await
649            .unwrap();
650
651        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
652            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
653            std::task::Poll::Pending => std::task::Poll::Ready(()),
654        })
655        .await;
656
657        let context = service.connections.get(&peer).unwrap();
658        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
659        assert_eq!(
660            context.secondary.as_ref().unwrap().connection_id(),
661            &ConnectionId::from(1usize)
662        );
663
664        // close the secondary connection
665        sender
666            .send(InnerTransportEvent::ConnectionClosed {
667                peer,
668                connection: ConnectionId::from(1usize),
669            })
670            .await
671            .unwrap();
672
673        // verify that the protocol is not notified
674        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
675            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
676            std::task::Poll::Pending => std::task::Poll::Ready(()),
677        })
678        .await;
679
680        // verify that the secondary connection doesn't exist anymore
681        let context = service.connections.get(&peer).unwrap();
682        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
683        assert!(context.secondary.is_none());
684    }
685
686    #[tokio::test]
687    async fn convert_secondary_to_primary() {
688        let (mut service, sender, _) = transport_service();
689        let peer = PeerId::random();
690
691        // register first connection
692        let (cmd_tx1, mut cmd_rx1) = channel(64);
693        sender
694            .send(InnerTransportEvent::ConnectionEstablished {
695                peer,
696                connection: ConnectionId::from(0usize),
697                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(0usize)),
698                sender: ConnectionHandle::new(ConnectionId::from(0usize), cmd_tx1),
699            })
700            .await
701            .unwrap();
702
703        if let Some(TransportEvent::ConnectionEstablished {
704            peer: connected_peer,
705            endpoint,
706        }) = service.next().await
707        {
708            assert_eq!(connected_peer, peer);
709            assert_eq!(endpoint.address(), &Multiaddr::empty());
710        } else {
711            panic!("expected event from `TransportService`");
712        };
713
714        // register secondary connection
715        let (cmd_tx2, mut cmd_rx2) = channel(64);
716        sender
717            .send(InnerTransportEvent::ConnectionEstablished {
718                peer,
719                connection: ConnectionId::from(1usize),
720                endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1usize)),
721                sender: ConnectionHandle::new(ConnectionId::from(1usize), cmd_tx2),
722            })
723            .await
724            .unwrap();
725
726        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
727            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
728            std::task::Poll::Pending => std::task::Poll::Ready(()),
729        })
730        .await;
731
732        let context = service.connections.get(&peer).unwrap();
733        assert_eq!(context.primary.connection_id(), &ConnectionId::from(0usize));
734        assert_eq!(
735            context.secondary.as_ref().unwrap().connection_id(),
736            &ConnectionId::from(1usize)
737        );
738
739        // close the primary connection
740        sender
741            .send(InnerTransportEvent::ConnectionClosed {
742                peer,
743                connection: ConnectionId::from(0usize),
744            })
745            .await
746            .unwrap();
747
748        // verify that the protocol is not notified
749        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
750            std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
751            std::task::Poll::Pending => std::task::Poll::Ready(()),
752        })
753        .await;
754
755        // verify that the primary connection has been replaced
756        let context = service.connections.get(&peer).unwrap();
757        assert_eq!(context.primary.connection_id(), &ConnectionId::from(1usize));
758        assert!(context.secondary.is_none());
759        assert!(cmd_rx1.try_recv().is_err());
760
761        // close the secondary connection as well
762        sender
763            .send(InnerTransportEvent::ConnectionClosed {
764                peer,
765                connection: ConnectionId::from(1usize),
766            })
767            .await
768            .unwrap();
769
770        if let Some(TransportEvent::ConnectionClosed {
771            peer: disconnected_peer,
772        }) = service.next().await
773        {
774            assert_eq!(disconnected_peer, peer);
775        } else {
776            panic!("expected event from `TransportService`");
777        };
778
779        // verify that the primary connection has been replaced
780        assert!(service.connections.get(&peer).is_none());
781        assert!(cmd_rx2.try_recv().is_err());
782    }
783
784    #[tokio::test]
785    async fn keep_alive_timeout_expires_for_a_stale_connection() {
786        let (mut service, sender, _) = transport_service();
787        let peer = PeerId::random();
788
789        // register first connection
790        let (cmd_tx1, _cmd_rx1) = channel(64);
791        sender
792            .send(InnerTransportEvent::ConnectionEstablished {
793                peer,
794                connection: ConnectionId::from(1337usize),
795                endpoint: Endpoint::dialer(Multiaddr::empty(), ConnectionId::from(1337usize)),
796                sender: ConnectionHandle::new(ConnectionId::from(1337usize), cmd_tx1),
797            })
798            .await
799            .unwrap();
800
801        if let Some(TransportEvent::ConnectionEstablished {
802            peer: connected_peer,
803            endpoint,
804        }) = service.next().await
805        {
806            assert_eq!(connected_peer, peer);
807            assert_eq!(endpoint.address(), &Multiaddr::empty());
808        } else {
809            panic!("expected event from `TransportService`");
810        };
811
812        // verify the first connection state is correct
813        assert_eq!(service.pending_keep_alive_timeouts.len(), 1);
814        match service.connections.get(&peer) {
815            Some(context) => {
816                assert_eq!(
817                    context.primary.connection_id(),
818                    &ConnectionId::from(1337usize)
819                );
820                assert!(context.secondary.is_none());
821            }
822            None => panic!("expected {peer} to exist"),
823        }
824
825        // close the primary connection
826        sender
827            .send(InnerTransportEvent::ConnectionClosed {
828                peer,
829                connection: ConnectionId::from(1337usize),
830            })
831            .await
832            .unwrap();
833
834        // verify that the protocols are notified of the connection closing as well
835        if let Some(TransportEvent::ConnectionClosed {
836            peer: connected_peer,
837        }) = service.next().await
838        {
839            assert_eq!(connected_peer, peer);
840        } else {
841            panic!("expected event from `TransportService`");
842        }
843
844        // verify that the keep-alive timeout still exists for the peer but the peer itself
845        // doesn't exist anymore
846        //
847        // the peer is removed because there is no connection to them
848        assert_eq!(service.pending_keep_alive_timeouts.len(), 1);
849        assert!(service.connections.get(&peer).is_none());
850
851        // register new primary connection but verify that there are now two pending keep-alive
852        // timeouts
853        let (cmd_tx1, _cmd_rx1) = channel(64);
854        sender
855            .send(InnerTransportEvent::ConnectionEstablished {
856                peer,
857                connection: ConnectionId::from(1338usize),
858                endpoint: Endpoint::listener(Multiaddr::empty(), ConnectionId::from(1338usize)),
859                sender: ConnectionHandle::new(ConnectionId::from(1338usize), cmd_tx1),
860            })
861            .await
862            .unwrap();
863
864        if let Some(TransportEvent::ConnectionEstablished {
865            peer: connected_peer,
866            endpoint,
867        }) = service.next().await
868        {
869            assert_eq!(connected_peer, peer);
870            assert_eq!(endpoint.address(), &Multiaddr::empty());
871        } else {
872            panic!("expected event from `TransportService`");
873        };
874
875        // verify the first connection state is correct
876        assert_eq!(service.pending_keep_alive_timeouts.len(), 2);
877        match service.connections.get(&peer) {
878            Some(context) => {
879                assert_eq!(
880                    context.primary.connection_id(),
881                    &ConnectionId::from(1338usize)
882                );
883                assert!(context.secondary.is_none());
884            }
885            None => panic!("expected {peer} to exist"),
886        }
887
888        match tokio::time::timeout(Duration::from_secs(10), service.next()).await {
889            Ok(event) => panic!("didn't expect an event: {event:?}"),
890            Err(_) => {}
891        }
892    }
893}