libp2p_swarm/connection/pool/concurrent_dial.rs
1// Copyright 2021 Protocol Labs.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::{transport::TransportError, Multiaddr};
22use futures::{
23 future::{BoxFuture, Future},
24 ready,
25 stream::{FuturesUnordered, StreamExt},
26};
27use libp2p_core::muxing::StreamMuxerBox;
28use libp2p_identity::PeerId;
29use std::{
30 num::NonZeroU8,
31 pin::Pin,
32 task::{Context, Poll},
33};
34
35type Dial = BoxFuture<
36 'static,
37 (
38 Multiaddr,
39 Result<(PeerId, StreamMuxerBox), TransportError<std::io::Error>>,
40 ),
41>;
42
43pub(crate) struct ConcurrentDial {
44 dials: FuturesUnordered<Dial>,
45 pending_dials: Box<dyn Iterator<Item = Dial> + Send>,
46 errors: Vec<(Multiaddr, TransportError<std::io::Error>)>,
47}
48
49impl Unpin for ConcurrentDial {}
50
51impl ConcurrentDial {
52 pub(crate) fn new(pending_dials: Vec<Dial>, concurrency_factor: NonZeroU8) -> Self {
53 let mut pending_dials = pending_dials.into_iter();
54
55 let dials = FuturesUnordered::new();
56 for dial in pending_dials.by_ref() {
57 dials.push(dial);
58 if dials.len() == concurrency_factor.get() as usize {
59 break;
60 }
61 }
62
63 Self {
64 dials,
65 errors: Default::default(),
66 pending_dials: Box::new(pending_dials),
67 }
68 }
69}
70
71impl Future for ConcurrentDial {
72 type Output = Result<
73 // Either one dial succeeded, returning the negotiated [`PeerId`], the address, the
74 // muxer and the addresses and errors of the dials that failed before.
75 (
76 Multiaddr,
77 (PeerId, StreamMuxerBox),
78 Vec<(Multiaddr, TransportError<std::io::Error>)>,
79 ),
80 // Or all dials failed, thus returning the address and error for each dial.
81 Vec<(Multiaddr, TransportError<std::io::Error>)>,
82 >;
83
84 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
85 loop {
86 match ready!(self.dials.poll_next_unpin(cx)) {
87 Some((addr, Ok(output))) => {
88 let errors = std::mem::take(&mut self.errors);
89 return Poll::Ready(Ok((addr, output, errors)));
90 }
91 Some((addr, Err(e))) => {
92 self.errors.push((addr, e));
93 if let Some(dial) = self.pending_dials.next() {
94 self.dials.push(dial)
95 }
96 }
97 None => {
98 return Poll::Ready(Err(std::mem::take(&mut self.errors)));
99 }
100 }
101 }
102 }
103}