litep2p/transport/tcp/
mod.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2// Copyright 2023 litep2p developers
3//
4// Permission is hereby granted, free of charge, to any person obtaining a
5// copy of this software and associated documentation files (the "Software"),
6// to deal in the Software without restriction, including without limitation
7// the rights to use, copy, modify, merge, publish, distribute, sublicense,
8// and/or sell copies of the Software, and to permit persons to whom the
9// Software is furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20// DEALINGS IN THE SOFTWARE.
21
22//! TCP transport.
23
24use crate::{
25    error::{DialError, Error},
26    transport::{
27        common::listener::{DialAddresses, GetSocketAddr, SocketListener, TcpAddress},
28        manager::TransportHandle,
29        tcp::{
30            config::Config,
31            connection::{NegotiatedConnection, TcpConnection},
32        },
33        Transport, TransportBuilder, TransportEvent,
34    },
35    types::ConnectionId,
36    utils::futures_stream::FuturesStream,
37};
38
39use futures::{
40    future::BoxFuture,
41    stream::{AbortHandle, FuturesUnordered, Stream, StreamExt},
42    TryFutureExt,
43};
44use hickory_resolver::TokioResolver;
45use multiaddr::Multiaddr;
46use socket2::{Domain, Socket, Type};
47use tokio::net::TcpStream;
48
49use std::{
50    collections::HashMap,
51    net::SocketAddr,
52    pin::Pin,
53    sync::Arc,
54    task::{Context, Poll},
55    time::Duration,
56};
57
58pub(crate) use substream::Substream;
59
60mod connection;
61mod substream;
62
63pub mod config;
64
65/// Logging target for the file.
66const LOG_TARGET: &str = "litep2p::tcp";
67
68/// Pending inbound connection.
69struct PendingInboundConnection {
70    /// Socket address of the remote peer.
71    connection: TcpStream,
72    /// Address of the remote peer.
73    address: SocketAddr,
74}
75
76#[derive(Debug)]
77enum RawConnectionResult {
78    /// The first successful connection.
79    Connected {
80        negotiated: NegotiatedConnection,
81        errors: Vec<(Multiaddr, DialError)>,
82    },
83
84    /// All connection attempts failed.
85    Failed {
86        connection_id: ConnectionId,
87        errors: Vec<(Multiaddr, DialError)>,
88    },
89
90    /// Future was canceled.
91    Canceled { connection_id: ConnectionId },
92}
93
94/// TCP transport.
95pub(crate) struct TcpTransport {
96    /// Transport context.
97    context: TransportHandle,
98
99    /// Transport configuration.
100    config: Config,
101
102    /// TCP listener.
103    listener: SocketListener,
104
105    /// Pending dials.
106    pending_dials: HashMap<ConnectionId, Multiaddr>,
107
108    /// Dial addresses.
109    dial_addresses: DialAddresses,
110
111    /// Pending inbound connections.
112    pending_inbound_connections: HashMap<ConnectionId, PendingInboundConnection>,
113
114    /// Pending opening connections.
115    pending_connections:
116        FuturesStream<BoxFuture<'static, Result<NegotiatedConnection, (ConnectionId, DialError)>>>,
117
118    /// Pending raw, unnegotiated connections.
119    pending_raw_connections: FuturesStream<BoxFuture<'static, RawConnectionResult>>,
120
121    /// Opened raw connection, waiting for approval/rejection from `TransportManager`.
122    opened: HashMap<ConnectionId, NegotiatedConnection>,
123
124    /// Cancel raw connections futures.
125    ///
126    /// This is cancelling `Self::pending_raw_connections`.
127    cancel_futures: HashMap<ConnectionId, AbortHandle>,
128
129    /// Connections which have been opened and negotiated but are being validated by the
130    /// `TransportManager`.
131    pending_open: HashMap<ConnectionId, NegotiatedConnection>,
132
133    /// DNS resolver.
134    resolver: Arc<TokioResolver>,
135}
136
137impl TcpTransport {
138    /// Handle inbound TCP connection.
139    fn on_inbound_connection(
140        &mut self,
141        connection_id: ConnectionId,
142        connection: TcpStream,
143        address: SocketAddr,
144    ) {
145        let yamux_config = self.config.yamux_config.clone();
146        let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
147        let max_write_buffer_size = self.config.noise_write_buffer_size;
148        let connection_open_timeout = self.config.connection_open_timeout;
149        let substream_open_timeout = self.config.substream_open_timeout;
150        let keypair = self.context.keypair.clone();
151
152        tracing::trace!(
153            target: LOG_TARGET,
154            ?connection_id,
155            ?address,
156            "accept connection",
157        );
158
159        self.pending_connections.push(Box::pin(async move {
160            TcpConnection::accept_connection(
161                connection,
162                connection_id,
163                keypair,
164                address,
165                yamux_config,
166                max_read_ahead_factor,
167                max_write_buffer_size,
168                connection_open_timeout,
169                substream_open_timeout,
170            )
171            .await
172            .map_err(|error| (connection_id, error.into()))
173        }));
174    }
175
176    /// Dial remote peer
177    async fn dial_peer(
178        address: Multiaddr,
179        dial_addresses: DialAddresses,
180        connection_open_timeout: Duration,
181        nodelay: bool,
182        resolver: Arc<TokioResolver>,
183    ) -> Result<(Multiaddr, TcpStream), DialError> {
184        let (socket_address, _) = TcpAddress::multiaddr_to_socket_address(&address)?;
185
186        let remote_address =
187            match tokio::time::timeout(connection_open_timeout, socket_address.lookup_ip(resolver))
188                .await
189            {
190                Err(_) => {
191                    tracing::debug!(
192                        target: LOG_TARGET,
193                        ?address,
194                        ?connection_open_timeout,
195                        "failed to resolve address within timeout",
196                    );
197                    return Err(DialError::Timeout);
198                }
199                Ok(Err(error)) => return Err(error.into()),
200                Ok(Ok(address)) => address,
201            };
202
203        let domain = match remote_address.is_ipv4() {
204            true => Domain::IPV4,
205            false => Domain::IPV6,
206        };
207        let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?;
208        if remote_address.is_ipv6() {
209            socket.set_only_v6(true)?;
210        }
211        socket.set_nonblocking(true)?;
212        socket.set_nodelay(nodelay)?;
213
214        match dial_addresses.local_dial_address(&remote_address.ip()) {
215            Ok(Some(dial_address)) => {
216                socket.set_reuse_address(true)?;
217                #[cfg(unix)]
218                socket.set_reuse_port(true)?;
219                socket.bind(&dial_address.into())?;
220            }
221            Ok(None) => {}
222            Err(()) => {
223                tracing::debug!(
224                    target: LOG_TARGET,
225                    ?remote_address,
226                    "tcp listener not enabled for remote address, using ephemeral port",
227                );
228            }
229        }
230
231        let future = async move {
232            match socket.connect(&remote_address.into()) {
233                Ok(()) => {}
234                Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {}
235                Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {}
236                Err(err) => return Err(err),
237            }
238
239            let stream = TcpStream::try_from(Into::<std::net::TcpStream>::into(socket))?;
240            stream.writable().await?;
241
242            if let Some(e) = stream.take_error()? {
243                return Err(e);
244            }
245
246            Ok((address, stream))
247        };
248
249        match tokio::time::timeout(connection_open_timeout, future).await {
250            Err(_) => {
251                tracing::debug!(
252                    target: LOG_TARGET,
253                    ?connection_open_timeout,
254                    "failed to connect within timeout",
255                );
256                Err(DialError::Timeout)
257            }
258            Ok(Err(error)) => Err(error.into()),
259            Ok(Ok((address, stream))) => {
260                tracing::debug!(
261                    target: LOG_TARGET,
262                    ?address,
263                    "connected",
264                );
265
266                Ok((address, stream))
267            }
268        }
269    }
270}
271
272impl TransportBuilder for TcpTransport {
273    type Config = Config;
274    type Transport = TcpTransport;
275
276    /// Create new [`TcpTransport`].
277    fn new(
278        context: TransportHandle,
279        mut config: Self::Config,
280        resolver: Arc<TokioResolver>,
281    ) -> crate::Result<(Self, Vec<Multiaddr>)> {
282        tracing::debug!(
283            target: LOG_TARGET,
284            listen_addresses = ?config.listen_addresses,
285            "start tcp transport",
286        );
287
288        // start tcp listeners for all listen addresses
289        let (listener, listen_addresses, dial_addresses) = SocketListener::new::<TcpAddress>(
290            std::mem::take(&mut config.listen_addresses),
291            config.reuse_port,
292            config.nodelay,
293        );
294
295        Ok((
296            Self {
297                listener,
298                config,
299                context,
300                dial_addresses,
301                opened: HashMap::new(),
302                pending_open: HashMap::new(),
303                pending_dials: HashMap::new(),
304                pending_inbound_connections: HashMap::new(),
305                pending_connections: FuturesStream::new(),
306                pending_raw_connections: FuturesStream::new(),
307                cancel_futures: HashMap::new(),
308                resolver,
309            },
310            listen_addresses,
311        ))
312    }
313}
314
315impl Transport for TcpTransport {
316    fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> {
317        tracing::debug!(target: LOG_TARGET, ?connection_id, ?address, "open connection");
318
319        let (socket_address, peer) = TcpAddress::multiaddr_to_socket_address(&address)?;
320        let yamux_config = self.config.yamux_config.clone();
321        let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
322        let max_write_buffer_size = self.config.noise_write_buffer_size;
323        let connection_open_timeout = self.config.connection_open_timeout;
324        let substream_open_timeout = self.config.substream_open_timeout;
325        let dial_addresses = self.dial_addresses.clone();
326        let keypair = self.context.keypair.clone();
327        let nodelay = self.config.nodelay;
328        let resolver = self.resolver.clone();
329
330        self.pending_dials.insert(connection_id, address.clone());
331        self.pending_connections.push(Box::pin(async move {
332            let (_, stream) = TcpTransport::dial_peer(
333                address,
334                dial_addresses,
335                connection_open_timeout,
336                nodelay,
337                resolver,
338            )
339            .await
340            .map_err(|error| (connection_id, error))?;
341
342            TcpConnection::open_connection(
343                connection_id,
344                keypair,
345                stream,
346                socket_address,
347                peer,
348                yamux_config,
349                max_read_ahead_factor,
350                max_write_buffer_size,
351                connection_open_timeout,
352                substream_open_timeout,
353            )
354            .await
355            .map_err(|error| (connection_id, error.into()))
356        }));
357
358        Ok(())
359    }
360
361    fn accept_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
362        let pending = self.pending_inbound_connections.remove(&connection_id).ok_or_else(|| {
363            tracing::error!(
364                target: LOG_TARGET,
365                ?connection_id,
366                "Cannot accept non existent pending connection",
367            );
368
369            Error::ConnectionDoesntExist(connection_id)
370        })?;
371
372        self.on_inbound_connection(connection_id, pending.connection, pending.address);
373
374        Ok(())
375    }
376
377    fn reject_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
378        self.pending_inbound_connections.remove(&connection_id).map_or_else(
379            || {
380                tracing::error!(
381                    target: LOG_TARGET,
382                    ?connection_id,
383                    "Cannot reject non existent pending connection",
384                );
385
386                Err(Error::ConnectionDoesntExist(connection_id))
387            },
388            |_| Ok(()),
389        )
390    }
391
392    fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
393        let context = self
394            .pending_open
395            .remove(&connection_id)
396            .ok_or(Error::ConnectionDoesntExist(connection_id))?;
397        let protocol_set = self.context.protocol_set(connection_id);
398        let bandwidth_sink = self.context.bandwidth_sink.clone();
399        let next_substream_id = self.context.next_substream_id.clone();
400
401        tracing::trace!(
402            target: LOG_TARGET,
403            ?connection_id,
404            "start connection",
405        );
406
407        self.context.executor.run(Box::pin(async move {
408            if let Err(error) =
409                TcpConnection::new(context, protocol_set, bandwidth_sink, next_substream_id)
410                    .start()
411                    .await
412            {
413                tracing::debug!(
414                    target: LOG_TARGET,
415                    ?connection_id,
416                    ?error,
417                    "connection exited with error",
418                );
419            }
420        }));
421
422        Ok(())
423    }
424
425    fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
426        self.pending_open
427            .remove(&connection_id)
428            .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(()))
429    }
430
431    fn open(
432        &mut self,
433        connection_id: ConnectionId,
434        addresses: Vec<Multiaddr>,
435    ) -> crate::Result<()> {
436        let num_addresses = addresses.len();
437        let mut futures: FuturesUnordered<_> = addresses
438            .into_iter()
439            .map(|address| {
440                let yamux_config = self.config.yamux_config.clone();
441                let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
442                let max_write_buffer_size = self.config.noise_write_buffer_size;
443                let connection_open_timeout = self.config.connection_open_timeout;
444                let substream_open_timeout = self.config.substream_open_timeout;
445                let dial_addresses = self.dial_addresses.clone();
446                let keypair = self.context.keypair.clone();
447                let nodelay = self.config.nodelay;
448                let resolver = self.resolver.clone();
449
450                async move {
451                    let (address, stream) = TcpTransport::dial_peer(
452                        address.clone(),
453                        dial_addresses,
454                        connection_open_timeout,
455                        nodelay,
456                        resolver,
457                    )
458                    .await
459                    .map_err(|error| (address, error))?;
460
461                    let open_address = address.clone();
462                    let (socket_address, peer) = TcpAddress::multiaddr_to_socket_address(&address)
463                        .map_err(|error| (address, error.into()))?;
464
465                    TcpConnection::open_connection(
466                        connection_id,
467                        keypair,
468                        stream,
469                        socket_address,
470                        peer,
471                        yamux_config,
472                        max_read_ahead_factor,
473                        max_write_buffer_size,
474                        connection_open_timeout,
475                        substream_open_timeout,
476                    )
477                    .await
478                    .map_err(|error| (open_address, error.into()))
479                }
480            })
481            .collect();
482
483        // Future that will resolve to the first successful connection.
484        let future = async move {
485            let mut errors = Vec::with_capacity(num_addresses);
486            while let Some(result) = futures.next().await {
487                match result {
488                    Ok(negotiated) => return RawConnectionResult::Connected { negotiated, errors },
489                    Err(error) => {
490                        tracing::debug!(
491                            target: LOG_TARGET,
492                            ?connection_id,
493                            ?error,
494                            "failed to open connection",
495                        );
496                        errors.push(error)
497                    }
498                }
499            }
500
501            RawConnectionResult::Failed {
502                connection_id,
503                errors,
504            }
505        };
506
507        let (fut, handle) = futures::future::abortable(future);
508        let fut = fut.unwrap_or_else(move |_| RawConnectionResult::Canceled { connection_id });
509        self.pending_raw_connections.push(Box::pin(fut));
510        self.cancel_futures.insert(connection_id, handle);
511
512        Ok(())
513    }
514
515    fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
516        let negotiated = self
517            .opened
518            .remove(&connection_id)
519            .ok_or(Error::ConnectionDoesntExist(connection_id))?;
520
521        self.pending_connections.push(Box::pin(async move { Ok(negotiated) }));
522
523        Ok(())
524    }
525
526    fn cancel(&mut self, connection_id: ConnectionId) {
527        // Cancel the future if it exists.
528        // State clean-up happens inside the `poll_next`.
529        if let Some(handle) = self.cancel_futures.get(&connection_id) {
530            handle.abort();
531        }
532    }
533}
534
535impl Stream for TcpTransport {
536    type Item = TransportEvent;
537
538    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
539        if let Poll::Ready(event) = self.listener.poll_next_unpin(cx) {
540            return match event {
541                None => {
542                    tracing::error!(
543                        target: LOG_TARGET,
544                        "TCP listener terminated, ignore if the node is stopping",
545                    );
546
547                    Poll::Ready(None)
548                }
549                Some(Err(error)) => {
550                    tracing::error!(
551                        target: LOG_TARGET,
552                        ?error,
553                        "TCP listener terminated with error",
554                    );
555
556                    Poll::Ready(None)
557                }
558                Some(Ok((connection, address))) => {
559                    let connection_id = self.context.next_connection_id();
560                    tracing::trace!(
561                        target: LOG_TARGET,
562                        ?connection_id,
563                        ?address,
564                        "pending inbound TCP connection",
565                    );
566
567                    self.pending_inbound_connections.insert(
568                        connection_id,
569                        PendingInboundConnection {
570                            connection,
571                            address,
572                        },
573                    );
574
575                    Poll::Ready(Some(TransportEvent::PendingInboundConnection {
576                        connection_id,
577                    }))
578                }
579            };
580        }
581
582        while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) {
583            tracing::trace!(target: LOG_TARGET, ?result, "raw connection result");
584
585            match result {
586                RawConnectionResult::Connected { negotiated, errors } => {
587                    let Some(handle) = self.cancel_futures.remove(&negotiated.connection_id())
588                    else {
589                        tracing::warn!(
590                            target: LOG_TARGET,
591                            connection_id = ?negotiated.connection_id(),
592                            address = ?negotiated.endpoint().address(),
593                            ?errors,
594                            "raw connection without a cancel handle",
595                        );
596                        continue;
597                    };
598
599                    if !handle.is_aborted() {
600                        let connection_id = negotiated.connection_id();
601                        let address = negotiated.endpoint().address().clone();
602
603                        self.opened.insert(connection_id, negotiated);
604
605                        return Poll::Ready(Some(TransportEvent::ConnectionOpened {
606                            connection_id,
607                            address,
608                        }));
609                    }
610                }
611
612                RawConnectionResult::Failed {
613                    connection_id,
614                    errors,
615                } => {
616                    let Some(handle) = self.cancel_futures.remove(&connection_id) else {
617                        tracing::warn!(
618                            target: LOG_TARGET,
619                            ?connection_id,
620                            ?errors,
621                            "raw connection without a cancel handle",
622                        );
623                        continue;
624                    };
625
626                    if !handle.is_aborted() {
627                        return Poll::Ready(Some(TransportEvent::OpenFailure {
628                            connection_id,
629                            errors,
630                        }));
631                    }
632                }
633                RawConnectionResult::Canceled { connection_id } => {
634                    if self.cancel_futures.remove(&connection_id).is_none() {
635                        tracing::warn!(
636                            target: LOG_TARGET,
637                            ?connection_id,
638                            "raw cancelled connection without a cancel handle",
639                        );
640                    }
641                }
642            }
643        }
644
645        while let Poll::Ready(Some(connection)) = self.pending_connections.poll_next_unpin(cx) {
646            match connection {
647                Ok(connection) => {
648                    let peer = connection.peer();
649                    let endpoint = connection.endpoint();
650                    self.pending_dials.remove(&connection.connection_id());
651                    self.pending_open.insert(connection.connection_id(), connection);
652
653                    return Poll::Ready(Some(TransportEvent::ConnectionEstablished {
654                        peer,
655                        endpoint,
656                    }));
657                }
658                Err((connection_id, error)) => {
659                    if let Some(address) = self.pending_dials.remove(&connection_id) {
660                        return Poll::Ready(Some(TransportEvent::DialFailure {
661                            connection_id,
662                            address,
663                            error,
664                        }));
665                    } else {
666                        tracing::debug!(target: LOG_TARGET, ?error, ?connection_id, "Pending inbound connection failed");
667                    }
668                }
669            }
670        }
671
672        Poll::Pending
673    }
674}
675
676#[cfg(test)]
677mod tests {
678    use super::*;
679    use crate::{
680        codec::ProtocolCodec,
681        crypto::ed25519::Keypair,
682        executor::DefaultExecutor,
683        transport::manager::{ProtocolContext, SupportedTransport, TransportManagerBuilder},
684        types::protocol::ProtocolName,
685        BandwidthSink, PeerId,
686    };
687    use multiaddr::Protocol;
688    use multihash::Multihash;
689    use std::sync::Arc;
690    use tokio::sync::mpsc::channel;
691
692    #[tokio::test]
693    async fn connect_and_accept_works() {
694        let _ = tracing_subscriber::fmt()
695            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
696            .try_init();
697
698        let keypair1 = Keypair::generate();
699        let (tx1, _rx1) = channel(64);
700        let (event_tx1, _event_rx1) = channel(64);
701        let bandwidth_sink = BandwidthSink::new();
702
703        let handle1 = crate::transport::manager::TransportHandle {
704            executor: Arc::new(DefaultExecutor {}),
705            next_substream_id: Default::default(),
706            next_connection_id: Default::default(),
707            keypair: keypair1.clone(),
708            tx: event_tx1,
709            bandwidth_sink: bandwidth_sink.clone(),
710
711            protocols: HashMap::from_iter([(
712                ProtocolName::from("/notif/1"),
713                ProtocolContext {
714                    tx: tx1,
715                    codec: ProtocolCodec::Identity(32),
716                    fallback_names: Vec::new(),
717                },
718            )]),
719        };
720        let transport_config1 = Config {
721            listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()],
722            ..Default::default()
723        };
724        let resolver = Arc::new(TokioResolver::builder_tokio().unwrap().build());
725
726        let (mut transport1, listen_addresses) =
727            TcpTransport::new(handle1, transport_config1, resolver.clone()).unwrap();
728        let listen_address = listen_addresses[0].clone();
729
730        let keypair2 = Keypair::generate();
731        let (tx2, _rx2) = channel(64);
732        let (event_tx2, _event_rx2) = channel(64);
733
734        let handle2 = crate::transport::manager::TransportHandle {
735            executor: Arc::new(DefaultExecutor {}),
736            next_substream_id: Default::default(),
737            next_connection_id: Default::default(),
738            keypair: keypair2.clone(),
739            tx: event_tx2,
740            bandwidth_sink: bandwidth_sink.clone(),
741
742            protocols: HashMap::from_iter([(
743                ProtocolName::from("/notif/1"),
744                ProtocolContext {
745                    tx: tx2,
746                    codec: ProtocolCodec::Identity(32),
747                    fallback_names: Vec::new(),
748                },
749            )]),
750        };
751        let transport_config2 = Config {
752            listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()],
753            ..Default::default()
754        };
755
756        let (mut transport2, _) = TcpTransport::new(handle2, transport_config2, resolver).unwrap();
757        transport2.dial(ConnectionId::new(), listen_address).unwrap();
758
759        let (tx, mut from_transport2) = channel(64);
760        tokio::spawn(async move {
761            let event = transport2.next().await;
762            tx.send(event).await.unwrap();
763        });
764
765        let event = transport1.next().await.unwrap();
766        match event {
767            TransportEvent::PendingInboundConnection { connection_id } => {
768                transport1.accept_pending(connection_id).unwrap();
769            }
770            _ => panic!("unexpected event"),
771        }
772
773        let event = transport1.next().await;
774        assert!(std::matches!(
775            event,
776            Some(TransportEvent::ConnectionEstablished { .. })
777        ));
778
779        let event = from_transport2.recv().await.unwrap();
780        assert!(std::matches!(
781            event,
782            Some(TransportEvent::ConnectionEstablished { .. })
783        ));
784    }
785
786    #[tokio::test]
787    async fn connect_and_reject_works() {
788        let _ = tracing_subscriber::fmt()
789            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
790            .try_init();
791
792        let keypair1 = Keypair::generate();
793        let (tx1, _rx1) = channel(64);
794        let (event_tx1, _event_rx1) = channel(64);
795        let bandwidth_sink = BandwidthSink::new();
796
797        let handle1 = crate::transport::manager::TransportHandle {
798            executor: Arc::new(DefaultExecutor {}),
799            next_substream_id: Default::default(),
800            next_connection_id: Default::default(),
801            keypair: keypair1.clone(),
802            tx: event_tx1,
803            bandwidth_sink: bandwidth_sink.clone(),
804
805            protocols: HashMap::from_iter([(
806                ProtocolName::from("/notif/1"),
807                ProtocolContext {
808                    tx: tx1,
809                    codec: ProtocolCodec::Identity(32),
810                    fallback_names: Vec::new(),
811                },
812            )]),
813        };
814        let transport_config1 = Config {
815            listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()],
816            ..Default::default()
817        };
818        let resolver = Arc::new(TokioResolver::builder_tokio().unwrap().build());
819
820        let (mut transport1, listen_addresses) =
821            TcpTransport::new(handle1, transport_config1, resolver.clone()).unwrap();
822        let listen_address = listen_addresses[0].clone();
823
824        let keypair2 = Keypair::generate();
825        let (tx2, _rx2) = channel(64);
826        let (event_tx2, _event_rx2) = channel(64);
827
828        let handle2 = crate::transport::manager::TransportHandle {
829            executor: Arc::new(DefaultExecutor {}),
830            next_substream_id: Default::default(),
831            next_connection_id: Default::default(),
832            keypair: keypair2.clone(),
833            tx: event_tx2,
834            bandwidth_sink: bandwidth_sink.clone(),
835
836            protocols: HashMap::from_iter([(
837                ProtocolName::from("/notif/1"),
838                ProtocolContext {
839                    tx: tx2,
840                    codec: ProtocolCodec::Identity(32),
841                    fallback_names: Vec::new(),
842                },
843            )]),
844        };
845        let transport_config2 = Config {
846            listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()],
847            ..Default::default()
848        };
849
850        let (mut transport2, _) = TcpTransport::new(handle2, transport_config2, resolver).unwrap();
851        transport2.dial(ConnectionId::new(), listen_address).unwrap();
852
853        let (tx, mut from_transport2) = channel(64);
854        tokio::spawn(async move {
855            let event = transport2.next().await;
856            tx.send(event).await.unwrap();
857        });
858
859        // Reject connection.
860        let event = transport1.next().await.unwrap();
861        match event {
862            TransportEvent::PendingInboundConnection { connection_id } => {
863                transport1.reject_pending(connection_id).unwrap();
864            }
865            _ => panic!("unexpected event"),
866        }
867
868        let event = from_transport2.recv().await.unwrap();
869        assert!(std::matches!(
870            event,
871            Some(TransportEvent::DialFailure { .. })
872        ));
873    }
874
875    #[tokio::test]
876    async fn dial_failure() {
877        let _ = tracing_subscriber::fmt()
878            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
879            .try_init();
880
881        let keypair1 = Keypair::generate();
882        let (tx1, _rx1) = channel(64);
883        let (event_tx1, mut event_rx1) = channel(64);
884        let bandwidth_sink = BandwidthSink::new();
885
886        let handle1 = crate::transport::manager::TransportHandle {
887            executor: Arc::new(DefaultExecutor {}),
888            next_substream_id: Default::default(),
889            next_connection_id: Default::default(),
890            keypair: keypair1.clone(),
891            tx: event_tx1,
892            bandwidth_sink: bandwidth_sink.clone(),
893
894            protocols: HashMap::from_iter([(
895                ProtocolName::from("/notif/1"),
896                ProtocolContext {
897                    tx: tx1,
898                    codec: ProtocolCodec::Identity(32),
899                    fallback_names: Vec::new(),
900                },
901            )]),
902        };
903        let resolver = Arc::new(TokioResolver::builder_tokio().unwrap().build());
904        let (mut transport1, _) =
905            TcpTransport::new(handle1, Default::default(), resolver.clone()).unwrap();
906
907        tokio::spawn(async move {
908            while let Some(event) = transport1.next().await {
909                match event {
910                    TransportEvent::ConnectionEstablished { .. } => {}
911                    TransportEvent::ConnectionClosed { .. } => {}
912                    TransportEvent::DialFailure { .. } => {}
913                    TransportEvent::ConnectionOpened { .. } => {}
914                    TransportEvent::OpenFailure { .. } => {}
915                    TransportEvent::PendingInboundConnection { .. } => {}
916                }
917            }
918        });
919
920        let keypair2 = Keypair::generate();
921        let (tx2, _rx2) = channel(64);
922        let (event_tx2, _event_rx2) = channel(64);
923
924        let handle2 = crate::transport::manager::TransportHandle {
925            executor: Arc::new(DefaultExecutor {}),
926            next_substream_id: Default::default(),
927            next_connection_id: Default::default(),
928            keypair: keypair2.clone(),
929            tx: event_tx2,
930            bandwidth_sink: bandwidth_sink.clone(),
931
932            protocols: HashMap::from_iter([(
933                ProtocolName::from("/notif/1"),
934                ProtocolContext {
935                    tx: tx2,
936                    codec: ProtocolCodec::Identity(32),
937                    fallback_names: Vec::new(),
938                },
939            )]),
940        };
941
942        let (mut transport2, _) = TcpTransport::new(handle2, Default::default(), resolver).unwrap();
943
944        let peer1: PeerId = PeerId::from_public_key(&keypair1.public().into());
945        let peer2: PeerId = PeerId::from_public_key(&keypair2.public().into());
946
947        tracing::info!(target: LOG_TARGET, "peer1 {peer1}, peer2 {peer2}");
948
949        let address = Multiaddr::empty()
950            .with(Protocol::Ip6(std::net::Ipv6Addr::new(
951                0, 0, 0, 0, 0, 0, 0, 1,
952            )))
953            .with(Protocol::Tcp(8888))
954            .with(Protocol::P2p(
955                Multihash::from_bytes(&peer1.to_bytes()).unwrap(),
956            ));
957
958        transport2.dial(ConnectionId::new(), address).unwrap();
959
960        // spawn the other connection in the background as it won't return anything
961        tokio::spawn(async move {
962            loop {
963                let _ = event_rx1.recv().await;
964            }
965        });
966
967        assert!(std::matches!(
968            transport2.next().await,
969            Some(TransportEvent::DialFailure { .. })
970        ));
971    }
972
973    #[tokio::test]
974    async fn dial_error_reported_for_outbound_connections() {
975        let mut manager = TransportManagerBuilder::new().build();
976        let handle = manager.transport_handle(Arc::new(DefaultExecutor {}));
977        let resolver = Arc::new(TokioResolver::builder_tokio().unwrap().build());
978        manager.register_transport(
979            SupportedTransport::Tcp,
980            Box::new(crate::transport::dummy::DummyTransport::new()),
981        );
982        let (mut transport, _) = TcpTransport::new(
983            handle,
984            Config {
985                listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()],
986                ..Default::default()
987            },
988            resolver,
989        )
990        .unwrap();
991
992        let keypair = Keypair::generate();
993        let peer_id = PeerId::from_public_key(&keypair.public().into());
994        let multiaddr = Multiaddr::empty()
995            .with(Protocol::Ip4(std::net::Ipv4Addr::new(255, 254, 253, 252)))
996            .with(Protocol::Tcp(8888))
997            .with(Protocol::P2p(
998                Multihash::from_bytes(&peer_id.to_bytes()).unwrap(),
999            ));
1000        manager.dial_address(multiaddr.clone()).await.unwrap();
1001
1002        assert!(transport.pending_dials.is_empty());
1003
1004        match transport.dial(ConnectionId::from(0usize), multiaddr) {
1005            Ok(()) => {}
1006            _ => panic!("invalid result for `on_dial_peer()`"),
1007        }
1008
1009        assert!(!transport.pending_dials.is_empty());
1010        transport.pending_connections.push(Box::pin(async move {
1011            Err((ConnectionId::from(0usize), DialError::Timeout))
1012        }));
1013
1014        assert!(std::matches!(
1015            transport.next().await,
1016            Some(TransportEvent::DialFailure { .. })
1017        ));
1018        assert!(transport.pending_dials.is_empty());
1019    }
1020}