litep2p/transport/tcp/
mod.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2// Copyright 2023 litep2p developers
3//
4// Permission is hereby granted, free of charge, to any person obtaining a
5// copy of this software and associated documentation files (the "Software"),
6// to deal in the Software without restriction, including without limitation
7// the rights to use, copy, modify, merge, publish, distribute, sublicense,
8// and/or sell copies of the Software, and to permit persons to whom the
9// Software is furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20// DEALINGS IN THE SOFTWARE.
21
22//! TCP transport.
23
24use crate::{
25    error::{DialError, Error},
26    transport::{
27        common::listener::{DialAddresses, GetSocketAddr, SocketListener, TcpAddress},
28        manager::TransportHandle,
29        tcp::{
30            config::Config,
31            connection::{NegotiatedConnection, TcpConnection},
32        },
33        Transport, TransportBuilder, TransportEvent, DIAL_DEADLINE_MULTIPLIER,
34    },
35    types::ConnectionId,
36    utils::futures_stream::FuturesStream,
37};
38
39use futures::{
40    future::BoxFuture,
41    stream::{AbortHandle, Stream, StreamExt},
42    TryFutureExt,
43};
44use hickory_resolver::TokioResolver;
45use multiaddr::Multiaddr;
46use socket2::{Domain, Socket, Type};
47use tokio::net::TcpStream;
48
49use std::{
50    collections::HashMap,
51    net::SocketAddr,
52    pin::Pin,
53    sync::Arc,
54    task::{Context, Poll},
55    time::Duration,
56};
57
58pub(crate) use substream::Substream;
59
60mod connection;
61mod substream;
62
63pub mod config;
64
65/// Logging target for the file.
66const LOG_TARGET: &str = "litep2p::tcp";
67
68/// Pending inbound connection.
69struct PendingInboundConnection {
70    /// Socket address of the remote peer.
71    connection: TcpStream,
72    /// Address of the remote peer.
73    address: SocketAddr,
74}
75
76#[derive(Debug)]
77enum RawConnectionResult {
78    /// The first successful connection.
79    Connected {
80        negotiated: NegotiatedConnection,
81        errors: Vec<(Multiaddr, DialError)>,
82    },
83
84    /// All connection attempts failed.
85    Failed {
86        connection_id: ConnectionId,
87        errors: Vec<(Multiaddr, DialError)>,
88    },
89
90    /// Future was canceled.
91    Canceled { connection_id: ConnectionId },
92}
93
94/// TCP transport.
95pub(crate) struct TcpTransport {
96    /// Transport context.
97    context: TransportHandle,
98
99    /// Transport configuration.
100    config: Config,
101
102    /// TCP listener.
103    listener: SocketListener,
104
105    /// Pending dials.
106    pending_dials: HashMap<ConnectionId, Multiaddr>,
107
108    /// Dial addresses.
109    dial_addresses: DialAddresses,
110
111    /// Pending inbound connections.
112    pending_inbound_connections: HashMap<ConnectionId, PendingInboundConnection>,
113
114    /// Pending opening connections.
115    pending_connections:
116        FuturesStream<BoxFuture<'static, Result<NegotiatedConnection, (ConnectionId, DialError)>>>,
117
118    /// Pending raw, unnegotiated connections.
119    pending_raw_connections: FuturesStream<BoxFuture<'static, RawConnectionResult>>,
120
121    /// Opened raw connection, waiting for approval/rejection from `TransportManager`.
122    opened: HashMap<ConnectionId, NegotiatedConnection>,
123
124    /// Cancel raw connections futures.
125    ///
126    /// This is cancelling `Self::pending_raw_connections`.
127    cancel_futures: HashMap<ConnectionId, AbortHandle>,
128
129    /// Connections which have been opened and negotiated but are being validated by the
130    /// `TransportManager`.
131    pending_open: HashMap<ConnectionId, NegotiatedConnection>,
132
133    /// DNS resolver.
134    resolver: Arc<TokioResolver>,
135}
136
137impl TcpTransport {
138    /// Handle inbound TCP connection.
139    fn on_inbound_connection(
140        &mut self,
141        connection_id: ConnectionId,
142        connection: TcpStream,
143        address: SocketAddr,
144    ) {
145        let yamux_config = self.config.yamux_config.clone();
146        let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
147        let max_write_buffer_size = self.config.noise_write_buffer_size;
148        let connection_open_timeout = self.config.connection_open_timeout;
149        let substream_open_timeout = self.config.substream_open_timeout;
150        let keypair = self.context.keypair.clone();
151
152        tracing::trace!(
153            target: LOG_TARGET,
154            ?connection_id,
155            ?address,
156            "accept connection",
157        );
158
159        self.pending_connections.push(Box::pin(async move {
160            TcpConnection::accept_connection(
161                connection,
162                connection_id,
163                keypair,
164                address,
165                yamux_config,
166                max_read_ahead_factor,
167                max_write_buffer_size,
168                connection_open_timeout,
169                substream_open_timeout,
170            )
171            .await
172            .map_err(|error| (connection_id, error.into()))
173        }));
174    }
175
176    /// Dial remote peer
177    async fn dial_peer(
178        address: Multiaddr,
179        dial_addresses: DialAddresses,
180        connection_open_timeout: Duration,
181        nodelay: bool,
182        resolver: Arc<TokioResolver>,
183    ) -> Result<(Multiaddr, TcpStream), DialError> {
184        let (socket_address, _) = TcpAddress::multiaddr_to_socket_address(&address)?;
185
186        let remote_address =
187            match tokio::time::timeout(connection_open_timeout, socket_address.lookup_ip(resolver))
188                .await
189            {
190                Err(_) => {
191                    tracing::debug!(
192                        target: LOG_TARGET,
193                        ?address,
194                        ?connection_open_timeout,
195                        "failed to resolve address within timeout",
196                    );
197                    return Err(DialError::Timeout);
198                }
199                Ok(Err(error)) => return Err(error.into()),
200                Ok(Ok(address)) => address,
201            };
202
203        let domain = match remote_address.is_ipv4() {
204            true => Domain::IPV4,
205            false => Domain::IPV6,
206        };
207        let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?;
208        if remote_address.is_ipv6() {
209            socket.set_only_v6(true)?;
210        }
211        socket.set_nonblocking(true)?;
212        socket.set_nodelay(nodelay)?;
213
214        match dial_addresses.local_dial_address(&remote_address.ip()) {
215            Ok(Some(dial_address)) => {
216                socket.set_reuse_address(true)?;
217                #[cfg(unix)]
218                socket.set_reuse_port(true)?;
219                socket.bind(&dial_address.into())?;
220            }
221            Ok(None) => {}
222            Err(()) => {
223                tracing::debug!(
224                    target: LOG_TARGET,
225                    ?remote_address,
226                    "tcp listener not enabled for remote address, using ephemeral port",
227                );
228            }
229        }
230
231        let future = async move {
232            match socket.connect(&remote_address.into()) {
233                Ok(()) => {}
234                Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {}
235                Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {}
236                Err(err) => return Err(err),
237            }
238
239            let stream = TcpStream::try_from(Into::<std::net::TcpStream>::into(socket))?;
240            stream.writable().await?;
241
242            if let Some(e) = stream.take_error()? {
243                return Err(e);
244            }
245
246            Ok((address, stream))
247        };
248
249        match tokio::time::timeout(connection_open_timeout, future).await {
250            Err(_) => {
251                tracing::debug!(
252                    target: LOG_TARGET,
253                    ?connection_open_timeout,
254                    "failed to connect within timeout",
255                );
256                Err(DialError::Timeout)
257            }
258            Ok(Err(error)) => Err(error.into()),
259            Ok(Ok((address, stream))) => {
260                tracing::debug!(
261                    target: LOG_TARGET,
262                    ?address,
263                    "connected",
264                );
265
266                Ok((address, stream))
267            }
268        }
269    }
270}
271
272impl TransportBuilder for TcpTransport {
273    type Config = Config;
274    type Transport = TcpTransport;
275
276    /// Create new [`TcpTransport`].
277    fn new(
278        context: TransportHandle,
279        mut config: Self::Config,
280        resolver: Arc<TokioResolver>,
281    ) -> crate::Result<(Self, Vec<Multiaddr>)> {
282        tracing::debug!(
283            target: LOG_TARGET,
284            listen_addresses = ?config.listen_addresses,
285            "start tcp transport",
286        );
287
288        // start tcp listeners for all listen addresses
289        let (listener, listen_addresses, dial_addresses) = SocketListener::new::<TcpAddress>(
290            std::mem::take(&mut config.listen_addresses),
291            config.reuse_port,
292            config.nodelay,
293        );
294
295        Ok((
296            Self {
297                listener,
298                config,
299                context,
300                dial_addresses,
301                opened: HashMap::new(),
302                pending_open: HashMap::new(),
303                pending_dials: HashMap::new(),
304                pending_inbound_connections: HashMap::new(),
305                pending_connections: FuturesStream::new(),
306                pending_raw_connections: FuturesStream::new(),
307                cancel_futures: HashMap::new(),
308                resolver,
309            },
310            listen_addresses,
311        ))
312    }
313}
314
315impl Transport for TcpTransport {
316    fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> {
317        tracing::debug!(target: LOG_TARGET, ?connection_id, ?address, "open connection");
318
319        let (socket_address, peer) = TcpAddress::multiaddr_to_socket_address(&address)?;
320        let yamux_config = self.config.yamux_config.clone();
321        let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
322        let max_write_buffer_size = self.config.noise_write_buffer_size;
323        let connection_open_timeout = self.config.connection_open_timeout;
324        let substream_open_timeout = self.config.substream_open_timeout;
325        let dial_addresses = self.dial_addresses.clone();
326        let keypair = self.context.keypair.clone();
327        let nodelay = self.config.nodelay;
328        let resolver = self.resolver.clone();
329
330        self.pending_dials.insert(connection_id, address.clone());
331        self.pending_connections.push(Box::pin(async move {
332            let (_, stream) = TcpTransport::dial_peer(
333                address,
334                dial_addresses,
335                connection_open_timeout,
336                nodelay,
337                resolver,
338            )
339            .await
340            .map_err(|error| (connection_id, error))?;
341
342            TcpConnection::open_connection(
343                connection_id,
344                keypair,
345                stream,
346                socket_address,
347                peer,
348                yamux_config,
349                max_read_ahead_factor,
350                max_write_buffer_size,
351                connection_open_timeout,
352                substream_open_timeout,
353            )
354            .await
355            .map_err(|error| (connection_id, error.into()))
356        }));
357
358        Ok(())
359    }
360
361    fn accept_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
362        let pending = self.pending_inbound_connections.remove(&connection_id).ok_or_else(|| {
363            tracing::error!(
364                target: LOG_TARGET,
365                ?connection_id,
366                "Cannot accept non existent pending connection",
367            );
368
369            Error::ConnectionDoesntExist(connection_id)
370        })?;
371
372        self.on_inbound_connection(connection_id, pending.connection, pending.address);
373
374        Ok(())
375    }
376
377    fn reject_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
378        self.pending_inbound_connections.remove(&connection_id).map_or_else(
379            || {
380                tracing::error!(
381                    target: LOG_TARGET,
382                    ?connection_id,
383                    "Cannot reject non existent pending connection",
384                );
385
386                Err(Error::ConnectionDoesntExist(connection_id))
387            },
388            |_| Ok(()),
389        )
390    }
391
392    fn accept(
393        &mut self,
394        connection_id: ConnectionId,
395    ) -> crate::Result<BoxFuture<'static, crate::Result<()>>> {
396        let context = self
397            .pending_open
398            .remove(&connection_id)
399            .ok_or(Error::ConnectionDoesntExist(connection_id))?;
400        let mut protocol_set = self.context.protocol_set(connection_id);
401        let bandwidth_sink = self.context.bandwidth_sink.clone();
402        let next_substream_id = self.context.next_substream_id.clone();
403        let executor = self.context.executor.clone();
404
405        tracing::trace!(
406            target: LOG_TARGET,
407            ?connection_id,
408            "start connection",
409        );
410
411        let peer = context.peer();
412        let endpoint = context.endpoint().clone();
413
414        Ok(Box::pin(async move {
415            // First, notify all protocols about the connection establishment
416            // This ensures that when the accept() future completes, protocols are ready
417            protocol_set.report_connection_established(peer, endpoint).await?;
418
419            // After protocols are notified, spawn the connection event loop
420            executor.run(Box::pin(async move {
421                if let Err(error) =
422                    TcpConnection::new(context, protocol_set, bandwidth_sink, next_substream_id)
423                        .start()
424                        .await
425                {
426                    tracing::debug!(
427                        target: LOG_TARGET,
428                        ?connection_id,
429                        ?error,
430                        "connection exited with error",
431                    );
432                }
433            }));
434
435            Ok(())
436        }))
437    }
438
439    fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
440        self.pending_open
441            .remove(&connection_id)
442            .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(()))
443    }
444
445    fn open(
446        &mut self,
447        connection_id: ConnectionId,
448        addresses: Vec<Multiaddr>,
449    ) -> crate::Result<()> {
450        let num_addresses = addresses.len();
451
452        let yamux_config = self.config.yamux_config.clone();
453        let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
454        let max_write_buffer_size = self.config.noise_write_buffer_size;
455        let connection_open_timeout = self.config.connection_open_timeout;
456        let substream_open_timeout = self.config.substream_open_timeout;
457        let max_parallel_dials = self.config.max_parallel_dials;
458        let dial_addresses = self.dial_addresses.clone();
459        let keypair = self.context.keypair.clone();
460        let nodelay = self.config.nodelay;
461        let resolver = self.resolver.clone();
462
463        let futures = futures::stream::iter(addresses.into_iter().map(move |address| {
464            let yamux_config = yamux_config.clone();
465            let dial_addresses = dial_addresses.clone();
466            let keypair = keypair.clone();
467            let resolver = resolver.clone();
468
469            async move {
470                let (address, stream) = TcpTransport::dial_peer(
471                    address.clone(),
472                    dial_addresses,
473                    connection_open_timeout,
474                    nodelay,
475                    resolver,
476                )
477                .await
478                .map_err(|error| (address, error))?;
479
480                let open_address = address.clone();
481                let (socket_address, peer) = TcpAddress::multiaddr_to_socket_address(&address)
482                    .map_err(|error| (address, error.into()))?;
483
484                TcpConnection::open_connection(
485                    connection_id,
486                    keypair,
487                    stream,
488                    socket_address,
489                    peer,
490                    yamux_config,
491                    max_read_ahead_factor,
492                    max_write_buffer_size,
493                    connection_open_timeout,
494                    substream_open_timeout,
495                )
496                .await
497                .map_err(|error| (open_address, error.into()))
498            }
499        }))
500        .buffer_unordered(max_parallel_dials);
501
502        // Future that will resolve to the first successful connection.
503        let future = async move {
504            let mut errors = Vec::with_capacity(num_addresses);
505            // Deadline for the overall dial attempt, including all retries. This is to prevent
506            // retry attempts from indefinitely delaying the dial result.
507            let dial_deadline = DIAL_DEADLINE_MULTIPLIER * connection_open_timeout;
508            let deadline = tokio::time::sleep(dial_deadline);
509
510            tokio::pin!(deadline);
511            tokio::pin!(futures);
512
513            loop {
514                tokio::select! {
515                    result = futures.next() => {
516                        match result {
517                            Some(Ok(negotiated)) => {
518                                return RawConnectionResult::Connected {
519                                    negotiated,
520                                    errors,
521                                };
522                            }
523                            Some(Err(error)) => {
524                                tracing::debug!(
525                                    target: LOG_TARGET,
526                                    ?connection_id,
527                                    ?error,
528                                    "failed to open connection",
529                                );
530                                errors.push(error);
531                            }
532                            None => {
533                                return RawConnectionResult::Failed {
534                                    connection_id,
535                                    errors,
536                                };
537                            }
538                        }
539                    }
540                    _ = &mut deadline => {
541                        tracing::debug!(
542                            target: LOG_TARGET,
543                            ?connection_id,
544                            ?dial_deadline,
545                            "overall dial timeout exceeded",
546                        );
547                        return RawConnectionResult::Failed {
548                            connection_id,
549                            errors,
550                        };
551                    }
552                }
553            }
554        };
555
556        let (fut, handle) = futures::future::abortable(future);
557        let fut = fut.unwrap_or_else(move |_| RawConnectionResult::Canceled { connection_id });
558        self.pending_raw_connections.push(Box::pin(fut));
559        self.cancel_futures.insert(connection_id, handle);
560
561        Ok(())
562    }
563
564    fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
565        let negotiated = self
566            .opened
567            .remove(&connection_id)
568            .ok_or(Error::ConnectionDoesntExist(connection_id))?;
569
570        self.pending_connections.push(Box::pin(async move { Ok(negotiated) }));
571
572        Ok(())
573    }
574
575    fn cancel(&mut self, connection_id: ConnectionId) {
576        // Cancel the future if it exists.
577        // State clean-up happens inside the `poll_next`.
578        if let Some(handle) = self.cancel_futures.get(&connection_id) {
579            handle.abort();
580        }
581    }
582}
583
584impl Stream for TcpTransport {
585    type Item = TransportEvent;
586
587    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
588        if let Poll::Ready(event) = self.listener.poll_next_unpin(cx) {
589            return match event {
590                None => {
591                    tracing::error!(
592                        target: LOG_TARGET,
593                        "TCP listener terminated, ignore if the node is stopping",
594                    );
595
596                    Poll::Ready(None)
597                }
598                Some(Err(error)) => {
599                    tracing::error!(
600                        target: LOG_TARGET,
601                        ?error,
602                        "TCP listener terminated with error",
603                    );
604
605                    Poll::Ready(None)
606                }
607                Some(Ok((connection, address))) => {
608                    let connection_id = self.context.next_connection_id();
609                    tracing::trace!(
610                        target: LOG_TARGET,
611                        ?connection_id,
612                        ?address,
613                        "pending inbound TCP connection",
614                    );
615
616                    self.pending_inbound_connections.insert(
617                        connection_id,
618                        PendingInboundConnection {
619                            connection,
620                            address,
621                        },
622                    );
623
624                    Poll::Ready(Some(TransportEvent::PendingInboundConnection {
625                        connection_id,
626                    }))
627                }
628            };
629        }
630
631        while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) {
632            tracing::trace!(target: LOG_TARGET, ?result, "raw connection result");
633
634            match result {
635                RawConnectionResult::Connected { negotiated, errors } => {
636                    let Some(handle) = self.cancel_futures.remove(&negotiated.connection_id())
637                    else {
638                        tracing::warn!(
639                            target: LOG_TARGET,
640                            connection_id = ?negotiated.connection_id(),
641                            address = ?negotiated.endpoint().address(),
642                            ?errors,
643                            "raw connection without a cancel handle",
644                        );
645                        continue;
646                    };
647
648                    if !handle.is_aborted() {
649                        let connection_id = negotiated.connection_id();
650                        let address = negotiated.endpoint().address().clone();
651
652                        self.opened.insert(connection_id, negotiated);
653
654                        return Poll::Ready(Some(TransportEvent::ConnectionOpened {
655                            connection_id,
656                            address,
657                            errors,
658                        }));
659                    }
660                }
661
662                RawConnectionResult::Failed {
663                    connection_id,
664                    errors,
665                } => {
666                    let Some(handle) = self.cancel_futures.remove(&connection_id) else {
667                        tracing::warn!(
668                            target: LOG_TARGET,
669                            ?connection_id,
670                            ?errors,
671                            "raw connection without a cancel handle",
672                        );
673                        continue;
674                    };
675
676                    if !handle.is_aborted() {
677                        return Poll::Ready(Some(TransportEvent::OpenFailure {
678                            connection_id,
679                            errors,
680                        }));
681                    }
682                }
683                RawConnectionResult::Canceled { connection_id } => {
684                    if self.cancel_futures.remove(&connection_id).is_none() {
685                        tracing::warn!(
686                            target: LOG_TARGET,
687                            ?connection_id,
688                            "raw cancelled connection without a cancel handle",
689                        );
690                    }
691                }
692            }
693        }
694
695        while let Poll::Ready(Some(connection)) = self.pending_connections.poll_next_unpin(cx) {
696            match connection {
697                Ok(connection) => {
698                    let peer = connection.peer();
699                    let endpoint = connection.endpoint();
700                    self.pending_dials.remove(&connection.connection_id());
701                    self.pending_open.insert(connection.connection_id(), connection);
702
703                    return Poll::Ready(Some(TransportEvent::ConnectionEstablished {
704                        peer,
705                        endpoint,
706                    }));
707                }
708                Err((connection_id, error)) => {
709                    if let Some(address) = self.pending_dials.remove(&connection_id) {
710                        return Poll::Ready(Some(TransportEvent::DialFailure {
711                            connection_id,
712                            address,
713                            error,
714                        }));
715                    } else {
716                        tracing::debug!(target: LOG_TARGET, ?error, ?connection_id, "Pending inbound connection failed");
717                    }
718                }
719            }
720        }
721
722        Poll::Pending
723    }
724}
725
726#[cfg(test)]
727mod tests {
728    use super::*;
729    use crate::{
730        codec::ProtocolCodec,
731        crypto::ed25519::Keypair,
732        executor::DefaultExecutor,
733        protocol::SubstreamKeepAlive,
734        transport::manager::{ProtocolContext, SupportedTransport, TransportManagerBuilder},
735        types::protocol::ProtocolName,
736        BandwidthSink, PeerId,
737    };
738    use multiaddr::Protocol;
739    use std::sync::Arc;
740    use tokio::sync::mpsc::channel;
741
742    #[tokio::test]
743    async fn connect_and_accept_works() {
744        let _ = tracing_subscriber::fmt()
745            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
746            .try_init();
747
748        let keypair1 = Keypair::generate();
749        let (tx1, _rx1) = channel(64);
750        let (event_tx1, _event_rx1) = channel(64);
751        let bandwidth_sink = BandwidthSink::new();
752
753        let handle1 = crate::transport::manager::TransportHandle {
754            executor: Arc::new(DefaultExecutor {}),
755            next_substream_id: Default::default(),
756            next_connection_id: Default::default(),
757            keypair: keypair1.clone(),
758            tx: event_tx1,
759            bandwidth_sink: bandwidth_sink.clone(),
760
761            protocols: HashMap::from_iter([(
762                ProtocolName::from("/notif/1"),
763                ProtocolContext {
764                    tx: tx1,
765                    codec: ProtocolCodec::Identity(32),
766                    fallback_names: Vec::new(),
767                    keep_alive: SubstreamKeepAlive::Yes,
768                },
769            )]),
770        };
771        let transport_config1 = Config {
772            listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()],
773            ..Default::default()
774        };
775        let resolver = Arc::new(TokioResolver::builder_tokio().unwrap().build());
776
777        let (mut transport1, listen_addresses) =
778            TcpTransport::new(handle1, transport_config1, resolver.clone()).unwrap();
779        let listen_address = listen_addresses[0].clone();
780
781        let keypair2 = Keypair::generate();
782        let (tx2, _rx2) = channel(64);
783        let (event_tx2, _event_rx2) = channel(64);
784
785        let handle2 = crate::transport::manager::TransportHandle {
786            executor: Arc::new(DefaultExecutor {}),
787            next_substream_id: Default::default(),
788            next_connection_id: Default::default(),
789            keypair: keypair2.clone(),
790            tx: event_tx2,
791            bandwidth_sink: bandwidth_sink.clone(),
792
793            protocols: HashMap::from_iter([(
794                ProtocolName::from("/notif/1"),
795                ProtocolContext {
796                    tx: tx2,
797                    codec: ProtocolCodec::Identity(32),
798                    fallback_names: Vec::new(),
799                    keep_alive: SubstreamKeepAlive::Yes,
800                },
801            )]),
802        };
803        let transport_config2 = Config {
804            listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()],
805            ..Default::default()
806        };
807
808        let (mut transport2, _) = TcpTransport::new(handle2, transport_config2, resolver).unwrap();
809        transport2.dial(ConnectionId::new(), listen_address).unwrap();
810
811        let (tx, mut from_transport2) = channel(64);
812        tokio::spawn(async move {
813            let event = transport2.next().await;
814            tx.send(event).await.unwrap();
815        });
816
817        let event = transport1.next().await.unwrap();
818        match event {
819            TransportEvent::PendingInboundConnection { connection_id } => {
820                transport1.accept_pending(connection_id).unwrap();
821            }
822            _ => panic!("unexpected event"),
823        }
824
825        let event = transport1.next().await;
826        assert!(std::matches!(
827            event,
828            Some(TransportEvent::ConnectionEstablished { .. })
829        ));
830
831        let event = from_transport2.recv().await.unwrap();
832        assert!(std::matches!(
833            event,
834            Some(TransportEvent::ConnectionEstablished { .. })
835        ));
836    }
837
838    #[tokio::test]
839    async fn connect_and_reject_works() {
840        let _ = tracing_subscriber::fmt()
841            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
842            .try_init();
843
844        let keypair1 = Keypair::generate();
845        let (tx1, _rx1) = channel(64);
846        let (event_tx1, _event_rx1) = channel(64);
847        let bandwidth_sink = BandwidthSink::new();
848
849        let handle1 = crate::transport::manager::TransportHandle {
850            executor: Arc::new(DefaultExecutor {}),
851            next_substream_id: Default::default(),
852            next_connection_id: Default::default(),
853            keypair: keypair1.clone(),
854            tx: event_tx1,
855            bandwidth_sink: bandwidth_sink.clone(),
856
857            protocols: HashMap::from_iter([(
858                ProtocolName::from("/notif/1"),
859                ProtocolContext {
860                    tx: tx1,
861                    codec: ProtocolCodec::Identity(32),
862                    fallback_names: Vec::new(),
863                    keep_alive: SubstreamKeepAlive::Yes,
864                },
865            )]),
866        };
867        let transport_config1 = Config {
868            listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()],
869            ..Default::default()
870        };
871        let resolver = Arc::new(TokioResolver::builder_tokio().unwrap().build());
872
873        let (mut transport1, listen_addresses) =
874            TcpTransport::new(handle1, transport_config1, resolver.clone()).unwrap();
875        let listen_address = listen_addresses[0].clone();
876
877        let keypair2 = Keypair::generate();
878        let (tx2, _rx2) = channel(64);
879        let (event_tx2, _event_rx2) = channel(64);
880
881        let handle2 = crate::transport::manager::TransportHandle {
882            executor: Arc::new(DefaultExecutor {}),
883            next_substream_id: Default::default(),
884            next_connection_id: Default::default(),
885            keypair: keypair2.clone(),
886            tx: event_tx2,
887            bandwidth_sink: bandwidth_sink.clone(),
888
889            protocols: HashMap::from_iter([(
890                ProtocolName::from("/notif/1"),
891                ProtocolContext {
892                    tx: tx2,
893                    codec: ProtocolCodec::Identity(32),
894                    fallback_names: Vec::new(),
895                    keep_alive: SubstreamKeepAlive::Yes,
896                },
897            )]),
898        };
899        let transport_config2 = Config {
900            listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()],
901            ..Default::default()
902        };
903
904        let (mut transport2, _) = TcpTransport::new(handle2, transport_config2, resolver).unwrap();
905        transport2.dial(ConnectionId::new(), listen_address).unwrap();
906
907        let (tx, mut from_transport2) = channel(64);
908        tokio::spawn(async move {
909            let event = transport2.next().await;
910            tx.send(event).await.unwrap();
911        });
912
913        // Reject connection.
914        let event = transport1.next().await.unwrap();
915        match event {
916            TransportEvent::PendingInboundConnection { connection_id } => {
917                transport1.reject_pending(connection_id).unwrap();
918            }
919            _ => panic!("unexpected event"),
920        }
921
922        let event = from_transport2.recv().await.unwrap();
923        assert!(std::matches!(
924            event,
925            Some(TransportEvent::DialFailure { .. })
926        ));
927    }
928
929    #[tokio::test]
930    async fn dial_failure() {
931        let _ = tracing_subscriber::fmt()
932            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
933            .try_init();
934
935        let keypair1 = Keypair::generate();
936        let (tx1, _rx1) = channel(64);
937        let (event_tx1, mut event_rx1) = channel(64);
938        let bandwidth_sink = BandwidthSink::new();
939
940        let handle1 = crate::transport::manager::TransportHandle {
941            executor: Arc::new(DefaultExecutor {}),
942            next_substream_id: Default::default(),
943            next_connection_id: Default::default(),
944            keypair: keypair1.clone(),
945            tx: event_tx1,
946            bandwidth_sink: bandwidth_sink.clone(),
947
948            protocols: HashMap::from_iter([(
949                ProtocolName::from("/notif/1"),
950                ProtocolContext {
951                    tx: tx1,
952                    codec: ProtocolCodec::Identity(32),
953                    fallback_names: Vec::new(),
954                    keep_alive: SubstreamKeepAlive::Yes,
955                },
956            )]),
957        };
958        let resolver = Arc::new(TokioResolver::builder_tokio().unwrap().build());
959        let (mut transport1, _) =
960            TcpTransport::new(handle1, Default::default(), resolver.clone()).unwrap();
961
962        tokio::spawn(async move {
963            while let Some(event) = transport1.next().await {
964                match event {
965                    TransportEvent::ConnectionEstablished { .. } => {}
966                    TransportEvent::ConnectionClosed { .. } => {}
967                    TransportEvent::DialFailure { .. } => {}
968                    TransportEvent::ConnectionOpened { .. } => {}
969                    TransportEvent::OpenFailure { .. } => {}
970                    TransportEvent::PendingInboundConnection { .. } => {}
971                }
972            }
973        });
974
975        let keypair2 = Keypair::generate();
976        let (tx2, _rx2) = channel(64);
977        let (event_tx2, _event_rx2) = channel(64);
978
979        let handle2 = crate::transport::manager::TransportHandle {
980            executor: Arc::new(DefaultExecutor {}),
981            next_substream_id: Default::default(),
982            next_connection_id: Default::default(),
983            keypair: keypair2.clone(),
984            tx: event_tx2,
985            bandwidth_sink: bandwidth_sink.clone(),
986
987            protocols: HashMap::from_iter([(
988                ProtocolName::from("/notif/1"),
989                ProtocolContext {
990                    tx: tx2,
991                    codec: ProtocolCodec::Identity(32),
992                    fallback_names: Vec::new(),
993                    keep_alive: SubstreamKeepAlive::Yes,
994                },
995            )]),
996        };
997
998        let (mut transport2, _) = TcpTransport::new(handle2, Default::default(), resolver).unwrap();
999
1000        let peer1: PeerId = PeerId::from_public_key(&keypair1.public().into());
1001        let peer2: PeerId = PeerId::from_public_key(&keypair2.public().into());
1002
1003        tracing::info!(target: LOG_TARGET, "peer1 {peer1}, peer2 {peer2}");
1004
1005        let address = Multiaddr::empty()
1006            .with(Protocol::Ip6(std::net::Ipv6Addr::new(
1007                0, 0, 0, 0, 0, 0, 0, 1,
1008            )))
1009            .with(Protocol::Tcp(8888))
1010            .with(Protocol::P2p(peer1.into()));
1011
1012        transport2.dial(ConnectionId::new(), address).unwrap();
1013
1014        // spawn the other connection in the background as it won't return anything
1015        tokio::spawn(async move {
1016            loop {
1017                let _ = event_rx1.recv().await;
1018            }
1019        });
1020
1021        assert!(std::matches!(
1022            transport2.next().await,
1023            Some(TransportEvent::DialFailure { .. })
1024        ));
1025    }
1026
1027    #[tokio::test]
1028    async fn dial_error_reported_for_outbound_connections() {
1029        let mut manager = TransportManagerBuilder::new().build();
1030        let handle = manager.transport_handle(Arc::new(DefaultExecutor {}));
1031        let resolver = Arc::new(TokioResolver::builder_tokio().unwrap().build());
1032        manager.register_transport(
1033            SupportedTransport::Tcp,
1034            Box::new(crate::transport::dummy::DummyTransport::new()),
1035        );
1036        let (mut transport, _) = TcpTransport::new(
1037            handle,
1038            Config {
1039                listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()],
1040                ..Default::default()
1041            },
1042            resolver,
1043        )
1044        .unwrap();
1045
1046        let keypair = Keypair::generate();
1047        let peer_id = PeerId::from_public_key(&keypair.public().into());
1048        let multiaddr = Multiaddr::empty()
1049            .with(Protocol::Ip4(std::net::Ipv4Addr::new(255, 254, 253, 252)))
1050            .with(Protocol::Tcp(8888))
1051            .with(Protocol::P2p(peer_id.into()));
1052        manager.dial_address(multiaddr.clone()).await.unwrap();
1053
1054        assert!(transport.pending_dials.is_empty());
1055
1056        match transport.dial(ConnectionId::from(0usize), multiaddr) {
1057            Ok(()) => {}
1058            _ => panic!("invalid result for `on_dial_peer()`"),
1059        }
1060
1061        assert!(!transport.pending_dials.is_empty());
1062        transport.pending_connections.push(Box::pin(async move {
1063            Err((ConnectionId::from(0usize), DialError::Timeout))
1064        }));
1065
1066        assert!(std::matches!(
1067            transport.next().await,
1068            Some(TransportEvent::DialFailure { .. })
1069        ));
1070        assert!(transport.pending_dials.is_empty());
1071    }
1072}