libp2p_core/transport/
timeout.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Transports with timeouts on the connection setup.
22//!
23//! The connection setup includes all protocol upgrades applied on the
24//! underlying `Transport`.
25// TODO: add example
26
27use crate::transport::DialOpts;
28use crate::{
29    transport::{ListenerId, TransportError, TransportEvent},
30    Multiaddr, Transport,
31};
32use futures::prelude::*;
33use futures_timer::Delay;
34use std::{error, fmt, io, pin::Pin, task::Context, task::Poll, time::Duration};
35
36/// A `TransportTimeout` is a `Transport` that wraps another `Transport` and adds
37/// timeouts to all inbound and outbound connection attempts.
38///
39/// **Note**: `listen_on` is never subject to a timeout, only the setup of each
40/// individual accepted connection.
41#[derive(Debug, Copy, Clone)]
42#[pin_project::pin_project]
43pub struct TransportTimeout<InnerTrans> {
44    #[pin]
45    inner: InnerTrans,
46    outgoing_timeout: Duration,
47    incoming_timeout: Duration,
48}
49
50impl<InnerTrans> TransportTimeout<InnerTrans> {
51    /// Wraps around a `Transport` to add timeouts to all the sockets created by it.
52    pub fn new(trans: InnerTrans, timeout: Duration) -> Self {
53        TransportTimeout {
54            inner: trans,
55            outgoing_timeout: timeout,
56            incoming_timeout: timeout,
57        }
58    }
59
60    /// Wraps around a `Transport` to add timeouts to the outgoing connections.
61    pub fn with_outgoing_timeout(trans: InnerTrans, timeout: Duration) -> Self {
62        TransportTimeout {
63            inner: trans,
64            outgoing_timeout: timeout,
65            incoming_timeout: Duration::from_secs(100 * 365 * 24 * 3600), // 100 years
66        }
67    }
68
69    /// Wraps around a `Transport` to add timeouts to the ingoing connections.
70    pub fn with_ingoing_timeout(trans: InnerTrans, timeout: Duration) -> Self {
71        TransportTimeout {
72            inner: trans,
73            outgoing_timeout: Duration::from_secs(100 * 365 * 24 * 3600), // 100 years
74            incoming_timeout: timeout,
75        }
76    }
77}
78
79impl<InnerTrans> Transport for TransportTimeout<InnerTrans>
80where
81    InnerTrans: Transport,
82    InnerTrans::Error: 'static,
83{
84    type Output = InnerTrans::Output;
85    type Error = TransportTimeoutError<InnerTrans::Error>;
86    type ListenerUpgrade = Timeout<InnerTrans::ListenerUpgrade>;
87    type Dial = Timeout<InnerTrans::Dial>;
88
89    fn listen_on(
90        &mut self,
91        id: ListenerId,
92        addr: Multiaddr,
93    ) -> Result<(), TransportError<Self::Error>> {
94        self.inner
95            .listen_on(id, addr)
96            .map_err(|err| err.map(TransportTimeoutError::Other))
97    }
98
99    fn remove_listener(&mut self, id: ListenerId) -> bool {
100        self.inner.remove_listener(id)
101    }
102
103    fn dial(
104        &mut self,
105        addr: Multiaddr,
106        opts: DialOpts,
107    ) -> Result<Self::Dial, TransportError<Self::Error>> {
108        let dial = self
109            .inner
110            .dial(addr, opts)
111            .map_err(|err| err.map(TransportTimeoutError::Other))?;
112        Ok(Timeout {
113            inner: dial,
114            timer: Delay::new(self.outgoing_timeout),
115        })
116    }
117
118    fn poll(
119        self: Pin<&mut Self>,
120        cx: &mut Context<'_>,
121    ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
122        let this = self.project();
123        let timeout = *this.incoming_timeout;
124        this.inner.poll(cx).map(|event| {
125            event
126                .map_upgrade(move |inner_fut| Timeout {
127                    inner: inner_fut,
128                    timer: Delay::new(timeout),
129                })
130                .map_err(TransportTimeoutError::Other)
131        })
132    }
133}
134
135/// Wraps around a `Future`. Turns the error type from `TimeoutError<Err>` to
136/// `TransportTimeoutError<Err>`.
137// TODO: can be replaced with `impl Future` once `impl Trait` are fully stable in Rust
138//       (https://github.com/rust-lang/rust/issues/34511)
139#[pin_project::pin_project]
140#[must_use = "futures do nothing unless polled"]
141pub struct Timeout<InnerFut> {
142    #[pin]
143    inner: InnerFut,
144    timer: Delay,
145}
146
147impl<InnerFut> Future for Timeout<InnerFut>
148where
149    InnerFut: TryFuture,
150{
151    type Output = Result<InnerFut::Ok, TransportTimeoutError<InnerFut::Error>>;
152
153    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
154        // It is debatable whether we should poll the inner future first or the timer first.
155        // For example, if you start dialing with a timeout of 10 seconds, then after 15 seconds
156        // the dialing succeeds on the wire, then after 20 seconds you poll, then depending on
157        // which gets polled first, the outcome will be success or failure.
158
159        let mut this = self.project();
160
161        match TryFuture::try_poll(this.inner, cx) {
162            Poll::Pending => {}
163            Poll::Ready(Ok(v)) => return Poll::Ready(Ok(v)),
164            Poll::Ready(Err(err)) => return Poll::Ready(Err(TransportTimeoutError::Other(err))),
165        }
166
167        match Pin::new(&mut this.timer).poll(cx) {
168            Poll::Pending => Poll::Pending,
169            Poll::Ready(()) => Poll::Ready(Err(TransportTimeoutError::Timeout)),
170        }
171    }
172}
173
174/// Error that can be produced by the `TransportTimeout` layer.
175#[derive(Debug)]
176pub enum TransportTimeoutError<TErr> {
177    /// The transport timed out.
178    Timeout,
179    /// An error happened in the timer.
180    TimerError(io::Error),
181    /// Other kind of error.
182    Other(TErr),
183}
184
185impl<TErr> fmt::Display for TransportTimeoutError<TErr>
186where
187    TErr: fmt::Display,
188{
189    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190        match self {
191            TransportTimeoutError::Timeout => write!(f, "Timeout has been reached"),
192            TransportTimeoutError::TimerError(err) => write!(f, "Error in the timer: {err}"),
193            TransportTimeoutError::Other(err) => write!(f, "{err}"),
194        }
195    }
196}
197
198impl<TErr> error::Error for TransportTimeoutError<TErr>
199where
200    TErr: error::Error + 'static,
201{
202    fn source(&self) -> Option<&(dyn error::Error + 'static)> {
203        match self {
204            TransportTimeoutError::Timeout => None,
205            TransportTimeoutError::TimerError(err) => Some(err),
206            TransportTimeoutError::Other(err) => Some(err),
207        }
208    }
209}