litep2p/transport/tcp/
mod.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2// Copyright 2023 litep2p developers
3//
4// Permission is hereby granted, free of charge, to any person obtaining a
5// copy of this software and associated documentation files (the "Software"),
6// to deal in the Software without restriction, including without limitation
7// the rights to use, copy, modify, merge, publish, distribute, sublicense,
8// and/or sell copies of the Software, and to permit persons to whom the
9// Software is furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20// DEALINGS IN THE SOFTWARE.
21
22//! TCP transport.
23
24use crate::{
25    config::Role,
26    error::{DialError, Error},
27    transport::{
28        common::listener::{DialAddresses, GetSocketAddr, SocketListener, TcpAddress},
29        manager::TransportHandle,
30        tcp::{
31            config::Config,
32            connection::{NegotiatedConnection, TcpConnection},
33        },
34        Transport, TransportBuilder, TransportEvent,
35    },
36    types::ConnectionId,
37};
38
39use futures::{
40    future::BoxFuture,
41    stream::{FuturesUnordered, Stream, StreamExt},
42};
43use multiaddr::Multiaddr;
44use socket2::{Domain, Socket, Type};
45use tokio::net::TcpStream;
46
47use std::{
48    collections::{HashMap, HashSet},
49    net::SocketAddr,
50    pin::Pin,
51    task::{Context, Poll},
52    time::Duration,
53};
54
55pub(crate) use substream::Substream;
56
57mod connection;
58mod substream;
59
60pub mod config;
61
62/// Logging target for the file.
63const LOG_TARGET: &str = "litep2p::tcp";
64
65/// Pending inbound connection.
66struct PendingInboundConnection {
67    /// Socket address of the remote peer.
68    connection: TcpStream,
69    /// Address of the remote peer.
70    address: SocketAddr,
71}
72
73/// TCP transport.
74pub(crate) struct TcpTransport {
75    /// Transport context.
76    context: TransportHandle,
77
78    /// Transport configuration.
79    config: Config,
80
81    /// TCP listener.
82    listener: SocketListener,
83
84    /// Pending dials.
85    pending_dials: HashMap<ConnectionId, Multiaddr>,
86
87    /// Dial addresses.
88    dial_addresses: DialAddresses,
89
90    /// Pending inbound connections.
91    pending_inbound_connections: HashMap<ConnectionId, PendingInboundConnection>,
92
93    /// Pending opening connections.
94    pending_connections: FuturesUnordered<
95        BoxFuture<'static, Result<NegotiatedConnection, (ConnectionId, DialError)>>,
96    >,
97
98    /// Pending raw, unnegotiated connections.
99    pending_raw_connections: FuturesUnordered<
100        BoxFuture<
101            'static,
102            Result<
103                (ConnectionId, Multiaddr, TcpStream),
104                (ConnectionId, Vec<(Multiaddr, DialError)>),
105            >,
106        >,
107    >,
108
109    /// Opened raw connection, waiting for approval/rejection from `TransportManager`.
110    opened_raw: HashMap<ConnectionId, (TcpStream, Multiaddr)>,
111
112    /// Canceled raw connections.
113    canceled: HashSet<ConnectionId>,
114
115    /// Connections which have been opened and negotiated but are being validated by the
116    /// `TransportManager`.
117    pending_open: HashMap<ConnectionId, NegotiatedConnection>,
118}
119
120impl TcpTransport {
121    /// Handle inbound TCP connection.
122    fn on_inbound_connection(
123        &mut self,
124        connection_id: ConnectionId,
125        connection: TcpStream,
126        address: SocketAddr,
127    ) {
128        let yamux_config = self.config.yamux_config.clone();
129        let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
130        let max_write_buffer_size = self.config.noise_write_buffer_size;
131        let connection_open_timeout = self.config.connection_open_timeout;
132        let substream_open_timeout = self.config.substream_open_timeout;
133        let keypair = self.context.keypair.clone();
134
135        tracing::trace!(
136            target: LOG_TARGET,
137            ?connection_id,
138            ?address,
139            "accept connection",
140        );
141
142        self.pending_connections.push(Box::pin(async move {
143            TcpConnection::accept_connection(
144                connection,
145                connection_id,
146                keypair,
147                address,
148                yamux_config,
149                max_read_ahead_factor,
150                max_write_buffer_size,
151                connection_open_timeout,
152                substream_open_timeout,
153            )
154            .await
155            .map_err(|error| (connection_id, error.into()))
156        }));
157    }
158
159    /// Dial remote peer
160    async fn dial_peer(
161        address: Multiaddr,
162        dial_addresses: DialAddresses,
163        connection_open_timeout: Duration,
164        nodelay: bool,
165    ) -> Result<(Multiaddr, TcpStream), DialError> {
166        let (socket_address, _) = TcpAddress::multiaddr_to_socket_address(&address)?;
167
168        let remote_address =
169            match tokio::time::timeout(connection_open_timeout, socket_address.lookup_ip()).await {
170                Err(_) => {
171                    tracing::debug!(
172                        target: LOG_TARGET,
173                        ?address,
174                        ?connection_open_timeout,
175                        "failed to resolve address within timeout",
176                    );
177                    return Err(DialError::Timeout);
178                }
179                Ok(Err(error)) => return Err(error.into()),
180                Ok(Ok(address)) => address,
181            };
182
183        let domain = match remote_address.is_ipv4() {
184            true => Domain::IPV4,
185            false => Domain::IPV6,
186        };
187        let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?;
188        if remote_address.is_ipv6() {
189            socket.set_only_v6(true)?;
190        }
191        socket.set_nonblocking(true)?;
192        socket.set_nodelay(nodelay)?;
193
194        match dial_addresses.local_dial_address(&remote_address.ip()) {
195            Ok(Some(dial_address)) => {
196                socket.set_reuse_address(true)?;
197                #[cfg(unix)]
198                socket.set_reuse_port(true)?;
199                socket.bind(&dial_address.into())?;
200            }
201            Ok(None) => {}
202            Err(()) => {
203                tracing::debug!(
204                    target: LOG_TARGET,
205                    ?remote_address,
206                    "tcp listener not enabled for remote address, using ephemeral port",
207                );
208            }
209        }
210
211        let future = async move {
212            match socket.connect(&remote_address.into()) {
213                Ok(()) => {}
214                Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {}
215                Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {}
216                Err(err) => return Err(err),
217            }
218
219            let stream = TcpStream::try_from(Into::<std::net::TcpStream>::into(socket))?;
220            stream.writable().await?;
221
222            if let Some(e) = stream.take_error()? {
223                return Err(e);
224            }
225
226            Ok((address, stream))
227        };
228
229        match tokio::time::timeout(connection_open_timeout, future).await {
230            Err(_) => {
231                tracing::debug!(
232                    target: LOG_TARGET,
233                    ?connection_open_timeout,
234                    "failed to connect within timeout",
235                );
236                Err(DialError::Timeout)
237            }
238            Ok(Err(error)) => Err(error.into()),
239            Ok(Ok((address, stream))) => {
240                tracing::debug!(
241                    target: LOG_TARGET,
242                    ?address,
243                    "connected",
244                );
245
246                Ok((address, stream))
247            }
248        }
249    }
250}
251
252impl TransportBuilder for TcpTransport {
253    type Config = Config;
254    type Transport = TcpTransport;
255
256    /// Create new [`TcpTransport`].
257    fn new(
258        context: TransportHandle,
259        mut config: Self::Config,
260    ) -> crate::Result<(Self, Vec<Multiaddr>)> {
261        tracing::debug!(
262            target: LOG_TARGET,
263            listen_addresses = ?config.listen_addresses,
264            "start tcp transport",
265        );
266
267        // start tcp listeners for all listen addresses
268        let (listener, listen_addresses, dial_addresses) = SocketListener::new::<TcpAddress>(
269            std::mem::take(&mut config.listen_addresses),
270            config.reuse_port,
271            config.nodelay,
272        );
273
274        Ok((
275            Self {
276                listener,
277                config,
278                context,
279                dial_addresses,
280                canceled: HashSet::new(),
281                opened_raw: HashMap::new(),
282                pending_open: HashMap::new(),
283                pending_dials: HashMap::new(),
284                pending_inbound_connections: HashMap::new(),
285                pending_connections: FuturesUnordered::new(),
286                pending_raw_connections: FuturesUnordered::new(),
287            },
288            listen_addresses,
289        ))
290    }
291}
292
293impl Transport for TcpTransport {
294    fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> {
295        tracing::debug!(target: LOG_TARGET, ?connection_id, ?address, "open connection");
296
297        let (socket_address, peer) = TcpAddress::multiaddr_to_socket_address(&address)?;
298        let yamux_config = self.config.yamux_config.clone();
299        let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
300        let max_write_buffer_size = self.config.noise_write_buffer_size;
301        let connection_open_timeout = self.config.connection_open_timeout;
302        let substream_open_timeout = self.config.substream_open_timeout;
303        let dial_addresses = self.dial_addresses.clone();
304        let keypair = self.context.keypair.clone();
305        let nodelay = self.config.nodelay;
306
307        self.pending_dials.insert(connection_id, address.clone());
308        self.pending_connections.push(Box::pin(async move {
309            let (_, stream) =
310                TcpTransport::dial_peer(address, dial_addresses, connection_open_timeout, nodelay)
311                    .await
312                    .map_err(|error| (connection_id, error))?;
313
314            TcpConnection::open_connection(
315                connection_id,
316                keypair,
317                stream,
318                socket_address,
319                peer,
320                yamux_config,
321                max_read_ahead_factor,
322                max_write_buffer_size,
323                connection_open_timeout,
324                substream_open_timeout,
325            )
326            .await
327            .map_err(|error| (connection_id, error.into()))
328        }));
329
330        Ok(())
331    }
332
333    fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
334        let context = self
335            .pending_open
336            .remove(&connection_id)
337            .ok_or(Error::ConnectionDoesntExist(connection_id))?;
338        let protocol_set = self.context.protocol_set(connection_id);
339        let bandwidth_sink = self.context.bandwidth_sink.clone();
340        let next_substream_id = self.context.next_substream_id.clone();
341
342        tracing::trace!(
343            target: LOG_TARGET,
344            ?connection_id,
345            "start connection",
346        );
347
348        self.context.executor.run(Box::pin(async move {
349            if let Err(error) =
350                TcpConnection::new(context, protocol_set, bandwidth_sink, next_substream_id)
351                    .start()
352                    .await
353            {
354                tracing::debug!(
355                    target: LOG_TARGET,
356                    ?connection_id,
357                    ?error,
358                    "connection exited with error",
359                );
360            }
361        }));
362
363        Ok(())
364    }
365
366    fn accept_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
367        let pending = self
368            .pending_inbound_connections
369            .remove(&connection_id)
370            .ok_or(Error::ConnectionDoesntExist(connection_id))?;
371
372        self.on_inbound_connection(connection_id, pending.connection, pending.address);
373
374        Ok(())
375    }
376
377    fn reject_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
378        self.pending_inbound_connections
379            .remove(&connection_id)
380            .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(()))
381    }
382
383    fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
384        self.pending_open
385            .remove(&connection_id)
386            .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(()))
387    }
388
389    fn open(
390        &mut self,
391        connection_id: ConnectionId,
392        addresses: Vec<Multiaddr>,
393    ) -> crate::Result<()> {
394        let num_addresses = addresses.len();
395        let mut futures: FuturesUnordered<_> = addresses
396            .into_iter()
397            .map(|address| {
398                let dial_addresses = self.dial_addresses.clone();
399                let connection_open_timeout = self.config.connection_open_timeout;
400                let nodelay = self.config.nodelay;
401
402                async move {
403                    TcpTransport::dial_peer(
404                        address.clone(),
405                        dial_addresses,
406                        connection_open_timeout,
407                        nodelay,
408                    )
409                    .await
410                    .map_err(|error| (address, error))
411                }
412            })
413            .collect();
414
415        self.pending_raw_connections.push(Box::pin(async move {
416            let mut errors = Vec::with_capacity(num_addresses);
417            while let Some(result) = futures.next().await {
418                match result {
419                    Ok((address, stream)) => return Ok((connection_id, address, stream)),
420                    Err(error) => {
421                        tracing::debug!(
422                            target: LOG_TARGET,
423                            ?connection_id,
424                            ?error,
425                            "failed to open connection",
426                        );
427                        errors.push(error)
428                    }
429                }
430            }
431
432            Err((connection_id, errors))
433        }));
434
435        Ok(())
436    }
437
438    fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
439        let (stream, address) = self
440            .opened_raw
441            .remove(&connection_id)
442            .ok_or(Error::ConnectionDoesntExist(connection_id))?;
443
444        let (socket_address, peer) = TcpAddress::multiaddr_to_socket_address(&address)?;
445        let yamux_config = self.config.yamux_config.clone();
446        let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
447        let max_write_buffer_size = self.config.noise_write_buffer_size;
448        let connection_open_timeout = self.config.connection_open_timeout;
449        let substream_open_timeout = self.config.substream_open_timeout;
450        let keypair = self.context.keypair.clone();
451
452        tracing::trace!(
453            target: LOG_TARGET,
454            ?peer,
455            ?connection_id,
456            ?address,
457            "negotiate connection",
458        );
459
460        self.pending_dials.insert(connection_id, address);
461        self.pending_connections.push(Box::pin(async move {
462            match tokio::time::timeout(connection_open_timeout, async move {
463                TcpConnection::negotiate_connection(
464                    stream,
465                    peer,
466                    connection_id,
467                    keypair,
468                    Role::Dialer,
469                    socket_address,
470                    yamux_config,
471                    max_read_ahead_factor,
472                    max_write_buffer_size,
473                    substream_open_timeout,
474                )
475                .await
476                .map_err(|error| (connection_id, error.into()))
477            })
478            .await
479            {
480                Err(_) => Err((connection_id, DialError::Timeout)),
481                Ok(Err(error)) => Err(error),
482                Ok(Ok(connection)) => Ok(connection),
483            }
484        }));
485
486        Ok(())
487    }
488
489    fn cancel(&mut self, connection_id: ConnectionId) {
490        self.canceled.insert(connection_id);
491    }
492}
493
494impl Stream for TcpTransport {
495    type Item = TransportEvent;
496
497    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
498        if let Poll::Ready(event) = self.listener.poll_next_unpin(cx) {
499            return match event {
500                None | Some(Err(_)) => Poll::Ready(None),
501                Some(Ok((connection, address))) => {
502                    let connection_id = self.context.next_connection_id();
503                    tracing::trace!(
504                        target: LOG_TARGET,
505                        ?connection_id,
506                        ?address,
507                        "pending inbound TCP connection",
508                    );
509
510                    self.pending_inbound_connections.insert(
511                        connection_id,
512                        PendingInboundConnection {
513                            connection,
514                            address,
515                        },
516                    );
517
518                    Poll::Ready(Some(TransportEvent::PendingInboundConnection {
519                        connection_id,
520                    }))
521                }
522            };
523        }
524
525        while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) {
526            match result {
527                Ok((connection_id, address, stream)) => {
528                    tracing::trace!(
529                        target: LOG_TARGET,
530                        ?connection_id,
531                        ?address,
532                        canceled = self.canceled.contains(&connection_id),
533                        "connection opened",
534                    );
535
536                    if !self.canceled.remove(&connection_id) {
537                        self.opened_raw.insert(connection_id, (stream, address.clone()));
538
539                        return Poll::Ready(Some(TransportEvent::ConnectionOpened {
540                            connection_id,
541                            address,
542                        }));
543                    }
544                }
545                Err((connection_id, errors)) =>
546                    if !self.canceled.remove(&connection_id) {
547                        return Poll::Ready(Some(TransportEvent::OpenFailure {
548                            connection_id,
549                            errors,
550                        }));
551                    },
552            }
553        }
554
555        while let Poll::Ready(Some(connection)) = self.pending_connections.poll_next_unpin(cx) {
556            match connection {
557                Ok(connection) => {
558                    let peer = connection.peer();
559                    let endpoint = connection.endpoint();
560                    self.pending_open.insert(connection.connection_id(), connection);
561
562                    return Poll::Ready(Some(TransportEvent::ConnectionEstablished {
563                        peer,
564                        endpoint,
565                    }));
566                }
567                Err((connection_id, error)) => {
568                    if let Some(address) = self.pending_dials.remove(&connection_id) {
569                        return Poll::Ready(Some(TransportEvent::DialFailure {
570                            connection_id,
571                            address,
572                            error,
573                        }));
574                    } else {
575                        tracing::debug!(target: LOG_TARGET, ?error, ?connection_id, "Pending inbound connection failed");
576                    }
577                }
578            }
579        }
580
581        Poll::Pending
582    }
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588    use crate::{
589        codec::ProtocolCodec,
590        crypto::ed25519::Keypair,
591        executor::DefaultExecutor,
592        transport::manager::{
593            limits::ConnectionLimitsConfig, ProtocolContext, SupportedTransport, TransportManager,
594        },
595        types::protocol::ProtocolName,
596        BandwidthSink, PeerId,
597    };
598    use multiaddr::Protocol;
599    use multihash::Multihash;
600    use std::{collections::HashSet, sync::Arc};
601    use tokio::sync::mpsc::channel;
602
603    #[tokio::test]
604    async fn connect_and_accept_works() {
605        let _ = tracing_subscriber::fmt()
606            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
607            .try_init();
608
609        let keypair1 = Keypair::generate();
610        let (tx1, _rx1) = channel(64);
611        let (event_tx1, _event_rx1) = channel(64);
612        let bandwidth_sink = BandwidthSink::new();
613
614        let handle1 = crate::transport::manager::TransportHandle {
615            executor: Arc::new(DefaultExecutor {}),
616            next_substream_id: Default::default(),
617            next_connection_id: Default::default(),
618            keypair: keypair1.clone(),
619            tx: event_tx1,
620            bandwidth_sink: bandwidth_sink.clone(),
621
622            protocols: HashMap::from_iter([(
623                ProtocolName::from("/notif/1"),
624                ProtocolContext {
625                    tx: tx1,
626                    codec: ProtocolCodec::Identity(32),
627                    fallback_names: Vec::new(),
628                },
629            )]),
630        };
631        let transport_config1 = Config {
632            listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()],
633            ..Default::default()
634        };
635
636        let (mut transport1, listen_addresses) =
637            TcpTransport::new(handle1, transport_config1).unwrap();
638        let listen_address = listen_addresses[0].clone();
639
640        let keypair2 = Keypair::generate();
641        let (tx2, _rx2) = channel(64);
642        let (event_tx2, _event_rx2) = channel(64);
643
644        let handle2 = crate::transport::manager::TransportHandle {
645            executor: Arc::new(DefaultExecutor {}),
646            next_substream_id: Default::default(),
647            next_connection_id: Default::default(),
648            keypair: keypair2.clone(),
649            tx: event_tx2,
650            bandwidth_sink: bandwidth_sink.clone(),
651
652            protocols: HashMap::from_iter([(
653                ProtocolName::from("/notif/1"),
654                ProtocolContext {
655                    tx: tx2,
656                    codec: ProtocolCodec::Identity(32),
657                    fallback_names: Vec::new(),
658                },
659            )]),
660        };
661        let transport_config2 = Config {
662            listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()],
663            ..Default::default()
664        };
665
666        let (mut transport2, _) = TcpTransport::new(handle2, transport_config2).unwrap();
667        transport2.dial(ConnectionId::new(), listen_address).unwrap();
668
669        let (tx, mut from_transport2) = channel(64);
670        tokio::spawn(async move {
671            let event = transport2.next().await;
672            tx.send(event).await.unwrap();
673        });
674
675        let event = transport1.next().await.unwrap();
676        match event {
677            TransportEvent::PendingInboundConnection { connection_id } => {
678                transport1.accept_pending(connection_id).unwrap();
679            }
680            _ => panic!("unexpected event"),
681        }
682
683        let event = transport1.next().await;
684        assert!(std::matches!(
685            event,
686            Some(TransportEvent::ConnectionEstablished { .. })
687        ));
688
689        let event = from_transport2.recv().await.unwrap();
690        assert!(std::matches!(
691            event,
692            Some(TransportEvent::ConnectionEstablished { .. })
693        ));
694    }
695
696    #[tokio::test]
697    async fn connect_and_reject_works() {
698        let _ = tracing_subscriber::fmt()
699            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
700            .try_init();
701
702        let keypair1 = Keypair::generate();
703        let (tx1, _rx1) = channel(64);
704        let (event_tx1, _event_rx1) = channel(64);
705        let bandwidth_sink = BandwidthSink::new();
706
707        let handle1 = crate::transport::manager::TransportHandle {
708            executor: Arc::new(DefaultExecutor {}),
709            next_substream_id: Default::default(),
710            next_connection_id: Default::default(),
711            keypair: keypair1.clone(),
712            tx: event_tx1,
713            bandwidth_sink: bandwidth_sink.clone(),
714
715            protocols: HashMap::from_iter([(
716                ProtocolName::from("/notif/1"),
717                ProtocolContext {
718                    tx: tx1,
719                    codec: ProtocolCodec::Identity(32),
720                    fallback_names: Vec::new(),
721                },
722            )]),
723        };
724        let transport_config1 = Config {
725            listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()],
726            ..Default::default()
727        };
728
729        let (mut transport1, listen_addresses) =
730            TcpTransport::new(handle1, transport_config1).unwrap();
731        let listen_address = listen_addresses[0].clone();
732
733        let keypair2 = Keypair::generate();
734        let (tx2, _rx2) = channel(64);
735        let (event_tx2, _event_rx2) = channel(64);
736
737        let handle2 = crate::transport::manager::TransportHandle {
738            executor: Arc::new(DefaultExecutor {}),
739            next_substream_id: Default::default(),
740            next_connection_id: Default::default(),
741            keypair: keypair2.clone(),
742            tx: event_tx2,
743            bandwidth_sink: bandwidth_sink.clone(),
744
745            protocols: HashMap::from_iter([(
746                ProtocolName::from("/notif/1"),
747                ProtocolContext {
748                    tx: tx2,
749                    codec: ProtocolCodec::Identity(32),
750                    fallback_names: Vec::new(),
751                },
752            )]),
753        };
754        let transport_config2 = Config {
755            listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()],
756            ..Default::default()
757        };
758
759        let (mut transport2, _) = TcpTransport::new(handle2, transport_config2).unwrap();
760        transport2.dial(ConnectionId::new(), listen_address).unwrap();
761
762        let (tx, mut from_transport2) = channel(64);
763        tokio::spawn(async move {
764            let event = transport2.next().await;
765            tx.send(event).await.unwrap();
766        });
767
768        // Reject connection.
769        let event = transport1.next().await.unwrap();
770        match event {
771            TransportEvent::PendingInboundConnection { connection_id } => {
772                transport1.reject_pending(connection_id).unwrap();
773            }
774            _ => panic!("unexpected event"),
775        }
776
777        let event = from_transport2.recv().await.unwrap();
778        assert!(std::matches!(
779            event,
780            Some(TransportEvent::DialFailure { .. })
781        ));
782    }
783
784    #[tokio::test]
785    async fn dial_failure() {
786        let _ = tracing_subscriber::fmt()
787            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
788            .try_init();
789
790        let keypair1 = Keypair::generate();
791        let (tx1, _rx1) = channel(64);
792        let (event_tx1, mut event_rx1) = channel(64);
793        let bandwidth_sink = BandwidthSink::new();
794
795        let handle1 = crate::transport::manager::TransportHandle {
796            executor: Arc::new(DefaultExecutor {}),
797            next_substream_id: Default::default(),
798            next_connection_id: Default::default(),
799            keypair: keypair1.clone(),
800            tx: event_tx1,
801            bandwidth_sink: bandwidth_sink.clone(),
802
803            protocols: HashMap::from_iter([(
804                ProtocolName::from("/notif/1"),
805                ProtocolContext {
806                    tx: tx1,
807                    codec: ProtocolCodec::Identity(32),
808                    fallback_names: Vec::new(),
809                },
810            )]),
811        };
812        let (mut transport1, _) = TcpTransport::new(handle1, Default::default()).unwrap();
813
814        tokio::spawn(async move {
815            while let Some(event) = transport1.next().await {
816                match event {
817                    TransportEvent::ConnectionEstablished { .. } => {}
818                    TransportEvent::ConnectionClosed { .. } => {}
819                    TransportEvent::DialFailure { .. } => {}
820                    TransportEvent::ConnectionOpened { .. } => {}
821                    TransportEvent::OpenFailure { .. } => {}
822                    TransportEvent::PendingInboundConnection { .. } => {}
823                }
824            }
825        });
826
827        let keypair2 = Keypair::generate();
828        let (tx2, _rx2) = channel(64);
829        let (event_tx2, _event_rx2) = channel(64);
830
831        let handle2 = crate::transport::manager::TransportHandle {
832            executor: Arc::new(DefaultExecutor {}),
833            next_substream_id: Default::default(),
834            next_connection_id: Default::default(),
835            keypair: keypair2.clone(),
836            tx: event_tx2,
837            bandwidth_sink: bandwidth_sink.clone(),
838
839            protocols: HashMap::from_iter([(
840                ProtocolName::from("/notif/1"),
841                ProtocolContext {
842                    tx: tx2,
843                    codec: ProtocolCodec::Identity(32),
844                    fallback_names: Vec::new(),
845                },
846            )]),
847        };
848
849        let (mut transport2, _) = TcpTransport::new(handle2, Default::default()).unwrap();
850
851        let peer1: PeerId = PeerId::from_public_key(&keypair1.public().into());
852        let peer2: PeerId = PeerId::from_public_key(&keypair2.public().into());
853
854        tracing::info!(target: LOG_TARGET, "peer1 {peer1}, peer2 {peer2}");
855
856        let address = Multiaddr::empty()
857            .with(Protocol::Ip6(std::net::Ipv6Addr::new(
858                0, 0, 0, 0, 0, 0, 0, 1,
859            )))
860            .with(Protocol::Tcp(8888))
861            .with(Protocol::P2p(
862                Multihash::from_bytes(&peer1.to_bytes()).unwrap(),
863            ));
864
865        transport2.dial(ConnectionId::new(), address).unwrap();
866
867        // spawn the other connection in the background as it won't return anything
868        tokio::spawn(async move {
869            loop {
870                let _ = event_rx1.recv().await;
871            }
872        });
873
874        assert!(std::matches!(
875            transport2.next().await,
876            Some(TransportEvent::DialFailure { .. })
877        ));
878    }
879
880    #[tokio::test]
881    async fn dial_error_reported_for_outbound_connections() {
882        let (mut manager, _handle) = TransportManager::new(
883            Keypair::generate(),
884            HashSet::new(),
885            BandwidthSink::new(),
886            8usize,
887            ConnectionLimitsConfig::default(),
888        );
889        let handle = manager.transport_handle(Arc::new(DefaultExecutor {}));
890        manager.register_transport(
891            SupportedTransport::Tcp,
892            Box::new(crate::transport::dummy::DummyTransport::new()),
893        );
894        let (mut transport, _) = TcpTransport::new(
895            handle,
896            Config {
897                listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()],
898                ..Default::default()
899            },
900        )
901        .unwrap();
902
903        let keypair = Keypair::generate();
904        let peer_id = PeerId::from_public_key(&keypair.public().into());
905        let multiaddr = Multiaddr::empty()
906            .with(Protocol::Ip4(std::net::Ipv4Addr::new(255, 254, 253, 252)))
907            .with(Protocol::Tcp(8888))
908            .with(Protocol::P2p(
909                Multihash::from_bytes(&peer_id.to_bytes()).unwrap(),
910            ));
911        manager.dial_address(multiaddr.clone()).await.unwrap();
912
913        assert!(transport.pending_dials.is_empty());
914
915        match transport.dial(ConnectionId::from(0usize), multiaddr) {
916            Ok(()) => {}
917            _ => panic!("invalid result for `on_dial_peer()`"),
918        }
919
920        assert!(!transport.pending_dials.is_empty());
921        transport.pending_connections.push(Box::pin(async move {
922            Err((ConnectionId::from(0usize), DialError::Timeout))
923        }));
924
925        assert!(std::matches!(
926            transport.next().await,
927            Some(TransportEvent::DialFailure { .. })
928        ));
929        assert!(transport.pending_dials.is_empty());
930    }
931}