1use crate::{
25 error::{DialError, Error},
26 transport::{
27 common::listener::{DialAddresses, GetSocketAddr, SocketListener, TcpAddress},
28 manager::TransportHandle,
29 tcp::{
30 config::Config,
31 connection::{NegotiatedConnection, TcpConnection},
32 },
33 Transport, TransportBuilder, TransportEvent,
34 },
35 types::ConnectionId,
36 utils::futures_stream::FuturesStream,
37};
38
39use futures::{
40 future::BoxFuture,
41 stream::{AbortHandle, FuturesUnordered, Stream, StreamExt},
42 TryFutureExt,
43};
44use hickory_resolver::TokioResolver;
45use multiaddr::Multiaddr;
46use socket2::{Domain, Socket, Type};
47use tokio::net::TcpStream;
48
49use std::{
50 collections::HashMap,
51 net::SocketAddr,
52 pin::Pin,
53 sync::Arc,
54 task::{Context, Poll},
55 time::Duration,
56};
57
58pub(crate) use substream::Substream;
59
60mod connection;
61mod substream;
62
63pub mod config;
64
65const LOG_TARGET: &str = "litep2p::tcp";
67
68struct PendingInboundConnection {
70 connection: TcpStream,
72 address: SocketAddr,
74}
75
76#[derive(Debug)]
77enum RawConnectionResult {
78 Connected {
80 negotiated: NegotiatedConnection,
81 errors: Vec<(Multiaddr, DialError)>,
82 },
83
84 Failed {
86 connection_id: ConnectionId,
87 errors: Vec<(Multiaddr, DialError)>,
88 },
89
90 Canceled { connection_id: ConnectionId },
92}
93
94pub(crate) struct TcpTransport {
96 context: TransportHandle,
98
99 config: Config,
101
102 listener: SocketListener,
104
105 pending_dials: HashMap<ConnectionId, Multiaddr>,
107
108 dial_addresses: DialAddresses,
110
111 pending_inbound_connections: HashMap<ConnectionId, PendingInboundConnection>,
113
114 pending_connections:
116 FuturesStream<BoxFuture<'static, Result<NegotiatedConnection, (ConnectionId, DialError)>>>,
117
118 pending_raw_connections: FuturesStream<BoxFuture<'static, RawConnectionResult>>,
120
121 opened: HashMap<ConnectionId, NegotiatedConnection>,
123
124 cancel_futures: HashMap<ConnectionId, AbortHandle>,
128
129 pending_open: HashMap<ConnectionId, NegotiatedConnection>,
132
133 resolver: Arc<TokioResolver>,
135}
136
137impl TcpTransport {
138 fn on_inbound_connection(
140 &mut self,
141 connection_id: ConnectionId,
142 connection: TcpStream,
143 address: SocketAddr,
144 ) {
145 let yamux_config = self.config.yamux_config.clone();
146 let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
147 let max_write_buffer_size = self.config.noise_write_buffer_size;
148 let connection_open_timeout = self.config.connection_open_timeout;
149 let substream_open_timeout = self.config.substream_open_timeout;
150 let keypair = self.context.keypair.clone();
151
152 tracing::trace!(
153 target: LOG_TARGET,
154 ?connection_id,
155 ?address,
156 "accept connection",
157 );
158
159 self.pending_connections.push(Box::pin(async move {
160 TcpConnection::accept_connection(
161 connection,
162 connection_id,
163 keypair,
164 address,
165 yamux_config,
166 max_read_ahead_factor,
167 max_write_buffer_size,
168 connection_open_timeout,
169 substream_open_timeout,
170 )
171 .await
172 .map_err(|error| (connection_id, error.into()))
173 }));
174 }
175
176 async fn dial_peer(
178 address: Multiaddr,
179 dial_addresses: DialAddresses,
180 connection_open_timeout: Duration,
181 nodelay: bool,
182 resolver: Arc<TokioResolver>,
183 ) -> Result<(Multiaddr, TcpStream), DialError> {
184 let (socket_address, _) = TcpAddress::multiaddr_to_socket_address(&address)?;
185
186 let remote_address =
187 match tokio::time::timeout(connection_open_timeout, socket_address.lookup_ip(resolver))
188 .await
189 {
190 Err(_) => {
191 tracing::debug!(
192 target: LOG_TARGET,
193 ?address,
194 ?connection_open_timeout,
195 "failed to resolve address within timeout",
196 );
197 return Err(DialError::Timeout);
198 }
199 Ok(Err(error)) => return Err(error.into()),
200 Ok(Ok(address)) => address,
201 };
202
203 let domain = match remote_address.is_ipv4() {
204 true => Domain::IPV4,
205 false => Domain::IPV6,
206 };
207 let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?;
208 if remote_address.is_ipv6() {
209 socket.set_only_v6(true)?;
210 }
211 socket.set_nonblocking(true)?;
212 socket.set_nodelay(nodelay)?;
213
214 match dial_addresses.local_dial_address(&remote_address.ip()) {
215 Ok(Some(dial_address)) => {
216 socket.set_reuse_address(true)?;
217 #[cfg(unix)]
218 socket.set_reuse_port(true)?;
219 socket.bind(&dial_address.into())?;
220 }
221 Ok(None) => {}
222 Err(()) => {
223 tracing::debug!(
224 target: LOG_TARGET,
225 ?remote_address,
226 "tcp listener not enabled for remote address, using ephemeral port",
227 );
228 }
229 }
230
231 let future = async move {
232 match socket.connect(&remote_address.into()) {
233 Ok(()) => {}
234 Err(err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {}
235 Err(err) if err.kind() == std::io::ErrorKind::WouldBlock => {}
236 Err(err) => return Err(err),
237 }
238
239 let stream = TcpStream::try_from(Into::<std::net::TcpStream>::into(socket))?;
240 stream.writable().await?;
241
242 if let Some(e) = stream.take_error()? {
243 return Err(e);
244 }
245
246 Ok((address, stream))
247 };
248
249 match tokio::time::timeout(connection_open_timeout, future).await {
250 Err(_) => {
251 tracing::debug!(
252 target: LOG_TARGET,
253 ?connection_open_timeout,
254 "failed to connect within timeout",
255 );
256 Err(DialError::Timeout)
257 }
258 Ok(Err(error)) => Err(error.into()),
259 Ok(Ok((address, stream))) => {
260 tracing::debug!(
261 target: LOG_TARGET,
262 ?address,
263 "connected",
264 );
265
266 Ok((address, stream))
267 }
268 }
269 }
270}
271
272impl TransportBuilder for TcpTransport {
273 type Config = Config;
274 type Transport = TcpTransport;
275
276 fn new(
278 context: TransportHandle,
279 mut config: Self::Config,
280 resolver: Arc<TokioResolver>,
281 ) -> crate::Result<(Self, Vec<Multiaddr>)> {
282 tracing::debug!(
283 target: LOG_TARGET,
284 listen_addresses = ?config.listen_addresses,
285 "start tcp transport",
286 );
287
288 let (listener, listen_addresses, dial_addresses) = SocketListener::new::<TcpAddress>(
290 std::mem::take(&mut config.listen_addresses),
291 config.reuse_port,
292 config.nodelay,
293 );
294
295 Ok((
296 Self {
297 listener,
298 config,
299 context,
300 dial_addresses,
301 opened: HashMap::new(),
302 pending_open: HashMap::new(),
303 pending_dials: HashMap::new(),
304 pending_inbound_connections: HashMap::new(),
305 pending_connections: FuturesStream::new(),
306 pending_raw_connections: FuturesStream::new(),
307 cancel_futures: HashMap::new(),
308 resolver,
309 },
310 listen_addresses,
311 ))
312 }
313}
314
315impl Transport for TcpTransport {
316 fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> {
317 tracing::debug!(target: LOG_TARGET, ?connection_id, ?address, "open connection");
318
319 let (socket_address, peer) = TcpAddress::multiaddr_to_socket_address(&address)?;
320 let yamux_config = self.config.yamux_config.clone();
321 let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
322 let max_write_buffer_size = self.config.noise_write_buffer_size;
323 let connection_open_timeout = self.config.connection_open_timeout;
324 let substream_open_timeout = self.config.substream_open_timeout;
325 let dial_addresses = self.dial_addresses.clone();
326 let keypair = self.context.keypair.clone();
327 let nodelay = self.config.nodelay;
328 let resolver = self.resolver.clone();
329
330 self.pending_dials.insert(connection_id, address.clone());
331 self.pending_connections.push(Box::pin(async move {
332 let (_, stream) = TcpTransport::dial_peer(
333 address,
334 dial_addresses,
335 connection_open_timeout,
336 nodelay,
337 resolver,
338 )
339 .await
340 .map_err(|error| (connection_id, error))?;
341
342 TcpConnection::open_connection(
343 connection_id,
344 keypair,
345 stream,
346 socket_address,
347 peer,
348 yamux_config,
349 max_read_ahead_factor,
350 max_write_buffer_size,
351 connection_open_timeout,
352 substream_open_timeout,
353 )
354 .await
355 .map_err(|error| (connection_id, error.into()))
356 }));
357
358 Ok(())
359 }
360
361 fn accept_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
362 let pending = self.pending_inbound_connections.remove(&connection_id).ok_or_else(|| {
363 tracing::error!(
364 target: LOG_TARGET,
365 ?connection_id,
366 "Cannot accept non existent pending connection",
367 );
368
369 Error::ConnectionDoesntExist(connection_id)
370 })?;
371
372 self.on_inbound_connection(connection_id, pending.connection, pending.address);
373
374 Ok(())
375 }
376
377 fn reject_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
378 self.pending_inbound_connections.remove(&connection_id).map_or_else(
379 || {
380 tracing::error!(
381 target: LOG_TARGET,
382 ?connection_id,
383 "Cannot reject non existent pending connection",
384 );
385
386 Err(Error::ConnectionDoesntExist(connection_id))
387 },
388 |_| Ok(()),
389 )
390 }
391
392 fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
393 let context = self
394 .pending_open
395 .remove(&connection_id)
396 .ok_or(Error::ConnectionDoesntExist(connection_id))?;
397 let protocol_set = self.context.protocol_set(connection_id);
398 let bandwidth_sink = self.context.bandwidth_sink.clone();
399 let next_substream_id = self.context.next_substream_id.clone();
400
401 tracing::trace!(
402 target: LOG_TARGET,
403 ?connection_id,
404 "start connection",
405 );
406
407 self.context.executor.run(Box::pin(async move {
408 if let Err(error) =
409 TcpConnection::new(context, protocol_set, bandwidth_sink, next_substream_id)
410 .start()
411 .await
412 {
413 tracing::debug!(
414 target: LOG_TARGET,
415 ?connection_id,
416 ?error,
417 "connection exited with error",
418 );
419 }
420 }));
421
422 Ok(())
423 }
424
425 fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
426 self.pending_open
427 .remove(&connection_id)
428 .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(()))
429 }
430
431 fn open(
432 &mut self,
433 connection_id: ConnectionId,
434 addresses: Vec<Multiaddr>,
435 ) -> crate::Result<()> {
436 let num_addresses = addresses.len();
437 let mut futures: FuturesUnordered<_> = addresses
438 .into_iter()
439 .map(|address| {
440 let yamux_config = self.config.yamux_config.clone();
441 let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
442 let max_write_buffer_size = self.config.noise_write_buffer_size;
443 let connection_open_timeout = self.config.connection_open_timeout;
444 let substream_open_timeout = self.config.substream_open_timeout;
445 let dial_addresses = self.dial_addresses.clone();
446 let keypair = self.context.keypair.clone();
447 let nodelay = self.config.nodelay;
448 let resolver = self.resolver.clone();
449
450 async move {
451 let (address, stream) = TcpTransport::dial_peer(
452 address.clone(),
453 dial_addresses,
454 connection_open_timeout,
455 nodelay,
456 resolver,
457 )
458 .await
459 .map_err(|error| (address, error))?;
460
461 let open_address = address.clone();
462 let (socket_address, peer) = TcpAddress::multiaddr_to_socket_address(&address)
463 .map_err(|error| (address, error.into()))?;
464
465 TcpConnection::open_connection(
466 connection_id,
467 keypair,
468 stream,
469 socket_address,
470 peer,
471 yamux_config,
472 max_read_ahead_factor,
473 max_write_buffer_size,
474 connection_open_timeout,
475 substream_open_timeout,
476 )
477 .await
478 .map_err(|error| (open_address, error.into()))
479 }
480 })
481 .collect();
482
483 let future = async move {
485 let mut errors = Vec::with_capacity(num_addresses);
486 while let Some(result) = futures.next().await {
487 match result {
488 Ok(negotiated) => return RawConnectionResult::Connected { negotiated, errors },
489 Err(error) => {
490 tracing::debug!(
491 target: LOG_TARGET,
492 ?connection_id,
493 ?error,
494 "failed to open connection",
495 );
496 errors.push(error)
497 }
498 }
499 }
500
501 RawConnectionResult::Failed {
502 connection_id,
503 errors,
504 }
505 };
506
507 let (fut, handle) = futures::future::abortable(future);
508 let fut = fut.unwrap_or_else(move |_| RawConnectionResult::Canceled { connection_id });
509 self.pending_raw_connections.push(Box::pin(fut));
510 self.cancel_futures.insert(connection_id, handle);
511
512 Ok(())
513 }
514
515 fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
516 let negotiated = self
517 .opened
518 .remove(&connection_id)
519 .ok_or(Error::ConnectionDoesntExist(connection_id))?;
520
521 self.pending_connections.push(Box::pin(async move { Ok(negotiated) }));
522
523 Ok(())
524 }
525
526 fn cancel(&mut self, connection_id: ConnectionId) {
527 if let Some(handle) = self.cancel_futures.get(&connection_id) {
530 handle.abort();
531 }
532 }
533}
534
535impl Stream for TcpTransport {
536 type Item = TransportEvent;
537
538 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
539 if let Poll::Ready(event) = self.listener.poll_next_unpin(cx) {
540 return match event {
541 None => {
542 tracing::error!(
543 target: LOG_TARGET,
544 "TCP listener terminated, ignore if the node is stopping",
545 );
546
547 Poll::Ready(None)
548 }
549 Some(Err(error)) => {
550 tracing::error!(
551 target: LOG_TARGET,
552 ?error,
553 "TCP listener terminated with error",
554 );
555
556 Poll::Ready(None)
557 }
558 Some(Ok((connection, address))) => {
559 let connection_id = self.context.next_connection_id();
560 tracing::trace!(
561 target: LOG_TARGET,
562 ?connection_id,
563 ?address,
564 "pending inbound TCP connection",
565 );
566
567 self.pending_inbound_connections.insert(
568 connection_id,
569 PendingInboundConnection {
570 connection,
571 address,
572 },
573 );
574
575 Poll::Ready(Some(TransportEvent::PendingInboundConnection {
576 connection_id,
577 }))
578 }
579 };
580 }
581
582 while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) {
583 tracing::trace!(target: LOG_TARGET, ?result, "raw connection result");
584
585 match result {
586 RawConnectionResult::Connected { negotiated, errors } => {
587 let Some(handle) = self.cancel_futures.remove(&negotiated.connection_id())
588 else {
589 tracing::warn!(
590 target: LOG_TARGET,
591 connection_id = ?negotiated.connection_id(),
592 address = ?negotiated.endpoint().address(),
593 ?errors,
594 "raw connection without a cancel handle",
595 );
596 continue;
597 };
598
599 if !handle.is_aborted() {
600 let connection_id = negotiated.connection_id();
601 let address = negotiated.endpoint().address().clone();
602
603 self.opened.insert(connection_id, negotiated);
604
605 return Poll::Ready(Some(TransportEvent::ConnectionOpened {
606 connection_id,
607 address,
608 }));
609 }
610 }
611
612 RawConnectionResult::Failed {
613 connection_id,
614 errors,
615 } => {
616 let Some(handle) = self.cancel_futures.remove(&connection_id) else {
617 tracing::warn!(
618 target: LOG_TARGET,
619 ?connection_id,
620 ?errors,
621 "raw connection without a cancel handle",
622 );
623 continue;
624 };
625
626 if !handle.is_aborted() {
627 return Poll::Ready(Some(TransportEvent::OpenFailure {
628 connection_id,
629 errors,
630 }));
631 }
632 }
633 RawConnectionResult::Canceled { connection_id } => {
634 if self.cancel_futures.remove(&connection_id).is_none() {
635 tracing::warn!(
636 target: LOG_TARGET,
637 ?connection_id,
638 "raw cancelled connection without a cancel handle",
639 );
640 }
641 }
642 }
643 }
644
645 while let Poll::Ready(Some(connection)) = self.pending_connections.poll_next_unpin(cx) {
646 match connection {
647 Ok(connection) => {
648 let peer = connection.peer();
649 let endpoint = connection.endpoint();
650 self.pending_dials.remove(&connection.connection_id());
651 self.pending_open.insert(connection.connection_id(), connection);
652
653 return Poll::Ready(Some(TransportEvent::ConnectionEstablished {
654 peer,
655 endpoint,
656 }));
657 }
658 Err((connection_id, error)) => {
659 if let Some(address) = self.pending_dials.remove(&connection_id) {
660 return Poll::Ready(Some(TransportEvent::DialFailure {
661 connection_id,
662 address,
663 error,
664 }));
665 } else {
666 tracing::debug!(target: LOG_TARGET, ?error, ?connection_id, "Pending inbound connection failed");
667 }
668 }
669 }
670 }
671
672 Poll::Pending
673 }
674}
675
676#[cfg(test)]
677mod tests {
678 use super::*;
679 use crate::{
680 codec::ProtocolCodec,
681 crypto::ed25519::Keypair,
682 executor::DefaultExecutor,
683 transport::manager::{ProtocolContext, SupportedTransport, TransportManagerBuilder},
684 types::protocol::ProtocolName,
685 BandwidthSink, PeerId,
686 };
687 use multiaddr::Protocol;
688 use multihash::Multihash;
689 use std::sync::Arc;
690 use tokio::sync::mpsc::channel;
691
692 #[tokio::test]
693 async fn connect_and_accept_works() {
694 let _ = tracing_subscriber::fmt()
695 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
696 .try_init();
697
698 let keypair1 = Keypair::generate();
699 let (tx1, _rx1) = channel(64);
700 let (event_tx1, _event_rx1) = channel(64);
701 let bandwidth_sink = BandwidthSink::new();
702
703 let handle1 = crate::transport::manager::TransportHandle {
704 executor: Arc::new(DefaultExecutor {}),
705 next_substream_id: Default::default(),
706 next_connection_id: Default::default(),
707 keypair: keypair1.clone(),
708 tx: event_tx1,
709 bandwidth_sink: bandwidth_sink.clone(),
710
711 protocols: HashMap::from_iter([(
712 ProtocolName::from("/notif/1"),
713 ProtocolContext {
714 tx: tx1,
715 codec: ProtocolCodec::Identity(32),
716 fallback_names: Vec::new(),
717 },
718 )]),
719 };
720 let transport_config1 = Config {
721 listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()],
722 ..Default::default()
723 };
724 let resolver = Arc::new(TokioResolver::builder_tokio().unwrap().build());
725
726 let (mut transport1, listen_addresses) =
727 TcpTransport::new(handle1, transport_config1, resolver.clone()).unwrap();
728 let listen_address = listen_addresses[0].clone();
729
730 let keypair2 = Keypair::generate();
731 let (tx2, _rx2) = channel(64);
732 let (event_tx2, _event_rx2) = channel(64);
733
734 let handle2 = crate::transport::manager::TransportHandle {
735 executor: Arc::new(DefaultExecutor {}),
736 next_substream_id: Default::default(),
737 next_connection_id: Default::default(),
738 keypair: keypair2.clone(),
739 tx: event_tx2,
740 bandwidth_sink: bandwidth_sink.clone(),
741
742 protocols: HashMap::from_iter([(
743 ProtocolName::from("/notif/1"),
744 ProtocolContext {
745 tx: tx2,
746 codec: ProtocolCodec::Identity(32),
747 fallback_names: Vec::new(),
748 },
749 )]),
750 };
751 let transport_config2 = Config {
752 listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()],
753 ..Default::default()
754 };
755
756 let (mut transport2, _) = TcpTransport::new(handle2, transport_config2, resolver).unwrap();
757 transport2.dial(ConnectionId::new(), listen_address).unwrap();
758
759 let (tx, mut from_transport2) = channel(64);
760 tokio::spawn(async move {
761 let event = transport2.next().await;
762 tx.send(event).await.unwrap();
763 });
764
765 let event = transport1.next().await.unwrap();
766 match event {
767 TransportEvent::PendingInboundConnection { connection_id } => {
768 transport1.accept_pending(connection_id).unwrap();
769 }
770 _ => panic!("unexpected event"),
771 }
772
773 let event = transport1.next().await;
774 assert!(std::matches!(
775 event,
776 Some(TransportEvent::ConnectionEstablished { .. })
777 ));
778
779 let event = from_transport2.recv().await.unwrap();
780 assert!(std::matches!(
781 event,
782 Some(TransportEvent::ConnectionEstablished { .. })
783 ));
784 }
785
786 #[tokio::test]
787 async fn connect_and_reject_works() {
788 let _ = tracing_subscriber::fmt()
789 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
790 .try_init();
791
792 let keypair1 = Keypair::generate();
793 let (tx1, _rx1) = channel(64);
794 let (event_tx1, _event_rx1) = channel(64);
795 let bandwidth_sink = BandwidthSink::new();
796
797 let handle1 = crate::transport::manager::TransportHandle {
798 executor: Arc::new(DefaultExecutor {}),
799 next_substream_id: Default::default(),
800 next_connection_id: Default::default(),
801 keypair: keypair1.clone(),
802 tx: event_tx1,
803 bandwidth_sink: bandwidth_sink.clone(),
804
805 protocols: HashMap::from_iter([(
806 ProtocolName::from("/notif/1"),
807 ProtocolContext {
808 tx: tx1,
809 codec: ProtocolCodec::Identity(32),
810 fallback_names: Vec::new(),
811 },
812 )]),
813 };
814 let transport_config1 = Config {
815 listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()],
816 ..Default::default()
817 };
818 let resolver = Arc::new(TokioResolver::builder_tokio().unwrap().build());
819
820 let (mut transport1, listen_addresses) =
821 TcpTransport::new(handle1, transport_config1, resolver.clone()).unwrap();
822 let listen_address = listen_addresses[0].clone();
823
824 let keypair2 = Keypair::generate();
825 let (tx2, _rx2) = channel(64);
826 let (event_tx2, _event_rx2) = channel(64);
827
828 let handle2 = crate::transport::manager::TransportHandle {
829 executor: Arc::new(DefaultExecutor {}),
830 next_substream_id: Default::default(),
831 next_connection_id: Default::default(),
832 keypair: keypair2.clone(),
833 tx: event_tx2,
834 bandwidth_sink: bandwidth_sink.clone(),
835
836 protocols: HashMap::from_iter([(
837 ProtocolName::from("/notif/1"),
838 ProtocolContext {
839 tx: tx2,
840 codec: ProtocolCodec::Identity(32),
841 fallback_names: Vec::new(),
842 },
843 )]),
844 };
845 let transport_config2 = Config {
846 listen_addresses: vec!["/ip6/::1/tcp/0".parse().unwrap()],
847 ..Default::default()
848 };
849
850 let (mut transport2, _) = TcpTransport::new(handle2, transport_config2, resolver).unwrap();
851 transport2.dial(ConnectionId::new(), listen_address).unwrap();
852
853 let (tx, mut from_transport2) = channel(64);
854 tokio::spawn(async move {
855 let event = transport2.next().await;
856 tx.send(event).await.unwrap();
857 });
858
859 let event = transport1.next().await.unwrap();
861 match event {
862 TransportEvent::PendingInboundConnection { connection_id } => {
863 transport1.reject_pending(connection_id).unwrap();
864 }
865 _ => panic!("unexpected event"),
866 }
867
868 let event = from_transport2.recv().await.unwrap();
869 assert!(std::matches!(
870 event,
871 Some(TransportEvent::DialFailure { .. })
872 ));
873 }
874
875 #[tokio::test]
876 async fn dial_failure() {
877 let _ = tracing_subscriber::fmt()
878 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
879 .try_init();
880
881 let keypair1 = Keypair::generate();
882 let (tx1, _rx1) = channel(64);
883 let (event_tx1, mut event_rx1) = channel(64);
884 let bandwidth_sink = BandwidthSink::new();
885
886 let handle1 = crate::transport::manager::TransportHandle {
887 executor: Arc::new(DefaultExecutor {}),
888 next_substream_id: Default::default(),
889 next_connection_id: Default::default(),
890 keypair: keypair1.clone(),
891 tx: event_tx1,
892 bandwidth_sink: bandwidth_sink.clone(),
893
894 protocols: HashMap::from_iter([(
895 ProtocolName::from("/notif/1"),
896 ProtocolContext {
897 tx: tx1,
898 codec: ProtocolCodec::Identity(32),
899 fallback_names: Vec::new(),
900 },
901 )]),
902 };
903 let resolver = Arc::new(TokioResolver::builder_tokio().unwrap().build());
904 let (mut transport1, _) =
905 TcpTransport::new(handle1, Default::default(), resolver.clone()).unwrap();
906
907 tokio::spawn(async move {
908 while let Some(event) = transport1.next().await {
909 match event {
910 TransportEvent::ConnectionEstablished { .. } => {}
911 TransportEvent::ConnectionClosed { .. } => {}
912 TransportEvent::DialFailure { .. } => {}
913 TransportEvent::ConnectionOpened { .. } => {}
914 TransportEvent::OpenFailure { .. } => {}
915 TransportEvent::PendingInboundConnection { .. } => {}
916 }
917 }
918 });
919
920 let keypair2 = Keypair::generate();
921 let (tx2, _rx2) = channel(64);
922 let (event_tx2, _event_rx2) = channel(64);
923
924 let handle2 = crate::transport::manager::TransportHandle {
925 executor: Arc::new(DefaultExecutor {}),
926 next_substream_id: Default::default(),
927 next_connection_id: Default::default(),
928 keypair: keypair2.clone(),
929 tx: event_tx2,
930 bandwidth_sink: bandwidth_sink.clone(),
931
932 protocols: HashMap::from_iter([(
933 ProtocolName::from("/notif/1"),
934 ProtocolContext {
935 tx: tx2,
936 codec: ProtocolCodec::Identity(32),
937 fallback_names: Vec::new(),
938 },
939 )]),
940 };
941
942 let (mut transport2, _) = TcpTransport::new(handle2, Default::default(), resolver).unwrap();
943
944 let peer1: PeerId = PeerId::from_public_key(&keypair1.public().into());
945 let peer2: PeerId = PeerId::from_public_key(&keypair2.public().into());
946
947 tracing::info!(target: LOG_TARGET, "peer1 {peer1}, peer2 {peer2}");
948
949 let address = Multiaddr::empty()
950 .with(Protocol::Ip6(std::net::Ipv6Addr::new(
951 0, 0, 0, 0, 0, 0, 0, 1,
952 )))
953 .with(Protocol::Tcp(8888))
954 .with(Protocol::P2p(
955 Multihash::from_bytes(&peer1.to_bytes()).unwrap(),
956 ));
957
958 transport2.dial(ConnectionId::new(), address).unwrap();
959
960 tokio::spawn(async move {
962 loop {
963 let _ = event_rx1.recv().await;
964 }
965 });
966
967 assert!(std::matches!(
968 transport2.next().await,
969 Some(TransportEvent::DialFailure { .. })
970 ));
971 }
972
973 #[tokio::test]
974 async fn dial_error_reported_for_outbound_connections() {
975 let mut manager = TransportManagerBuilder::new().build();
976 let handle = manager.transport_handle(Arc::new(DefaultExecutor {}));
977 let resolver = Arc::new(TokioResolver::builder_tokio().unwrap().build());
978 manager.register_transport(
979 SupportedTransport::Tcp,
980 Box::new(crate::transport::dummy::DummyTransport::new()),
981 );
982 let (mut transport, _) = TcpTransport::new(
983 handle,
984 Config {
985 listen_addresses: vec!["/ip4/127.0.0.1/tcp/0".parse().unwrap()],
986 ..Default::default()
987 },
988 resolver,
989 )
990 .unwrap();
991
992 let keypair = Keypair::generate();
993 let peer_id = PeerId::from_public_key(&keypair.public().into());
994 let multiaddr = Multiaddr::empty()
995 .with(Protocol::Ip4(std::net::Ipv4Addr::new(255, 254, 253, 252)))
996 .with(Protocol::Tcp(8888))
997 .with(Protocol::P2p(
998 Multihash::from_bytes(&peer_id.to_bytes()).unwrap(),
999 ));
1000 manager.dial_address(multiaddr.clone()).await.unwrap();
1001
1002 assert!(transport.pending_dials.is_empty());
1003
1004 match transport.dial(ConnectionId::from(0usize), multiaddr) {
1005 Ok(()) => {}
1006 _ => panic!("invalid result for `on_dial_peer()`"),
1007 }
1008
1009 assert!(!transport.pending_dials.is_empty());
1010 transport.pending_connections.push(Box::pin(async move {
1011 Err((ConnectionId::from(0usize), DialError::Timeout))
1012 }));
1013
1014 assert!(std::matches!(
1015 transport.next().await,
1016 Some(TransportEvent::DialFailure { .. })
1017 ));
1018 assert!(transport.pending_dials.is_empty());
1019 }
1020}