litep2p/transport/tcp/
connection.rs

1// Copyright 2023 litep2p developers
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::{
22    config::Role,
23    crypto::{
24        ed25519::Keypair,
25        noise::{self, NoiseSocket},
26    },
27    error::{Error, NegotiationError, SubstreamError},
28    multistream_select::{dialer_select_proto, listener_select_proto, Negotiated, Version},
29    protocol::{Direction, Permit, ProtocolCommand, ProtocolSet},
30    substream,
31    transport::{
32        common::listener::{AddressType, DnsType},
33        tcp::substream::Substream,
34        Endpoint,
35    },
36    types::{protocol::ProtocolName, ConnectionId, SubstreamId},
37    BandwidthSink, PeerId,
38};
39
40use futures::{
41    future::BoxFuture,
42    stream::{FuturesUnordered, StreamExt},
43    AsyncRead, AsyncWrite,
44};
45use multiaddr::{Multiaddr, Protocol};
46use tokio::net::TcpStream;
47use tokio_util::compat::{
48    Compat, FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt,
49};
50
51use std::{
52    borrow::Cow,
53    fmt,
54    net::SocketAddr,
55    sync::{
56        atomic::{AtomicUsize, Ordering},
57        Arc,
58    },
59    time::Duration,
60};
61
62/// Logging target for the file.
63const LOG_TARGET: &str = "litep2p::tcp::connection";
64
65#[derive(Debug)]
66pub struct NegotiatedSubstream {
67    /// Substream direction.
68    direction: Direction,
69
70    /// Substream ID.
71    substream_id: SubstreamId,
72
73    /// Protocol name.
74    protocol: ProtocolName,
75
76    /// Yamux substream.
77    io: crate::yamux::Stream,
78
79    /// Permit.
80    permit: Permit,
81}
82
83/// TCP connection error.
84#[derive(Debug)]
85enum ConnectionError {
86    /// Timeout
87    Timeout {
88        /// Protocol.
89        protocol: Option<ProtocolName>,
90
91        /// Substream ID.
92        substream_id: Option<SubstreamId>,
93    },
94
95    /// Failed to negotiate connection/substream.
96    FailedToNegotiate {
97        /// Protocol.
98        protocol: Option<ProtocolName>,
99
100        /// Substream ID.
101        substream_id: Option<SubstreamId>,
102
103        /// Error.
104        error: SubstreamError,
105    },
106}
107
108/// Connection context for an opened connection that hasn't yet started its event loop.
109pub struct NegotiatedConnection {
110    /// Yamux connection.
111    connection: crate::yamux::ControlledConnection<NoiseSocket<Compat<TcpStream>>>,
112
113    /// Yamux control.
114    control: crate::yamux::Control,
115
116    /// Remote peer ID.
117    peer: PeerId,
118
119    /// Endpoint.
120    endpoint: Endpoint,
121
122    /// Substream open timeout.
123    substream_open_timeout: Duration,
124}
125
126impl NegotiatedConnection {
127    /// Get `ConnectionId` of the negotiated connection.
128    pub fn connection_id(&self) -> ConnectionId {
129        self.endpoint.connection_id()
130    }
131
132    /// Get `PeerId` of the negotiated connection.
133    pub fn peer(&self) -> PeerId {
134        self.peer
135    }
136
137    /// Get `Endpoint` of the negotiated connection.
138    pub fn endpoint(&self) -> Endpoint {
139        self.endpoint.clone()
140    }
141}
142
143/// TCP connection.
144pub struct TcpConnection {
145    /// Protocol context.
146    protocol_set: ProtocolSet,
147
148    /// Yamux connection.
149    connection: crate::yamux::ControlledConnection<NoiseSocket<Compat<TcpStream>>>,
150
151    /// Yamux control.
152    control: crate::yamux::Control,
153
154    /// Remote peer ID.
155    peer: PeerId,
156
157    /// Endpoint.
158    endpoint: Endpoint,
159
160    /// Substream open timeout.
161    substream_open_timeout: Duration,
162
163    /// Next substream ID.
164    next_substream_id: Arc<AtomicUsize>,
165
166    // Bandwidth sink.
167    bandwidth_sink: BandwidthSink,
168
169    /// Pending substreams.
170    pending_substreams:
171        FuturesUnordered<BoxFuture<'static, Result<NegotiatedSubstream, ConnectionError>>>,
172}
173
174impl fmt::Debug for TcpConnection {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        f.debug_struct("TcpConnection")
177            .field("peer", &self.peer)
178            .field("next_substream_id", &self.next_substream_id)
179            .finish()
180    }
181}
182
183impl TcpConnection {
184    /// Create new [`TcpConnection`] from [`NegotiatedConnection`].
185    pub(super) fn new(
186        context: NegotiatedConnection,
187        protocol_set: ProtocolSet,
188        bandwidth_sink: BandwidthSink,
189        next_substream_id: Arc<AtomicUsize>,
190    ) -> Self {
191        let NegotiatedConnection {
192            connection,
193            control,
194            peer,
195            endpoint,
196            substream_open_timeout,
197        } = context;
198
199        Self {
200            protocol_set,
201            connection,
202            control,
203            peer,
204            endpoint,
205            bandwidth_sink,
206            next_substream_id,
207            pending_substreams: FuturesUnordered::new(),
208            substream_open_timeout,
209        }
210    }
211
212    /// Open connection to remote peer at `address`.
213    // TODO: this function can be removed
214    pub(super) async fn open_connection(
215        connection_id: ConnectionId,
216        keypair: Keypair,
217        stream: TcpStream,
218        address: AddressType,
219        peer: Option<PeerId>,
220        yamux_config: crate::yamux::Config,
221        max_read_ahead_factor: usize,
222        max_write_buffer_size: usize,
223        connection_open_timeout: Duration,
224        substream_open_timeout: Duration,
225    ) -> Result<NegotiatedConnection, NegotiationError> {
226        tracing::debug!(
227            target: LOG_TARGET,
228            ?address,
229            ?peer,
230            "open connection to remote peer",
231        );
232
233        match tokio::time::timeout(connection_open_timeout, async move {
234            Self::negotiate_connection(
235                stream,
236                peer,
237                connection_id,
238                keypair,
239                Role::Dialer,
240                address,
241                yamux_config,
242                max_read_ahead_factor,
243                max_write_buffer_size,
244                substream_open_timeout,
245            )
246            .await
247        })
248        .await
249        {
250            Err(_) => {
251                tracing::trace!(target: LOG_TARGET, ?connection_id, "connection timed out during negotiation");
252                Err(NegotiationError::Timeout)
253            }
254            Ok(result) => result,
255        }
256    }
257
258    /// Open substream for `protocol`.
259    pub(super) async fn open_substream(
260        mut control: crate::yamux::Control,
261        substream_id: SubstreamId,
262        permit: Permit,
263        protocol: ProtocolName,
264        fallback_names: Vec<ProtocolName>,
265        open_timeout: Duration,
266    ) -> Result<NegotiatedSubstream, SubstreamError> {
267        tracing::debug!(target: LOG_TARGET, ?protocol, ?substream_id, "open substream");
268
269        let stream = match control.open_stream().await {
270            Ok(stream) => {
271                tracing::trace!(target: LOG_TARGET, ?substream_id, "substream opened");
272                stream
273            }
274            Err(error) => {
275                tracing::debug!(
276                    target: LOG_TARGET,
277                    ?substream_id,
278                    ?error,
279                    "failed to open substream"
280                );
281                return Err(SubstreamError::YamuxError(
282                    error,
283                    Direction::Outbound(substream_id),
284                ));
285            }
286        };
287
288        // TODO: protocols don't change after they've been initialized so this should be done only
289        // once
290        let protocols = std::iter::once(&*protocol)
291            .chain(fallback_names.iter().map(|protocol| &**protocol))
292            .collect();
293
294        let (io, protocol) =
295            Self::negotiate_protocol(stream, &Role::Dialer, protocols, open_timeout).await?;
296
297        Ok(NegotiatedSubstream {
298            io: io.inner(),
299            substream_id,
300            direction: Direction::Outbound(substream_id),
301            protocol,
302            permit,
303        })
304    }
305
306    /// Accept a new connection.
307    pub(super) async fn accept_connection(
308        stream: TcpStream,
309        connection_id: ConnectionId,
310        keypair: Keypair,
311        address: SocketAddr,
312        yamux_config: crate::yamux::Config,
313        max_read_ahead_factor: usize,
314        max_write_buffer_size: usize,
315        connection_open_timeout: Duration,
316        substream_open_timeout: Duration,
317    ) -> Result<NegotiatedConnection, NegotiationError> {
318        tracing::debug!(target: LOG_TARGET, ?address, "accept connection");
319
320        match tokio::time::timeout(connection_open_timeout, async move {
321            Self::negotiate_connection(
322                stream,
323                None,
324                connection_id,
325                keypair,
326                Role::Listener,
327                AddressType::Socket(address),
328                yamux_config,
329                max_read_ahead_factor,
330                max_write_buffer_size,
331                substream_open_timeout,
332            )
333            .await
334        })
335        .await
336        {
337            Err(_) => Err(NegotiationError::Timeout),
338            Ok(result) => result,
339        }
340    }
341
342    /// Accept substream.
343    pub(super) async fn accept_substream(
344        stream: crate::yamux::Stream,
345        permit: Permit,
346        substream_id: SubstreamId,
347        protocols: Vec<ProtocolName>,
348        open_timeout: Duration,
349    ) -> Result<NegotiatedSubstream, NegotiationError> {
350        tracing::trace!(
351            target: LOG_TARGET,
352            ?substream_id,
353            "accept inbound substream",
354        );
355
356        let protocols = protocols.iter().map(|protocol| &**protocol).collect::<Vec<&str>>();
357        let (io, protocol) =
358            Self::negotiate_protocol(stream, &Role::Listener, protocols, open_timeout).await?;
359
360        tracing::trace!(
361            target: LOG_TARGET,
362            ?substream_id,
363            "substream accepted and negotiated",
364        );
365
366        Ok(NegotiatedSubstream {
367            io: io.inner(),
368            substream_id,
369            direction: Direction::Inbound,
370            protocol,
371            permit,
372        })
373    }
374
375    /// Negotiate protocol.
376    async fn negotiate_protocol<S: AsyncRead + AsyncWrite + Unpin>(
377        stream: S,
378        role: &Role,
379        protocols: Vec<&str>,
380        substream_open_timeout: Duration,
381    ) -> Result<(Negotiated<S>, ProtocolName), NegotiationError> {
382        tracing::trace!(target: LOG_TARGET, ?protocols, "negotiating protocols");
383
384        match tokio::time::timeout(substream_open_timeout, async move {
385            match role {
386                Role::Dialer => dialer_select_proto(stream, protocols, Version::V1).await,
387                Role::Listener => listener_select_proto(stream, protocols).await,
388            }
389        })
390        .await
391        {
392            Err(_) => Err(NegotiationError::Timeout),
393            Ok(Err(error)) => Err(NegotiationError::MultistreamSelectError(error)),
394            Ok(Ok((protocol, socket))) => {
395                tracing::trace!(target: LOG_TARGET, ?protocol, "protocol negotiated");
396
397                Ok((socket, ProtocolName::from(protocol.to_string())))
398            }
399        }
400    }
401
402    /// Negotiate noise + yamux for the connection.
403    pub(super) async fn negotiate_connection(
404        stream: TcpStream,
405        dialed_peer: Option<PeerId>,
406        connection_id: ConnectionId,
407        keypair: Keypair,
408        role: Role,
409        address: AddressType,
410        yamux_config: crate::yamux::Config,
411        max_read_ahead_factor: usize,
412        max_write_buffer_size: usize,
413        substream_open_timeout: Duration,
414    ) -> Result<NegotiatedConnection, NegotiationError> {
415        tracing::trace!(
416            target: LOG_TARGET,
417            ?role,
418            "negotiate connection",
419        );
420
421        let stream = TokioAsyncReadCompatExt::compat(stream).into_inner();
422        let stream = TokioAsyncWriteCompatExt::compat_write(stream);
423
424        // negotiate `noise`
425        let (stream, _) =
426            Self::negotiate_protocol(stream, &role, vec!["/noise"], substream_open_timeout).await?;
427
428        tracing::trace!(
429            target: LOG_TARGET,
430            "`multistream-select` and `noise` negotiated",
431        );
432
433        // perform noise handshake
434        let (stream, peer) = noise::handshake(
435            stream.inner(),
436            &keypair,
437            role,
438            max_read_ahead_factor,
439            max_write_buffer_size,
440        )
441        .await?;
442
443        if let Some(dialed_peer) = dialed_peer {
444            if dialed_peer != peer {
445                tracing::debug!(target: LOG_TARGET, ?dialed_peer, ?peer, "peer id mismatch");
446                return Err(NegotiationError::PeerIdMismatch(dialed_peer, peer));
447            }
448        }
449
450        tracing::trace!(target: LOG_TARGET, "noise handshake done");
451        let stream: NoiseSocket<Compat<TcpStream>> = stream;
452
453        // negotiate `yamux`
454        let (stream, _) =
455            Self::negotiate_protocol(stream, &role, vec!["/yamux/1.0.0"], substream_open_timeout)
456                .await?;
457        tracing::trace!(target: LOG_TARGET, "`yamux` negotiated");
458
459        let connection = crate::yamux::Connection::new(stream.inner(), yamux_config, role.into());
460        let (control, connection) = crate::yamux::Control::new(connection);
461
462        let address = match address {
463            AddressType::Socket(address) => Multiaddr::empty()
464                .with(Protocol::from(address.ip()))
465                .with(Protocol::Tcp(address.port())),
466            AddressType::Dns {
467                address,
468                port,
469                dns_type,
470            } => match dns_type {
471                DnsType::Dns => Multiaddr::empty()
472                    .with(Protocol::Dns(Cow::Owned(address)))
473                    .with(Protocol::Tcp(port)),
474                DnsType::Dns4 => Multiaddr::empty()
475                    .with(Protocol::Dns4(Cow::Owned(address)))
476                    .with(Protocol::Tcp(port)),
477                DnsType::Dns6 => Multiaddr::empty()
478                    .with(Protocol::Dns6(Cow::Owned(address)))
479                    .with(Protocol::Tcp(port)),
480            },
481        };
482        let endpoint = match role {
483            Role::Dialer => Endpoint::dialer(address, connection_id),
484            Role::Listener => Endpoint::listener(address, connection_id),
485        };
486
487        Ok(NegotiatedConnection {
488            peer,
489            control,
490            connection,
491            endpoint,
492            substream_open_timeout,
493        })
494    }
495
496    /// Start connection event loop.
497    pub(crate) async fn start(mut self) -> crate::Result<()> {
498        self.protocol_set
499            .report_connection_established(self.peer, self.endpoint.clone())
500            .await?;
501
502        loop {
503            tokio::select! {
504                substream = self.connection.next() => match substream {
505                    Some(Ok(stream)) => {
506                        let substream_id = {
507                            let substream_id = self.next_substream_id.fetch_add(1usize, Ordering::Relaxed);
508                            SubstreamId::from(substream_id)
509                        };
510                        let protocols = self.protocol_set.protocols();
511                        let permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?;
512                        let open_timeout = self.substream_open_timeout;
513
514                        self.pending_substreams.push(Box::pin(async move {
515                            match tokio::time::timeout(
516                                open_timeout,
517                                Self::accept_substream(stream, permit, substream_id, protocols, open_timeout),
518                            )
519                            .await
520                            {
521                                Ok(Ok(substream)) => Ok(substream),
522                                Ok(Err(error)) => Err(ConnectionError::FailedToNegotiate {
523                                    protocol: None,
524                                    substream_id: None,
525                                    error: SubstreamError::NegotiationError(error),
526                                }),
527                                Err(_) => Err(ConnectionError::Timeout {
528                                    protocol: None,
529                                    substream_id: None
530                                }),
531                            }
532                        }));
533                    },
534                    Some(Err(error)) => {
535                        tracing::debug!(
536                            target: LOG_TARGET,
537                            peer = ?self.peer,
538                            ?error,
539                            "connection closed with error",
540                        );
541                        self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await?;
542
543                        return Ok(())
544                    }
545                    None => {
546                        tracing::debug!(target: LOG_TARGET, peer = ?self.peer, "connection closed");
547                        self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await?;
548
549                        return Ok(())
550                    }
551                },
552                // TODO: move this to a function
553                substream = self.pending_substreams.select_next_some(), if !self.pending_substreams.is_empty() => {
554                    match substream {
555                        // TODO: return error to protocol
556                        Err(error) => {
557                            tracing::debug!(
558                                target: LOG_TARGET,
559                                ?error,
560                                "failed to accept/open substream",
561                            );
562
563                            let (protocol, substream_id, error) = match error {
564                                ConnectionError::Timeout { protocol, substream_id } => {
565                                    (protocol, substream_id, SubstreamError::NegotiationError(NegotiationError::Timeout))
566                                }
567                                ConnectionError::FailedToNegotiate { protocol, substream_id, error } => {
568                                    (protocol, substream_id, error)
569                                }
570                            };
571
572                            match (protocol, substream_id) {
573                                (Some(protocol), Some(substream_id)) => {
574                                    if let Err(error) = self.protocol_set
575                                        .report_substream_open_failure(protocol, substream_id, error)
576                                        .await
577                                    {
578                                        tracing::error!(
579                                            target: LOG_TARGET,
580                                            ?error,
581                                            "failed to register opened substream to protocol"
582                                        );
583                                    }
584                                }
585                                _ => {}
586                            }
587                        }
588                        Ok(substream) => {
589                            let protocol = substream.protocol.clone();
590                            let direction = substream.direction;
591                            let substream_id = substream.substream_id;
592                            let socket = FuturesAsyncReadCompatExt::compat(substream.io);
593                            let bandwidth_sink = self.bandwidth_sink.clone();
594
595                            let substream = substream::Substream::new_tcp(
596                                self.peer,
597                                substream_id,
598                                Substream::new(socket, bandwidth_sink, substream.permit),
599                                self.protocol_set.protocol_codec(&protocol)
600                            );
601
602                            if let Err(error) = self.protocol_set
603                                .report_substream_open(self.peer, protocol, direction, substream)
604                                .await
605                            {
606                                tracing::error!(
607                                    target: LOG_TARGET,
608                                    ?error,
609                                    "failed to register opened substream to protocol",
610                                );
611                            }
612                        }
613                    }
614                }
615                protocol = self.protocol_set.next() => match protocol {
616                    Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit }) => {
617                        let control = self.control.clone();
618                        let open_timeout = self.substream_open_timeout;
619
620                        tracing::trace!(
621                            target: LOG_TARGET,
622                            ?protocol,
623                            ?substream_id,
624                            "open substream",
625                        );
626
627                        self.pending_substreams.push(Box::pin(async move {
628                            match tokio::time::timeout(
629                                open_timeout,
630                                Self::open_substream(
631                                    control,
632                                    substream_id,
633                                    permit,
634                                    protocol.clone(),
635                                    fallback_names,
636                                    open_timeout,
637                                ),
638                            )
639                            .await
640                            {
641                                Ok(Ok(substream)) => Ok(substream),
642                                Ok(Err(error)) => Err(ConnectionError::FailedToNegotiate {
643                                    protocol: Some(protocol),
644                                    substream_id: Some(substream_id),
645                                    error,
646                                }),
647                                Err(_) => Err(ConnectionError::Timeout {
648                                    protocol: Some(protocol),
649                                    substream_id: Some(substream_id)
650                                }),
651                            }
652                        }));
653                    }
654                    Some(ProtocolCommand::ForceClose) => {
655                        tracing::debug!(
656                            target: LOG_TARGET,
657                            peer = ?self.peer,
658                            connection_id = ?self.endpoint.connection_id(),
659                            "force closing connection",
660                        );
661
662                        return self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await
663                    }
664                    None => {
665                        tracing::debug!(target: LOG_TARGET, "protocols have disconnected, closing connection");
666                        return self.protocol_set.report_connection_closed(self.peer, self.endpoint.connection_id()).await
667                    }
668                }
669            }
670        }
671    }
672}
673
674#[cfg(test)]
675mod tests {
676    use crate::transport::tcp::TcpTransport;
677
678    use super::*;
679    use tokio::{io::AsyncWriteExt, net::TcpListener};
680
681    #[tokio::test]
682    async fn multistream_select_not_supported_dialer() {
683        let _ = tracing_subscriber::fmt()
684            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
685            .try_init();
686
687        let listener = TcpListener::bind("[::1]:0").await.unwrap();
688        let address = listener.local_addr().unwrap();
689
690        tokio::spawn(async move {
691            let (mut stream, _) = listener.accept().await.unwrap();
692            let _ = stream.write_all(&vec![0x12u8; 256]).await;
693        });
694
695        let (_, stream) = TcpTransport::dial_peer(
696            Multiaddr::empty()
697                .with(Protocol::from(address.ip()))
698                .with(Protocol::Tcp(address.port())),
699            Default::default(),
700            Duration::from_secs(10),
701            false,
702        )
703        .await
704        .unwrap();
705
706        match TcpConnection::open_connection(
707            ConnectionId::from(0usize),
708            Keypair::generate(),
709            stream,
710            AddressType::Socket(address),
711            None,
712            Default::default(),
713            5,
714            2,
715            Duration::from_secs(10),
716            Duration::from_secs(10),
717        )
718        .await
719        {
720            Ok(_) => panic!("connection was supposed to fail"),
721            Err(NegotiationError::MultistreamSelectError(
722                crate::multistream_select::NegotiationError::ProtocolError(
723                    crate::multistream_select::ProtocolError::InvalidMessage,
724                ),
725            )) => {}
726            Err(error) => panic!("invalid error: {error:?}"),
727        }
728    }
729
730    #[tokio::test]
731    async fn multistream_select_not_supported_listener() {
732        let _ = tracing_subscriber::fmt()
733            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
734            .try_init();
735
736        let listener = TcpListener::bind("[::1]:0").await.unwrap();
737        let address = listener.local_addr().unwrap();
738
739        let (Ok(mut dialer), Ok((stream, dialer_address))) =
740            tokio::join!(TcpStream::connect(address.clone()), listener.accept(),)
741        else {
742            panic!("failed to establish connection");
743        };
744
745        tokio::spawn(async move {
746            let _ = dialer.write_all(&vec![0x12u8; 256]).await;
747        });
748
749        match TcpConnection::accept_connection(
750            stream,
751            ConnectionId::from(0usize),
752            Keypair::generate(),
753            dialer_address,
754            Default::default(),
755            5,
756            2,
757            Duration::from_secs(10),
758            Duration::from_secs(10),
759        )
760        .await
761        {
762            Ok(_) => panic!("connection was supposed to fail"),
763            Err(NegotiationError::MultistreamSelectError(
764                crate::multistream_select::NegotiationError::ProtocolError(
765                    crate::multistream_select::ProtocolError::InvalidMessage,
766                ),
767            )) => {}
768            Err(error) => panic!("invalid error: {error:?}"),
769        }
770    }
771
772    #[tokio::test]
773    async fn noise_not_supported_dialer() {
774        let _ = tracing_subscriber::fmt()
775            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
776            .try_init();
777
778        let listener = TcpListener::bind("[::1]:0").await.unwrap();
779        let address = listener.local_addr().unwrap();
780
781        tokio::spawn(async move {
782            let (stream, _) = listener.accept().await.unwrap();
783            let stream = TokioAsyncReadCompatExt::compat(stream).into_inner();
784            let stream = TokioAsyncWriteCompatExt::compat_write(stream);
785
786            // attempt to negotiate yamux, skipping noise entirely
787            assert!(listener_select_proto(stream, vec!["/yamux/1.0.0"]).await.is_err());
788        });
789
790        let (_, stream) = TcpTransport::dial_peer(
791            Multiaddr::empty()
792                .with(Protocol::from(address.ip()))
793                .with(Protocol::Tcp(address.port())),
794            Default::default(),
795            Duration::from_secs(10),
796            false,
797        )
798        .await
799        .unwrap();
800
801        match TcpConnection::open_connection(
802            ConnectionId::from(0usize),
803            Keypair::generate(),
804            stream,
805            AddressType::Socket(address),
806            None,
807            Default::default(),
808            5,
809            2,
810            Duration::from_secs(10),
811            Duration::from_secs(10),
812        )
813        .await
814        {
815            Ok(_) => panic!("connection was supposed to fail"),
816            Err(NegotiationError::MultistreamSelectError(
817                crate::multistream_select::NegotiationError::Failed,
818            )) => {}
819            Err(error) => panic!("invalid error: {error:?}"),
820        }
821    }
822
823    #[tokio::test]
824    async fn noise_not_supported_listener() {
825        let _ = tracing_subscriber::fmt()
826            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
827            .try_init();
828
829        let listener = TcpListener::bind("[::1]:0").await.unwrap();
830        let address = listener.local_addr().unwrap();
831
832        let (Ok(dialer), Ok((listener, dialer_address))) =
833            tokio::join!(TcpStream::connect(address.clone()), listener.accept(),)
834        else {
835            panic!("failed to establish connection");
836        };
837
838        tokio::spawn(async move {
839            let dialer = TokioAsyncReadCompatExt::compat(dialer).into_inner();
840            let dialer = TokioAsyncWriteCompatExt::compat_write(dialer);
841
842            // attempt to negotiate yamux, skipping noise entirely
843            assert!(dialer_select_proto(dialer, vec!["/yamux/1.0.0"], Version::V1).await.is_err());
844        });
845
846        match TcpConnection::accept_connection(
847            listener,
848            ConnectionId::from(0usize),
849            Keypair::generate(),
850            dialer_address,
851            Default::default(),
852            5,
853            2,
854            Duration::from_secs(10),
855            Duration::from_secs(10),
856        )
857        .await
858        {
859            Ok(_) => panic!("connection was supposed to fail"),
860            Err(NegotiationError::MultistreamSelectError(
861                crate::multistream_select::NegotiationError::Failed,
862            )) => {}
863            Err(error) => panic!("invalid error: {error:?}"),
864        }
865    }
866
867    #[tokio::test]
868    async fn noise_timeout_listener() {
869        let _ = tracing_subscriber::fmt()
870            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
871            .try_init();
872
873        let listener = TcpListener::bind("[::1]:0").await.unwrap();
874        let address = listener.local_addr().unwrap();
875
876        let (Ok(dialer), Ok((listener, dialer_address))) =
877            tokio::join!(TcpStream::connect(address.clone()), listener.accept(),)
878        else {
879            panic!("failed to establish connection");
880        };
881
882        tokio::spawn(async move {
883            let dialer = TokioAsyncReadCompatExt::compat(dialer).into_inner();
884            let dialer = TokioAsyncWriteCompatExt::compat_write(dialer);
885
886            // attempt to negotiate yamux, skipping noise entirely
887            let (_protocol, _socket) =
888                dialer_select_proto(dialer, vec!["/noise"], Version::V1).await.unwrap();
889
890            tokio::time::sleep(std::time::Duration::from_secs(60)).await;
891        });
892
893        match TcpConnection::accept_connection(
894            listener,
895            ConnectionId::from(0usize),
896            Keypair::generate(),
897            dialer_address,
898            Default::default(),
899            5,
900            2,
901            Duration::from_secs(10),
902            Duration::from_secs(10),
903        )
904        .await
905        {
906            Ok(_) => panic!("connection was supposed to fail"),
907            Err(NegotiationError::Timeout) => {}
908            Err(error) => panic!("invalid error: {error:?}"),
909        }
910    }
911
912    #[tokio::test]
913    async fn noise_timeout_dialer() {
914        let _ = tracing_subscriber::fmt()
915            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
916            .try_init();
917
918        let listener = TcpListener::bind("[::1]:0").await.unwrap();
919        let address = listener.local_addr().unwrap();
920
921        tokio::spawn(async move {
922            let (stream, _) = listener.accept().await.unwrap();
923            let stream = TokioAsyncReadCompatExt::compat(stream).into_inner();
924            let stream = TokioAsyncWriteCompatExt::compat_write(stream);
925
926            // negotiate noise but never actually send any handshake data
927            let (_protocol, _socket) = listener_select_proto(stream, vec!["/noise"]).await.unwrap();
928
929            tokio::time::sleep(std::time::Duration::from_secs(60)).await;
930        });
931
932        let (_, stream) = TcpTransport::dial_peer(
933            Multiaddr::empty()
934                .with(Protocol::from(address.ip()))
935                .with(Protocol::Tcp(address.port())),
936            Default::default(),
937            Duration::from_secs(10),
938            false,
939        )
940        .await
941        .unwrap();
942
943        match TcpConnection::open_connection(
944            ConnectionId::from(0usize),
945            Keypair::generate(),
946            stream,
947            AddressType::Socket(address),
948            None,
949            Default::default(),
950            5,
951            2,
952            Duration::from_secs(10),
953            Duration::from_secs(10),
954        )
955        .await
956        {
957            Ok(_) => panic!("connection was supposed to fail"),
958            Err(NegotiationError::Timeout) => {}
959            Err(error) => panic!("invalid error: {error:?}"),
960        }
961    }
962
963    #[tokio::test]
964    async fn multistream_select_timeout_dialer() {
965        let _ = tracing_subscriber::fmt()
966            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
967            .try_init();
968
969        let listener = TcpListener::bind("[::1]:0").await.unwrap();
970        let address = listener.local_addr().unwrap();
971
972        tokio::spawn(async move {
973            let _stream = listener.accept().await.unwrap();
974
975            tokio::time::sleep(std::time::Duration::from_secs(60)).await;
976        });
977
978        let (_, stream) = TcpTransport::dial_peer(
979            Multiaddr::empty()
980                .with(Protocol::from(address.ip()))
981                .with(Protocol::Tcp(address.port())),
982            Default::default(),
983            Duration::from_secs(10),
984            false,
985        )
986        .await
987        .unwrap();
988
989        match TcpConnection::open_connection(
990            ConnectionId::from(0usize),
991            Keypair::generate(),
992            stream,
993            AddressType::Socket(address),
994            None,
995            Default::default(),
996            5,
997            2,
998            Duration::from_secs(10),
999            Duration::from_secs(10),
1000        )
1001        .await
1002        {
1003            Ok(_) => panic!("connection was supposed to fail"),
1004            Err(NegotiationError::Timeout) => {}
1005            Err(error) => panic!("invalid error: {error:?}"),
1006        }
1007    }
1008
1009    #[tokio::test]
1010    async fn multistream_select_timeout_listener() {
1011        let _ = tracing_subscriber::fmt()
1012            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1013            .try_init();
1014
1015        let listener = TcpListener::bind("[::1]:0").await.unwrap();
1016        let address = listener.local_addr().unwrap();
1017
1018        let (Ok(_dialer), Ok((listener, dialer_address))) =
1019            tokio::join!(TcpStream::connect(address.clone()), listener.accept(),)
1020        else {
1021            panic!("failed to establish connection");
1022        };
1023
1024        tokio::spawn(async move {
1025            let _stream = TcpStream::connect(address).await.unwrap();
1026
1027            tokio::time::sleep(std::time::Duration::from_secs(60)).await;
1028        });
1029
1030        match TcpConnection::accept_connection(
1031            listener,
1032            ConnectionId::from(0usize),
1033            Keypair::generate(),
1034            dialer_address,
1035            Default::default(),
1036            5,
1037            2,
1038            Duration::from_secs(10),
1039            Duration::from_secs(10),
1040        )
1041        .await
1042        {
1043            Ok(_) => panic!("connection was supposed to fail"),
1044            Err(NegotiationError::Timeout) => {}
1045            Err(error) => panic!("invalid error: {error:?}"),
1046        }
1047    }
1048
1049    #[tokio::test]
1050    async fn yamux_not_supported_dialer() {
1051        let _ = tracing_subscriber::fmt()
1052            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1053            .try_init();
1054
1055        let listener = TcpListener::bind("[::1]:0").await.unwrap();
1056        let address = listener.local_addr().unwrap();
1057
1058        let (Ok(dialer), Ok((listener, dialer_address))) =
1059            tokio::join!(TcpStream::connect(address.clone()), listener.accept(),)
1060        else {
1061            panic!("failed to establish connection");
1062        };
1063
1064        tokio::spawn(async move {
1065            let dialer = TokioAsyncReadCompatExt::compat(dialer).into_inner();
1066            let dialer = TokioAsyncWriteCompatExt::compat_write(dialer);
1067
1068            // negotiate noise
1069            let (_protocol, stream) =
1070                dialer_select_proto(dialer, vec!["/noise"], Version::V1).await.unwrap();
1071
1072            let keypair = Keypair::generate();
1073
1074            // do a noise handshake
1075            let (stream, _peer) =
1076                noise::handshake(stream.inner(), &keypair, Role::Dialer, 5, 2).await.unwrap();
1077            let stream: NoiseSocket<Compat<TcpStream>> = stream;
1078
1079            // after the handshake, try to negotiate some random protocol instead of yamux
1080            assert!(
1081                dialer_select_proto(stream, vec!["/unsupported/1"], Version::V1).await.is_err()
1082            );
1083        });
1084
1085        match TcpConnection::accept_connection(
1086            listener,
1087            ConnectionId::from(0usize),
1088            Keypair::generate(),
1089            dialer_address,
1090            Default::default(),
1091            5,
1092            2,
1093            Duration::from_secs(10),
1094            Duration::from_secs(10),
1095        )
1096        .await
1097        {
1098            Ok(_) => panic!("connection was supposed to fail"),
1099            Err(NegotiationError::MultistreamSelectError(
1100                crate::multistream_select::NegotiationError::Failed,
1101            )) => {}
1102            Err(error) => panic!("{error:?}"),
1103        }
1104    }
1105
1106    #[tokio::test]
1107    async fn yamux_not_supported_listener() {
1108        let _ = tracing_subscriber::fmt()
1109            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1110            .try_init();
1111
1112        let listener = TcpListener::bind("[::1]:0").await.unwrap();
1113        let address = listener.local_addr().unwrap();
1114
1115        tokio::spawn(async move {
1116            let (stream, _) = listener.accept().await.unwrap();
1117            let stream = TokioAsyncReadCompatExt::compat(stream).into_inner();
1118            let stream = TokioAsyncWriteCompatExt::compat_write(stream);
1119
1120            // negotiate noise
1121            let (_protocol, stream) = listener_select_proto(stream, vec!["/noise"]).await.unwrap();
1122
1123            // do a noise handshake
1124            let keypair = Keypair::generate();
1125            let (stream, _peer) =
1126                noise::handshake(stream.inner(), &keypair, Role::Listener, 5, 2).await.unwrap();
1127            let stream: NoiseSocket<Compat<TcpStream>> = stream;
1128
1129            // after the handshake, try to negotiate some random protocol instead of yamux
1130            assert!(listener_select_proto(stream, vec!["/unsupported/1"]).await.is_err());
1131        });
1132
1133        let (_, stream) = TcpTransport::dial_peer(
1134            Multiaddr::empty()
1135                .with(Protocol::from(address.ip()))
1136                .with(Protocol::Tcp(address.port())),
1137            Default::default(),
1138            Duration::from_secs(10),
1139            false,
1140        )
1141        .await
1142        .unwrap();
1143
1144        match TcpConnection::open_connection(
1145            ConnectionId::from(0usize),
1146            Keypair::generate(),
1147            stream,
1148            AddressType::Socket(address),
1149            None,
1150            Default::default(),
1151            5,
1152            2,
1153            Duration::from_secs(10),
1154            Duration::from_secs(10),
1155        )
1156        .await
1157        {
1158            Ok(_) => panic!("connection was supposed to fail"),
1159            Err(NegotiationError::MultistreamSelectError(
1160                crate::multistream_select::NegotiationError::Failed,
1161            )) => {}
1162            Err(error) => panic!("{error:?}"),
1163        }
1164    }
1165
1166    #[tokio::test]
1167    async fn yamux_timeout_dialer() {
1168        let _ = tracing_subscriber::fmt()
1169            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1170            .try_init();
1171
1172        let listener = TcpListener::bind("[::1]:0").await.unwrap();
1173        let address = listener.local_addr().unwrap();
1174
1175        let (Ok(dialer), Ok((listener, dialer_address))) =
1176            tokio::join!(TcpStream::connect(address.clone()), listener.accept())
1177        else {
1178            panic!("failed to establish connection");
1179        };
1180
1181        tokio::spawn(async move {
1182            let dialer = TokioAsyncReadCompatExt::compat(dialer).into_inner();
1183            let dialer = TokioAsyncWriteCompatExt::compat_write(dialer);
1184
1185            // negotiate noise
1186            let (_protocol, stream) =
1187                dialer_select_proto(dialer, vec!["/noise"], Version::V1).await.unwrap();
1188
1189            // do a noise handshake
1190            let keypair = Keypair::generate();
1191            let (stream, _peer) =
1192                noise::handshake(stream.inner(), &keypair, Role::Dialer, 5, 2).await.unwrap();
1193            let _stream: NoiseSocket<Compat<TcpStream>> = stream;
1194
1195            tokio::time::sleep(std::time::Duration::from_secs(60)).await;
1196        });
1197
1198        match TcpConnection::accept_connection(
1199            listener,
1200            ConnectionId::from(0usize),
1201            Keypair::generate(),
1202            dialer_address,
1203            Default::default(),
1204            5,
1205            2,
1206            Duration::from_secs(10),
1207            Duration::from_secs(10),
1208        )
1209        .await
1210        {
1211            Ok(_) => panic!("connection was supposed to fail"),
1212            Err(NegotiationError::Timeout) => {}
1213            Err(error) => panic!("invalid error: {error:?}"),
1214        }
1215    }
1216
1217    #[tokio::test]
1218    async fn yamux_timeout_listener() {
1219        let _ = tracing_subscriber::fmt()
1220            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
1221            .try_init();
1222
1223        let listener = TcpListener::bind("[::1]:0").await.unwrap();
1224        let address = listener.local_addr().unwrap();
1225
1226        tokio::spawn(async move {
1227            let (stream, _) = listener.accept().await.unwrap();
1228            let stream = TokioAsyncReadCompatExt::compat(stream).into_inner();
1229            let stream = TokioAsyncWriteCompatExt::compat_write(stream);
1230
1231            // negotiate noise
1232            let (_protocol, stream) = listener_select_proto(stream, vec!["/noise"]).await.unwrap();
1233
1234            // do a noise handshake
1235            let keypair = Keypair::generate();
1236            let (stream, _peer) =
1237                noise::handshake(stream.inner(), &keypair, Role::Listener, 5, 2).await.unwrap();
1238            let _stream: NoiseSocket<Compat<TcpStream>> = stream;
1239
1240            tokio::time::sleep(std::time::Duration::from_secs(60)).await;
1241        });
1242
1243        let (_, stream) = TcpTransport::dial_peer(
1244            Multiaddr::empty()
1245                .with(Protocol::from(address.ip()))
1246                .with(Protocol::Tcp(address.port())),
1247            Default::default(),
1248            Duration::from_secs(10),
1249            false,
1250        )
1251        .await
1252        .unwrap();
1253
1254        match TcpConnection::open_connection(
1255            ConnectionId::from(0usize),
1256            Keypair::generate(),
1257            stream,
1258            AddressType::Socket(address),
1259            None,
1260            Default::default(),
1261            5,
1262            2,
1263            Duration::from_secs(10),
1264            Duration::from_secs(10),
1265        )
1266        .await
1267        {
1268            Ok(_) => panic!("connection was supposed to fail"),
1269            Err(NegotiationError::Timeout) => {}
1270            Err(error) => panic!("invalid error: {error:?}"),
1271        }
1272    }
1273}