1use crate::{protocol, PROTOCOL_NAME};
22use futures::future::{BoxFuture, Either};
23use futures::prelude::*;
24use futures_timer::Delay;
25use libp2p_core::upgrade::ReadyUpgrade;
26use libp2p_swarm::handler::{
27 ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound,
28};
29use libp2p_swarm::{
30 ConnectionHandler, ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError,
31 SubstreamProtocol,
32};
33use std::collections::VecDeque;
34use std::{
35 error::Error,
36 fmt, io,
37 task::{Context, Poll},
38 time::Duration,
39};
40use void::Void;
41
42#[derive(Debug, Clone)]
44pub struct Config {
45 timeout: Duration,
47 interval: Duration,
49}
50
51impl Config {
52 pub fn new() -> Self {
63 Self {
64 timeout: Duration::from_secs(20),
65 interval: Duration::from_secs(15),
66 }
67 }
68
69 pub fn with_timeout(mut self, d: Duration) -> Self {
71 self.timeout = d;
72 self
73 }
74
75 pub fn with_interval(mut self, d: Duration) -> Self {
77 self.interval = d;
78 self
79 }
80}
81
82impl Default for Config {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88#[derive(Debug)]
90pub enum Failure {
91 Timeout,
94 Unsupported,
96 Other {
98 error: Box<dyn std::error::Error + Send + Sync + 'static>,
99 },
100}
101
102impl Failure {
103 fn other(e: impl std::error::Error + Send + Sync + 'static) -> Self {
104 Self::Other { error: Box::new(e) }
105 }
106}
107
108impl fmt::Display for Failure {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 match self {
111 Failure::Timeout => f.write_str("Ping timeout"),
112 Failure::Other { error } => write!(f, "Ping error: {error}"),
113 Failure::Unsupported => write!(f, "Ping protocol not supported"),
114 }
115 }
116}
117
118impl Error for Failure {
119 fn source(&self) -> Option<&(dyn Error + 'static)> {
120 match self {
121 Failure::Timeout => None,
122 Failure::Other { error } => Some(&**error),
123 Failure::Unsupported => None,
124 }
125 }
126}
127
128pub struct Handler {
131 config: Config,
133 interval: Delay,
135 pending_errors: VecDeque<Failure>,
137 failures: u32,
141 outbound: Option<OutboundState>,
143 inbound: Option<PongFuture>,
147 state: State,
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq)]
152enum State {
153 Inactive {
155 reported: bool,
159 },
160 Active,
162}
163
164impl Handler {
165 pub fn new(config: Config) -> Self {
167 Handler {
168 config,
169 interval: Delay::new(Duration::new(0, 0)),
170 pending_errors: VecDeque::with_capacity(2),
171 failures: 0,
172 outbound: None,
173 inbound: None,
174 state: State::Active,
175 }
176 }
177
178 fn on_dial_upgrade_error(
179 &mut self,
180 DialUpgradeError { error, .. }: DialUpgradeError<
181 <Self as ConnectionHandler>::OutboundOpenInfo,
182 <Self as ConnectionHandler>::OutboundProtocol,
183 >,
184 ) {
185 self.outbound = None; self.interval.reset(Duration::new(0, 0));
198
199 let error = match error {
200 StreamUpgradeError::NegotiationFailed => {
201 debug_assert_eq!(self.state, State::Active);
202
203 self.state = State::Inactive { reported: false };
204 return;
205 }
206 StreamUpgradeError::Timeout => Failure::Other {
208 error: Box::new(std::io::Error::new(
209 std::io::ErrorKind::TimedOut,
210 "ping protocol negotiation timed out",
211 )),
212 },
213 StreamUpgradeError::Apply(e) => void::unreachable(e),
214 StreamUpgradeError::Io(e) => Failure::Other { error: Box::new(e) },
215 };
216
217 self.pending_errors.push_front(error);
218 }
219}
220
221impl ConnectionHandler for Handler {
222 type FromBehaviour = Void;
223 type ToBehaviour = Result<Duration, Failure>;
224 type InboundProtocol = ReadyUpgrade<StreamProtocol>;
225 type OutboundProtocol = ReadyUpgrade<StreamProtocol>;
226 type OutboundOpenInfo = ();
227 type InboundOpenInfo = ();
228
229 fn listen_protocol(&self) -> SubstreamProtocol<ReadyUpgrade<StreamProtocol>, ()> {
230 SubstreamProtocol::new(ReadyUpgrade::new(PROTOCOL_NAME), ())
231 }
232
233 fn on_behaviour_event(&mut self, _: Void) {}
234
235 #[tracing::instrument(level = "trace", name = "ConnectionHandler::poll", skip(self, cx))]
236 fn poll(
237 &mut self,
238 cx: &mut Context<'_>,
239 ) -> Poll<ConnectionHandlerEvent<ReadyUpgrade<StreamProtocol>, (), Result<Duration, Failure>>>
240 {
241 match self.state {
242 State::Inactive { reported: true } => {
243 return Poll::Pending; }
245 State::Inactive { reported: false } => {
246 self.state = State::Inactive { reported: true };
247 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Err(
248 Failure::Unsupported,
249 )));
250 }
251 State::Active => {}
252 }
253
254 if let Some(fut) = self.inbound.as_mut() {
256 match fut.poll_unpin(cx) {
257 Poll::Pending => {}
258 Poll::Ready(Err(e)) => {
259 tracing::debug!("Inbound ping error: {:?}", e);
260 self.inbound = None;
261 }
262 Poll::Ready(Ok(stream)) => {
263 tracing::trace!("answered inbound ping from peer");
264
265 self.inbound = Some(protocol::recv_ping(stream).boxed());
267 }
268 }
269 }
270
271 loop {
272 if let Some(error) = self.pending_errors.pop_back() {
274 tracing::debug!("Ping failure: {:?}", error);
275
276 self.failures += 1;
277
278 if self.failures > 1 {
284 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Err(error)));
285 }
286 }
287
288 match self.outbound.take() {
290 Some(OutboundState::Ping(mut ping)) => match ping.poll_unpin(cx) {
291 Poll::Pending => {
292 self.outbound = Some(OutboundState::Ping(ping));
293 break;
294 }
295 Poll::Ready(Ok((stream, rtt))) => {
296 tracing::debug!(?rtt, "ping succeeded");
297 self.failures = 0;
298 self.interval.reset(self.config.interval);
299 self.outbound = Some(OutboundState::Idle(stream));
300 return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Ok(rtt)));
301 }
302 Poll::Ready(Err(e)) => {
303 self.interval.reset(self.config.interval);
304 self.pending_errors.push_front(e);
305 }
306 },
307 Some(OutboundState::Idle(stream)) => match self.interval.poll_unpin(cx) {
308 Poll::Pending => {
309 self.outbound = Some(OutboundState::Idle(stream));
310 break;
311 }
312 Poll::Ready(()) => {
313 self.outbound = Some(OutboundState::Ping(
314 send_ping(stream, self.config.timeout).boxed(),
315 ));
316 }
317 },
318 Some(OutboundState::OpenStream) => {
319 self.outbound = Some(OutboundState::OpenStream);
320 break;
321 }
322 None => match self.interval.poll_unpin(cx) {
323 Poll::Pending => break,
324 Poll::Ready(()) => {
325 self.outbound = Some(OutboundState::OpenStream);
326 let protocol = SubstreamProtocol::new(ReadyUpgrade::new(PROTOCOL_NAME), ());
327 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
328 protocol,
329 });
330 }
331 },
332 }
333 }
334
335 Poll::Pending
336 }
337
338 fn on_connection_event(
339 &mut self,
340 event: ConnectionEvent<
341 Self::InboundProtocol,
342 Self::OutboundProtocol,
343 Self::InboundOpenInfo,
344 Self::OutboundOpenInfo,
345 >,
346 ) {
347 match event {
348 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
349 protocol: mut stream,
350 ..
351 }) => {
352 stream.ignore_for_keep_alive();
353 self.inbound = Some(protocol::recv_ping(stream).boxed());
354 }
355 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
356 protocol: mut stream,
357 ..
358 }) => {
359 stream.ignore_for_keep_alive();
360 self.outbound = Some(OutboundState::Ping(
361 send_ping(stream, self.config.timeout).boxed(),
362 ));
363 }
364 ConnectionEvent::DialUpgradeError(dial_upgrade_error) => {
365 self.on_dial_upgrade_error(dial_upgrade_error)
366 }
367 _ => {}
368 }
369 }
370}
371
372type PingFuture = BoxFuture<'static, Result<(Stream, Duration), Failure>>;
373type PongFuture = BoxFuture<'static, Result<Stream, io::Error>>;
374
375enum OutboundState {
377 OpenStream,
379 Idle(Stream),
381 Ping(PingFuture),
383}
384
385async fn send_ping(stream: Stream, timeout: Duration) -> Result<(Stream, Duration), Failure> {
387 let ping = protocol::send_ping(stream);
388 futures::pin_mut!(ping);
389
390 match future::select(ping, Delay::new(timeout)).await {
391 Either::Left((Ok((stream, rtt)), _)) => Ok((stream, rtt)),
392 Either::Left((Err(e), _)) => Err(Failure::other(e)),
393 Either::Right(((), _)) => Err(Failure::Timeout),
394 }
395}