libp2p_swarm/connection/pool/
concurrent_dial.rs

1// Copyright 2021 Protocol Labs.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::{transport::TransportError, Multiaddr};
22use futures::{
23    future::{BoxFuture, Future},
24    ready,
25    stream::{FuturesUnordered, StreamExt},
26};
27use libp2p_core::muxing::StreamMuxerBox;
28use libp2p_identity::PeerId;
29use std::{
30    num::NonZeroU8,
31    pin::Pin,
32    task::{Context, Poll},
33};
34
35type Dial = BoxFuture<
36    'static,
37    (
38        Multiaddr,
39        Result<(PeerId, StreamMuxerBox), TransportError<std::io::Error>>,
40    ),
41>;
42
43pub(crate) struct ConcurrentDial {
44    dials: FuturesUnordered<Dial>,
45    pending_dials: Box<dyn Iterator<Item = Dial> + Send>,
46    errors: Vec<(Multiaddr, TransportError<std::io::Error>)>,
47}
48
49impl Unpin for ConcurrentDial {}
50
51impl ConcurrentDial {
52    pub(crate) fn new(pending_dials: Vec<Dial>, concurrency_factor: NonZeroU8) -> Self {
53        let mut pending_dials = pending_dials.into_iter();
54
55        let dials = FuturesUnordered::new();
56        for dial in pending_dials.by_ref() {
57            dials.push(dial);
58            if dials.len() == concurrency_factor.get() as usize {
59                break;
60            }
61        }
62
63        Self {
64            dials,
65            errors: Default::default(),
66            pending_dials: Box::new(pending_dials),
67        }
68    }
69}
70
71impl Future for ConcurrentDial {
72    type Output = Result<
73        // Either one dial succeeded, returning the negotiated [`PeerId`], the address, the
74        // muxer and the addresses and errors of the dials that failed before.
75        (
76            Multiaddr,
77            (PeerId, StreamMuxerBox),
78            Vec<(Multiaddr, TransportError<std::io::Error>)>,
79        ),
80        // Or all dials failed, thus returning the address and error for each dial.
81        Vec<(Multiaddr, TransportError<std::io::Error>)>,
82    >;
83
84    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
85        loop {
86            match ready!(self.dials.poll_next_unpin(cx)) {
87                Some((addr, Ok(output))) => {
88                    let errors = std::mem::take(&mut self.errors);
89                    return Poll::Ready(Ok((addr, output, errors)));
90                }
91                Some((addr, Err(e))) => {
92                    self.errors.push((addr, e));
93                    if let Some(dial) = self.pending_dials.next() {
94                        self.dials.push(dial)
95                    }
96                }
97                None => {
98                    return Poll::Ready(Err(std::mem::take(&mut self.errors)));
99                }
100            }
101        }
102    }
103}