libp2p_core/transport/
timeout.rs1use crate::{
28 transport::{ListenerId, TransportError, TransportEvent},
29 Multiaddr, Transport,
30};
31use futures::prelude::*;
32use futures_timer::Delay;
33use std::{error, fmt, io, pin::Pin, task::Context, task::Poll, time::Duration};
34
35#[derive(Debug, Copy, Clone)]
41#[pin_project::pin_project]
42pub struct TransportTimeout<InnerTrans> {
43 #[pin]
44 inner: InnerTrans,
45 outgoing_timeout: Duration,
46 incoming_timeout: Duration,
47}
48
49impl<InnerTrans> TransportTimeout<InnerTrans> {
50 pub fn new(trans: InnerTrans, timeout: Duration) -> Self {
52 TransportTimeout {
53 inner: trans,
54 outgoing_timeout: timeout,
55 incoming_timeout: timeout,
56 }
57 }
58
59 pub fn with_outgoing_timeout(trans: InnerTrans, timeout: Duration) -> Self {
61 TransportTimeout {
62 inner: trans,
63 outgoing_timeout: timeout,
64 incoming_timeout: Duration::from_secs(100 * 365 * 24 * 3600), }
66 }
67
68 pub fn with_ingoing_timeout(trans: InnerTrans, timeout: Duration) -> Self {
70 TransportTimeout {
71 inner: trans,
72 outgoing_timeout: Duration::from_secs(100 * 365 * 24 * 3600), incoming_timeout: timeout,
74 }
75 }
76}
77
78impl<InnerTrans> Transport for TransportTimeout<InnerTrans>
79where
80 InnerTrans: Transport,
81 InnerTrans::Error: 'static,
82{
83 type Output = InnerTrans::Output;
84 type Error = TransportTimeoutError<InnerTrans::Error>;
85 type ListenerUpgrade = Timeout<InnerTrans::ListenerUpgrade>;
86 type Dial = Timeout<InnerTrans::Dial>;
87
88 fn listen_on(
89 &mut self,
90 id: ListenerId,
91 addr: Multiaddr,
92 ) -> Result<(), TransportError<Self::Error>> {
93 self.inner
94 .listen_on(id, addr)
95 .map_err(|err| err.map(TransportTimeoutError::Other))
96 }
97
98 fn remove_listener(&mut self, id: ListenerId) -> bool {
99 self.inner.remove_listener(id)
100 }
101
102 fn dial(&mut self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
103 let dial = self
104 .inner
105 .dial(addr)
106 .map_err(|err| err.map(TransportTimeoutError::Other))?;
107 Ok(Timeout {
108 inner: dial,
109 timer: Delay::new(self.outgoing_timeout),
110 })
111 }
112
113 fn dial_as_listener(
114 &mut self,
115 addr: Multiaddr,
116 ) -> Result<Self::Dial, TransportError<Self::Error>> {
117 let dial = self
118 .inner
119 .dial_as_listener(addr)
120 .map_err(|err| err.map(TransportTimeoutError::Other))?;
121 Ok(Timeout {
122 inner: dial,
123 timer: Delay::new(self.outgoing_timeout),
124 })
125 }
126
127 fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option<Multiaddr> {
128 self.inner.address_translation(server, observed)
129 }
130
131 fn poll(
132 self: Pin<&mut Self>,
133 cx: &mut Context<'_>,
134 ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
135 let this = self.project();
136 let timeout = *this.incoming_timeout;
137 this.inner.poll(cx).map(|event| {
138 event
139 .map_upgrade(move |inner_fut| Timeout {
140 inner: inner_fut,
141 timer: Delay::new(timeout),
142 })
143 .map_err(TransportTimeoutError::Other)
144 })
145 }
146}
147
148#[pin_project::pin_project]
153#[must_use = "futures do nothing unless polled"]
154pub struct Timeout<InnerFut> {
155 #[pin]
156 inner: InnerFut,
157 timer: Delay,
158}
159
160impl<InnerFut> Future for Timeout<InnerFut>
161where
162 InnerFut: TryFuture,
163{
164 type Output = Result<InnerFut::Ok, TransportTimeoutError<InnerFut::Error>>;
165
166 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
167 let mut this = self.project();
173
174 match TryFuture::try_poll(this.inner, cx) {
175 Poll::Pending => {}
176 Poll::Ready(Ok(v)) => return Poll::Ready(Ok(v)),
177 Poll::Ready(Err(err)) => return Poll::Ready(Err(TransportTimeoutError::Other(err))),
178 }
179
180 match Pin::new(&mut this.timer).poll(cx) {
181 Poll::Pending => Poll::Pending,
182 Poll::Ready(()) => Poll::Ready(Err(TransportTimeoutError::Timeout)),
183 }
184 }
185}
186
187#[derive(Debug)]
189pub enum TransportTimeoutError<TErr> {
190 Timeout,
192 TimerError(io::Error),
194 Other(TErr),
196}
197
198impl<TErr> fmt::Display for TransportTimeoutError<TErr>
199where
200 TErr: fmt::Display,
201{
202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203 match self {
204 TransportTimeoutError::Timeout => write!(f, "Timeout has been reached"),
205 TransportTimeoutError::TimerError(err) => write!(f, "Error in the timer: {err}"),
206 TransportTimeoutError::Other(err) => write!(f, "{err}"),
207 }
208 }
209}
210
211impl<TErr> error::Error for TransportTimeoutError<TErr>
212where
213 TErr: error::Error + 'static,
214{
215 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
216 match self {
217 TransportTimeoutError::Timeout => None,
218 TransportTimeoutError::TimerError(err) => Some(err),
219 TransportTimeoutError::Other(err) => Some(err),
220 }
221 }
222}