libp2p_core/transport/
timeout.rs1use crate::transport::DialOpts;
28use crate::{
29 transport::{ListenerId, TransportError, TransportEvent},
30 Multiaddr, Transport,
31};
32use futures::prelude::*;
33use futures_timer::Delay;
34use std::{error, fmt, io, pin::Pin, task::Context, task::Poll, time::Duration};
35
36#[derive(Debug, Copy, Clone)]
42#[pin_project::pin_project]
43pub struct TransportTimeout<InnerTrans> {
44 #[pin]
45 inner: InnerTrans,
46 outgoing_timeout: Duration,
47 incoming_timeout: Duration,
48}
49
50impl<InnerTrans> TransportTimeout<InnerTrans> {
51 pub fn new(trans: InnerTrans, timeout: Duration) -> Self {
53 TransportTimeout {
54 inner: trans,
55 outgoing_timeout: timeout,
56 incoming_timeout: timeout,
57 }
58 }
59
60 pub fn with_outgoing_timeout(trans: InnerTrans, timeout: Duration) -> Self {
62 TransportTimeout {
63 inner: trans,
64 outgoing_timeout: timeout,
65 incoming_timeout: Duration::from_secs(100 * 365 * 24 * 3600), }
67 }
68
69 pub fn with_ingoing_timeout(trans: InnerTrans, timeout: Duration) -> Self {
71 TransportTimeout {
72 inner: trans,
73 outgoing_timeout: Duration::from_secs(100 * 365 * 24 * 3600), incoming_timeout: timeout,
75 }
76 }
77}
78
79impl<InnerTrans> Transport for TransportTimeout<InnerTrans>
80where
81 InnerTrans: Transport,
82 InnerTrans::Error: 'static,
83{
84 type Output = InnerTrans::Output;
85 type Error = TransportTimeoutError<InnerTrans::Error>;
86 type ListenerUpgrade = Timeout<InnerTrans::ListenerUpgrade>;
87 type Dial = Timeout<InnerTrans::Dial>;
88
89 fn listen_on(
90 &mut self,
91 id: ListenerId,
92 addr: Multiaddr,
93 ) -> Result<(), TransportError<Self::Error>> {
94 self.inner
95 .listen_on(id, addr)
96 .map_err(|err| err.map(TransportTimeoutError::Other))
97 }
98
99 fn remove_listener(&mut self, id: ListenerId) -> bool {
100 self.inner.remove_listener(id)
101 }
102
103 fn dial(
104 &mut self,
105 addr: Multiaddr,
106 opts: DialOpts,
107 ) -> Result<Self::Dial, TransportError<Self::Error>> {
108 let dial = self
109 .inner
110 .dial(addr, opts)
111 .map_err(|err| err.map(TransportTimeoutError::Other))?;
112 Ok(Timeout {
113 inner: dial,
114 timer: Delay::new(self.outgoing_timeout),
115 })
116 }
117
118 fn poll(
119 self: Pin<&mut Self>,
120 cx: &mut Context<'_>,
121 ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
122 let this = self.project();
123 let timeout = *this.incoming_timeout;
124 this.inner.poll(cx).map(|event| {
125 event
126 .map_upgrade(move |inner_fut| Timeout {
127 inner: inner_fut,
128 timer: Delay::new(timeout),
129 })
130 .map_err(TransportTimeoutError::Other)
131 })
132 }
133}
134
135#[pin_project::pin_project]
140#[must_use = "futures do nothing unless polled"]
141pub struct Timeout<InnerFut> {
142 #[pin]
143 inner: InnerFut,
144 timer: Delay,
145}
146
147impl<InnerFut> Future for Timeout<InnerFut>
148where
149 InnerFut: TryFuture,
150{
151 type Output = Result<InnerFut::Ok, TransportTimeoutError<InnerFut::Error>>;
152
153 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
154 let mut this = self.project();
160
161 match TryFuture::try_poll(this.inner, cx) {
162 Poll::Pending => {}
163 Poll::Ready(Ok(v)) => return Poll::Ready(Ok(v)),
164 Poll::Ready(Err(err)) => return Poll::Ready(Err(TransportTimeoutError::Other(err))),
165 }
166
167 match Pin::new(&mut this.timer).poll(cx) {
168 Poll::Pending => Poll::Pending,
169 Poll::Ready(()) => Poll::Ready(Err(TransportTimeoutError::Timeout)),
170 }
171 }
172}
173
174#[derive(Debug)]
176pub enum TransportTimeoutError<TErr> {
177 Timeout,
179 TimerError(io::Error),
181 Other(TErr),
183}
184
185impl<TErr> fmt::Display for TransportTimeoutError<TErr>
186where
187 TErr: fmt::Display,
188{
189 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190 match self {
191 TransportTimeoutError::Timeout => write!(f, "Timeout has been reached"),
192 TransportTimeoutError::TimerError(err) => write!(f, "Error in the timer: {err}"),
193 TransportTimeoutError::Other(err) => write!(f, "{err}"),
194 }
195 }
196}
197
198impl<TErr> error::Error for TransportTimeoutError<TErr>
199where
200 TErr: error::Error + 'static,
201{
202 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
203 match self {
204 TransportTimeoutError::Timeout => None,
205 TransportTimeoutError::TimerError(err) => Some(err),
206 TransportTimeoutError::Other(err) => Some(err),
207 }
208 }
209}