1mod error;
22
23pub(crate) mod pool;
24mod supported_protocols;
25
26pub use error::ConnectionError;
27pub(crate) use error::{
28 PendingConnectionError, PendingInboundConnectionError, PendingOutboundConnectionError,
29};
30pub use supported_protocols::SupportedProtocols;
31
32use crate::handler::{
33 AddressChange, ConnectionEvent, ConnectionHandler, DialUpgradeError, FullyNegotiatedInbound,
34 FullyNegotiatedOutbound, ListenUpgradeError, ProtocolSupport, ProtocolsAdded, ProtocolsChange,
35 UpgradeInfoSend,
36};
37use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend};
38use crate::{
39 ConnectionHandlerEvent, KeepAlive, Stream, StreamProtocol, StreamUpgradeError,
40 SubstreamProtocol,
41};
42use futures::future::BoxFuture;
43use futures::stream::FuturesUnordered;
44use futures::FutureExt;
45use futures::StreamExt;
46use futures_timer::Delay;
47use instant::Instant;
48use libp2p_core::connection::ConnectedPoint;
49use libp2p_core::multiaddr::Multiaddr;
50use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerEvent, StreamMuxerExt, SubstreamBox};
51use libp2p_core::upgrade;
52use libp2p_core::upgrade::{NegotiationError, ProtocolError};
53use libp2p_core::Endpoint;
54use libp2p_identity::PeerId;
55use std::cmp::max;
56use std::collections::HashSet;
57use std::fmt::{Display, Formatter};
58use std::future::Future;
59use std::sync::atomic::{AtomicUsize, Ordering};
60use std::task::Waker;
61use std::time::Duration;
62use std::{fmt, io, mem, pin::Pin, task::Context, task::Poll};
63
64static NEXT_CONNECTION_ID: AtomicUsize = AtomicUsize::new(1);
65
66#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
68pub struct ConnectionId(usize);
69
70impl ConnectionId {
71 pub fn new_unchecked(id: usize) -> Self {
78 Self(id)
79 }
80
81 pub(crate) fn next() -> Self {
83 Self(NEXT_CONNECTION_ID.fetch_add(1, Ordering::SeqCst))
84 }
85}
86
87impl Display for ConnectionId {
88 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
89 write!(f, "{}", self.0)
90 }
91}
92
93#[derive(Debug, Clone, PartialEq, Eq)]
95pub(crate) struct Connected {
96 pub(crate) endpoint: ConnectedPoint,
98 pub(crate) peer_id: PeerId,
100}
101
102#[derive(Debug, Clone)]
104pub(crate) enum Event<T> {
105 Handler(T),
107 AddressChange(Multiaddr),
109}
110
111pub(crate) struct Connection<THandler>
113where
114 THandler: ConnectionHandler,
115{
116 muxing: StreamMuxerBox,
118 handler: THandler,
120 negotiating_in: FuturesUnordered<
122 StreamUpgrade<
123 THandler::InboundOpenInfo,
124 <THandler::InboundProtocol as InboundUpgradeSend>::Output,
125 <THandler::InboundProtocol as InboundUpgradeSend>::Error,
126 >,
127 >,
128 negotiating_out: FuturesUnordered<
130 StreamUpgrade<
131 THandler::OutboundOpenInfo,
132 <THandler::OutboundProtocol as OutboundUpgradeSend>::Output,
133 <THandler::OutboundProtocol as OutboundUpgradeSend>::Error,
134 >,
135 >,
136 shutdown: Shutdown,
138 substream_upgrade_protocol_override: Option<upgrade::Version>,
140 max_negotiating_inbound_streams: usize,
149 requested_substreams: FuturesUnordered<
154 SubstreamRequested<THandler::OutboundOpenInfo, THandler::OutboundProtocol>,
155 >,
156
157 local_supported_protocols: HashSet<StreamProtocol>,
158 remote_supported_protocols: HashSet<StreamProtocol>,
159 idle_timeout: Duration,
160}
161
162impl<THandler> fmt::Debug for Connection<THandler>
163where
164 THandler: ConnectionHandler + fmt::Debug,
165 THandler::OutboundOpenInfo: fmt::Debug,
166{
167 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
168 f.debug_struct("Connection")
169 .field("handler", &self.handler)
170 .finish()
171 }
172}
173
174impl<THandler> Unpin for Connection<THandler> where THandler: ConnectionHandler {}
175
176impl<THandler> Connection<THandler>
177where
178 THandler: ConnectionHandler,
179{
180 pub(crate) fn new(
183 muxer: StreamMuxerBox,
184 mut handler: THandler,
185 substream_upgrade_protocol_override: Option<upgrade::Version>,
186 max_negotiating_inbound_streams: usize,
187 idle_timeout: Duration,
188 ) -> Self {
189 let initial_protocols = gather_supported_protocols(&handler);
190 if !initial_protocols.is_empty() {
191 handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(
192 ProtocolsChange::Added(ProtocolsAdded::from_set(&initial_protocols)),
193 ));
194 }
195
196 Connection {
197 muxing: muxer,
198 handler,
199 negotiating_in: Default::default(),
200 negotiating_out: Default::default(),
201 shutdown: Shutdown::None,
202 substream_upgrade_protocol_override,
203 max_negotiating_inbound_streams,
204 requested_substreams: Default::default(),
205 local_supported_protocols: initial_protocols,
206 remote_supported_protocols: Default::default(),
207 idle_timeout,
208 }
209 }
210
211 pub(crate) fn on_behaviour_event(&mut self, event: THandler::FromBehaviour) {
213 self.handler.on_behaviour_event(event);
214 }
215
216 pub(crate) fn close(self) -> (THandler, impl Future<Output = io::Result<()>>) {
219 (self.handler, self.muxing.close())
220 }
221
222 #[allow(deprecated)]
225 pub(crate) fn poll(
226 self: Pin<&mut Self>,
227 cx: &mut Context<'_>,
228 ) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError<THandler::Error>>> {
229 let Self {
230 requested_substreams,
231 muxing,
232 handler,
233 negotiating_out,
234 negotiating_in,
235 shutdown,
236 max_negotiating_inbound_streams,
237 substream_upgrade_protocol_override,
238 local_supported_protocols: supported_protocols,
239 remote_supported_protocols,
240 idle_timeout,
241 } = self.get_mut();
242
243 loop {
244 match requested_substreams.poll_next_unpin(cx) {
245 Poll::Ready(Some(Ok(()))) => continue,
246 Poll::Ready(Some(Err(info))) => {
247 handler.on_connection_event(ConnectionEvent::DialUpgradeError(
248 DialUpgradeError {
249 info,
250 error: StreamUpgradeError::Timeout,
251 },
252 ));
253 continue;
254 }
255 Poll::Ready(None) | Poll::Pending => {}
256 }
257
258 match handler.poll(cx) {
260 Poll::Pending => {}
261 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
262 let timeout = *protocol.timeout();
263 let (upgrade, user_data) = protocol.into_upgrade();
264
265 requested_substreams.push(SubstreamRequested::new(user_data, timeout, upgrade));
266 continue; }
268 Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)) => {
269 return Poll::Ready(Ok(Event::Handler(event)));
270 }
271 #[allow(deprecated)]
272 Poll::Ready(ConnectionHandlerEvent::Close(err)) => {
273 return Poll::Ready(Err(ConnectionError::Handler(err)));
274 }
275 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
276 ProtocolSupport::Added(protocols),
277 )) => {
278 if let Some(added) =
279 ProtocolsChange::add(remote_supported_protocols, &protocols)
280 {
281 handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange(added));
282 remote_supported_protocols.extend(protocols);
283 }
284
285 continue;
286 }
287 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
288 ProtocolSupport::Removed(protocols),
289 )) => {
290 if let Some(removed) =
291 ProtocolsChange::remove(remote_supported_protocols, &protocols)
292 {
293 handler
294 .on_connection_event(ConnectionEvent::RemoteProtocolsChange(removed));
295 remote_supported_protocols.retain(|p| !protocols.contains(p));
296 }
297
298 continue;
299 }
300 }
301
302 match negotiating_out.poll_next_unpin(cx) {
304 Poll::Pending | Poll::Ready(None) => {}
305 Poll::Ready(Some((info, Ok(protocol)))) => {
306 handler.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
307 FullyNegotiatedOutbound { protocol, info },
308 ));
309 continue;
310 }
311 Poll::Ready(Some((info, Err(error)))) => {
312 handler.on_connection_event(ConnectionEvent::DialUpgradeError(
313 DialUpgradeError { info, error },
314 ));
315 continue;
316 }
317 }
318
319 match negotiating_in.poll_next_unpin(cx) {
322 Poll::Pending | Poll::Ready(None) => {}
323 Poll::Ready(Some((info, Ok(protocol)))) => {
324 handler.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
325 FullyNegotiatedInbound { protocol, info },
326 ));
327 continue;
328 }
329 Poll::Ready(Some((info, Err(StreamUpgradeError::Apply(error))))) => {
330 handler.on_connection_event(ConnectionEvent::ListenUpgradeError(
331 ListenUpgradeError { info, error },
332 ));
333 continue;
334 }
335 Poll::Ready(Some((_, Err(StreamUpgradeError::Io(e))))) => {
336 log::debug!("failed to upgrade inbound stream: {e}");
337 continue;
338 }
339 Poll::Ready(Some((_, Err(StreamUpgradeError::NegotiationFailed)))) => {
340 log::debug!("no protocol could be agreed upon for inbound stream");
341 continue;
342 }
343 Poll::Ready(Some((_, Err(StreamUpgradeError::Timeout)))) => {
344 log::debug!("inbound stream upgrade timed out");
345 continue;
346 }
347 }
348
349 if let Some(new_shutdown) =
351 compute_new_shutdown(handler.connection_keep_alive(), shutdown, *idle_timeout)
352 {
353 *shutdown = new_shutdown;
354 }
355
356 if negotiating_in.is_empty()
359 && negotiating_out.is_empty()
360 && requested_substreams.is_empty()
361 {
362 match shutdown {
363 Shutdown::None => {}
364 Shutdown::Asap => return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)),
365 Shutdown::Later(delay, _) => match Future::poll(Pin::new(delay), cx) {
366 Poll::Ready(_) => {
367 return Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
368 }
369 Poll::Pending => {}
370 },
371 }
372 }
373
374 match muxing.poll_unpin(cx)? {
375 Poll::Pending => {}
376 Poll::Ready(StreamMuxerEvent::AddressChange(address)) => {
377 handler.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
378 new_address: &address,
379 }));
380 return Poll::Ready(Ok(Event::AddressChange(address)));
381 }
382 }
383
384 if let Some(requested_substream) = requested_substreams.iter_mut().next() {
385 match muxing.poll_outbound_unpin(cx)? {
386 Poll::Pending => {}
387 Poll::Ready(substream) => {
388 let (user_data, timeout, upgrade) = requested_substream.extract();
389
390 negotiating_out.push(StreamUpgrade::new_outbound(
391 substream,
392 user_data,
393 timeout,
394 upgrade,
395 *substream_upgrade_protocol_override,
396 ));
397
398 continue; }
400 }
401 }
402
403 if negotiating_in.len() < *max_negotiating_inbound_streams {
404 match muxing.poll_inbound_unpin(cx)? {
405 Poll::Pending => {}
406 Poll::Ready(substream) => {
407 let protocol = handler.listen_protocol();
408
409 negotiating_in.push(StreamUpgrade::new_inbound(substream, protocol));
410
411 continue; }
413 }
414 }
415
416 let new_protocols = gather_supported_protocols(handler);
417 let changes = ProtocolsChange::from_full_sets(supported_protocols, &new_protocols);
418
419 if !changes.is_empty() {
420 for change in changes {
421 handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(change));
422 }
423
424 *supported_protocols = new_protocols;
425
426 continue; }
428
429 return Poll::Pending; }
431 }
432
433 #[cfg(test)]
434 #[allow(deprecated)]
435 fn poll_noop_waker(
436 &mut self,
437 ) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError<THandler::Error>>> {
438 Pin::new(self).poll(&mut Context::from_waker(futures::task::noop_waker_ref()))
439 }
440}
441
442fn gather_supported_protocols(handler: &impl ConnectionHandler) -> HashSet<StreamProtocol> {
443 handler
444 .listen_protocol()
445 .upgrade()
446 .protocol_info()
447 .filter_map(|i| StreamProtocol::try_from_owned(i.as_ref().to_owned()).ok())
448 .collect()
449}
450
451fn compute_new_shutdown(
452 handler_keep_alive: KeepAlive,
453 current_shutdown: &Shutdown,
454 idle_timeout: Duration,
455) -> Option<Shutdown> {
456 #[allow(deprecated)]
457 match (current_shutdown, handler_keep_alive) {
458 (Shutdown::Later(_, deadline), KeepAlive::Until(t)) => {
459 let now = Instant::now();
460
461 if *deadline != t {
462 let deadline = t;
463 if let Some(new_duration) = deadline.checked_duration_since(Instant::now()) {
464 let effective_keep_alive = max(new_duration, idle_timeout);
465
466 let safe_keep_alive = checked_add_fraction(now, effective_keep_alive);
467 return Some(Shutdown::Later(Delay::new(safe_keep_alive), deadline));
468 }
469 }
470 None
471 }
472 (_, KeepAlive::Until(earliest_shutdown)) => {
473 let now = Instant::now();
474
475 if let Some(requested) = earliest_shutdown.checked_duration_since(now) {
476 let effective_keep_alive = max(requested, idle_timeout);
477
478 let safe_keep_alive = checked_add_fraction(now, effective_keep_alive);
479
480 return Some(Shutdown::Later(
483 Delay::new(safe_keep_alive),
484 earliest_shutdown,
485 ));
486 }
487 None
488 }
489 (_, KeepAlive::No) if idle_timeout == Duration::ZERO => Some(Shutdown::Asap),
490 (Shutdown::Later(_, _), KeepAlive::No) => {
491 None
493 }
494 (_, KeepAlive::No) => {
495 let now = Instant::now();
496 let safe_keep_alive = checked_add_fraction(now, idle_timeout);
497
498 Some(Shutdown::Later(
499 Delay::new(safe_keep_alive),
500 now + safe_keep_alive,
501 ))
502 }
503 (_, KeepAlive::Yes) => Some(Shutdown::None),
504 }
505}
506
507fn checked_add_fraction(start: Instant, mut duration: Duration) -> Duration {
512 while start.checked_add(duration).is_none() {
513 log::debug!("{start:?} + {duration:?} cannot be presented, halving duration");
514
515 duration /= 2;
516 }
517
518 duration
519}
520
521#[derive(Debug, Copy, Clone)]
523pub(crate) struct IncomingInfo<'a> {
524 pub(crate) local_addr: &'a Multiaddr,
526 pub(crate) send_back_addr: &'a Multiaddr,
528}
529
530impl<'a> IncomingInfo<'a> {
531 pub(crate) fn create_connected_point(&self) -> ConnectedPoint {
533 ConnectedPoint::Listener {
534 local_addr: self.local_addr.clone(),
535 send_back_addr: self.send_back_addr.clone(),
536 }
537 }
538}
539
540struct StreamUpgrade<UserData, TOk, TErr> {
541 user_data: Option<UserData>,
542 timeout: Delay,
543 upgrade: BoxFuture<'static, Result<TOk, StreamUpgradeError<TErr>>>,
544}
545
546impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
547 fn new_outbound<Upgrade>(
548 substream: SubstreamBox,
549 user_data: UserData,
550 timeout: Delay,
551 upgrade: Upgrade,
552 version_override: Option<upgrade::Version>,
553 ) -> Self
554 where
555 Upgrade: OutboundUpgradeSend<Output = TOk, Error = TErr>,
556 {
557 let effective_version = match version_override {
558 Some(version_override) if version_override != upgrade::Version::default() => {
559 log::debug!(
560 "Substream upgrade protocol override: {:?} -> {:?}",
561 upgrade::Version::default(),
562 version_override
563 );
564
565 version_override
566 }
567 _ => upgrade::Version::default(),
568 };
569 let protocols = upgrade.protocol_info();
570
571 Self {
572 user_data: Some(user_data),
573 timeout,
574 upgrade: Box::pin(async move {
575 let (info, stream) = multistream_select::dialer_select_proto(
576 substream,
577 protocols,
578 effective_version,
579 )
580 .await
581 .map_err(to_stream_upgrade_error)?;
582
583 let output = upgrade
584 .upgrade_outbound(Stream::new(stream), info)
585 .await
586 .map_err(StreamUpgradeError::Apply)?;
587
588 Ok(output)
589 }),
590 }
591 }
592}
593
594impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
595 fn new_inbound<Upgrade>(
596 substream: SubstreamBox,
597 protocol: SubstreamProtocol<Upgrade, UserData>,
598 ) -> Self
599 where
600 Upgrade: InboundUpgradeSend<Output = TOk, Error = TErr>,
601 {
602 let timeout = *protocol.timeout();
603 let (upgrade, open_info) = protocol.into_upgrade();
604 let protocols = upgrade.protocol_info();
605
606 Self {
607 user_data: Some(open_info),
608 timeout: Delay::new(timeout),
609 upgrade: Box::pin(async move {
610 let (info, stream) =
611 multistream_select::listener_select_proto(substream, protocols)
612 .await
613 .map_err(to_stream_upgrade_error)?;
614
615 let output = upgrade
616 .upgrade_inbound(Stream::new(stream), info)
617 .await
618 .map_err(StreamUpgradeError::Apply)?;
619
620 Ok(output)
621 }),
622 }
623 }
624}
625
626fn to_stream_upgrade_error<T>(e: NegotiationError) -> StreamUpgradeError<T> {
627 match e {
628 NegotiationError::Failed => StreamUpgradeError::NegotiationFailed,
629 NegotiationError::ProtocolError(ProtocolError::IoError(e)) => StreamUpgradeError::Io(e),
630 NegotiationError::ProtocolError(other) => {
631 StreamUpgradeError::Io(io::Error::new(io::ErrorKind::Other, other))
632 }
633 }
634}
635
636impl<UserData, TOk, TErr> Unpin for StreamUpgrade<UserData, TOk, TErr> {}
637
638impl<UserData, TOk, TErr> Future for StreamUpgrade<UserData, TOk, TErr> {
639 type Output = (UserData, Result<TOk, StreamUpgradeError<TErr>>);
640
641 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
642 match self.timeout.poll_unpin(cx) {
643 Poll::Ready(()) => {
644 return Poll::Ready((
645 self.user_data
646 .take()
647 .expect("Future not to be polled again once ready."),
648 Err(StreamUpgradeError::Timeout),
649 ))
650 }
651
652 Poll::Pending => {}
653 }
654
655 let result = futures::ready!(self.upgrade.poll_unpin(cx));
656 let user_data = self
657 .user_data
658 .take()
659 .expect("Future not to be polled again once ready.");
660
661 Poll::Ready((user_data, result))
662 }
663}
664
665enum SubstreamRequested<UserData, Upgrade> {
666 Waiting {
667 user_data: UserData,
668 timeout: Delay,
669 upgrade: Upgrade,
670 extracted_waker: Option<Waker>,
675 },
676 Done,
677}
678
679impl<UserData, Upgrade> SubstreamRequested<UserData, Upgrade> {
680 fn new(user_data: UserData, timeout: Duration, upgrade: Upgrade) -> Self {
681 Self::Waiting {
682 user_data,
683 timeout: Delay::new(timeout),
684 upgrade,
685 extracted_waker: None,
686 }
687 }
688
689 fn extract(&mut self) -> (UserData, Delay, Upgrade) {
690 match mem::replace(self, Self::Done) {
691 SubstreamRequested::Waiting {
692 user_data,
693 timeout,
694 upgrade,
695 extracted_waker: waker,
696 } => {
697 if let Some(waker) = waker {
698 waker.wake();
699 }
700
701 (user_data, timeout, upgrade)
702 }
703 SubstreamRequested::Done => panic!("cannot extract twice"),
704 }
705 }
706}
707
708impl<UserData, Upgrade> Unpin for SubstreamRequested<UserData, Upgrade> {}
709
710impl<UserData, Upgrade> Future for SubstreamRequested<UserData, Upgrade> {
711 type Output = Result<(), UserData>;
712
713 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
714 let this = self.get_mut();
715
716 match mem::replace(this, Self::Done) {
717 SubstreamRequested::Waiting {
718 user_data,
719 upgrade,
720 mut timeout,
721 ..
722 } => match timeout.poll_unpin(cx) {
723 Poll::Ready(()) => Poll::Ready(Err(user_data)),
724 Poll::Pending => {
725 *this = Self::Waiting {
726 user_data,
727 upgrade,
728 timeout,
729 extracted_waker: Some(cx.waker().clone()),
730 };
731 Poll::Pending
732 }
733 },
734 SubstreamRequested::Done => Poll::Ready(Ok(())),
735 }
736 }
737}
738
739#[derive(Debug)]
749enum Shutdown {
750 None,
752 Asap,
754 Later(Delay, Instant),
756}
757
758#[cfg(test)]
759mod tests {
760 use super::*;
761 use crate::dummy;
762 use futures::future;
763 use futures::AsyncRead;
764 use futures::AsyncWrite;
765 use libp2p_core::upgrade::{DeniedUpgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo};
766 use libp2p_core::StreamMuxer;
767 use quickcheck::*;
768 use std::sync::{Arc, Weak};
769 use std::time::Instant;
770 use void::Void;
771
772 #[test]
773 fn max_negotiating_inbound_streams() {
774 let _ = env_logger::try_init();
775
776 fn prop(max_negotiating_inbound_streams: u8) {
777 let max_negotiating_inbound_streams: usize = max_negotiating_inbound_streams.into();
778
779 let alive_substream_counter = Arc::new(());
780 let mut connection = Connection::new(
781 StreamMuxerBox::new(DummyStreamMuxer {
782 counter: alive_substream_counter.clone(),
783 }),
784 MockConnectionHandler::new(Duration::from_secs(10)),
785 None,
786 max_negotiating_inbound_streams,
787 Duration::ZERO,
788 );
789
790 let result = connection.poll_noop_waker();
791
792 assert!(result.is_pending());
793 assert_eq!(
794 Arc::weak_count(&alive_substream_counter),
795 max_negotiating_inbound_streams,
796 "Expect no more than the maximum number of allowed streams"
797 );
798 }
799
800 QuickCheck::new().quickcheck(prop as fn(_));
801 }
802
803 #[test]
804 fn outbound_stream_timeout_starts_on_request() {
805 let upgrade_timeout = Duration::from_secs(1);
806 let mut connection = Connection::new(
807 StreamMuxerBox::new(PendingStreamMuxer),
808 MockConnectionHandler::new(upgrade_timeout),
809 None,
810 2,
811 Duration::ZERO,
812 );
813
814 connection.handler.open_new_outbound();
815 let _ = connection.poll_noop_waker();
816
817 std::thread::sleep(upgrade_timeout + Duration::from_secs(1));
818
819 let _ = connection.poll_noop_waker();
820
821 assert!(matches!(
822 connection.handler.error.unwrap(),
823 StreamUpgradeError::Timeout
824 ))
825 }
826
827 #[test]
828 fn propagates_changes_to_supported_inbound_protocols() {
829 let mut connection = Connection::new(
830 StreamMuxerBox::new(PendingStreamMuxer),
831 ConfigurableProtocolConnectionHandler::default(),
832 None,
833 0,
834 Duration::ZERO,
835 );
836
837 connection.handler.listen_on(&["/foo"]);
839 let _ = connection.poll_noop_waker();
840
841 assert_eq!(connection.handler.local_added, vec![vec!["/foo"]]);
842 assert!(connection.handler.local_removed.is_empty());
843
844 connection.handler.listen_on(&["/foo", "/bar"]);
846 let _ = connection.poll_noop_waker();
847
848 assert_eq!(
849 connection.handler.local_added,
850 vec![vec!["/foo"], vec!["/bar"]],
851 "expect to only receive an event for the newly added protocols"
852 );
853 assert!(connection.handler.local_removed.is_empty());
854
855 connection.handler.listen_on(&["/bar"]);
857 let _ = connection.poll_noop_waker();
858
859 assert_eq!(
860 connection.handler.local_added,
861 vec![vec!["/foo"], vec!["/bar"]]
862 );
863 assert_eq!(connection.handler.local_removed, vec![vec!["/foo"]]);
864 }
865
866 #[test]
867 fn only_propagtes_actual_changes_to_remote_protocols_to_handler() {
868 let mut connection = Connection::new(
869 StreamMuxerBox::new(PendingStreamMuxer),
870 ConfigurableProtocolConnectionHandler::default(),
871 None,
872 0,
873 Duration::ZERO,
874 );
875
876 connection.handler.remote_adds_support_for(&["/foo"]);
878 let _ = connection.poll_noop_waker();
879
880 assert_eq!(connection.handler.remote_added, vec![vec!["/foo"]]);
881 assert!(connection.handler.remote_removed.is_empty());
882
883 connection
885 .handler
886 .remote_adds_support_for(&["/foo", "/bar"]);
887 let _ = connection.poll_noop_waker();
888
889 assert_eq!(
890 connection.handler.remote_added,
891 vec![vec!["/foo"], vec!["/bar"]],
892 "expect to only receive an event for the newly added protocol"
893 );
894 assert!(connection.handler.remote_removed.is_empty());
895
896 connection.handler.remote_removes_support_for(&["/baz"]);
898 let _ = connection.poll_noop_waker();
899
900 assert_eq!(
901 connection.handler.remote_added,
902 vec![vec!["/foo"], vec!["/bar"]]
903 );
904 assert!(&connection.handler.remote_removed.is_empty());
905
906 connection.handler.remote_removes_support_for(&["/bar"]);
908 let _ = connection.poll_noop_waker();
909
910 assert_eq!(
911 connection.handler.remote_added,
912 vec![vec!["/foo"], vec!["/bar"]]
913 );
914 assert_eq!(connection.handler.remote_removed, vec![vec!["/bar"]]);
915 }
916
917 #[tokio::test]
918 async fn idle_timeout_with_keep_alive_no() {
919 let idle_timeout = Duration::from_millis(100);
920
921 let mut connection = Connection::new(
922 StreamMuxerBox::new(PendingStreamMuxer),
923 dummy::ConnectionHandler,
924 None,
925 0,
926 idle_timeout,
927 );
928
929 assert!(connection.poll_noop_waker().is_pending());
930
931 tokio::time::sleep(idle_timeout).await;
932
933 assert!(matches!(
934 connection.poll_noop_waker(),
935 Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
936 ));
937 }
938
939 #[tokio::test]
940 async fn idle_timeout_with_keep_alive_until_greater_than_idle_timeout() {
941 let idle_timeout = Duration::from_millis(100);
942
943 let mut connection = Connection::new(
944 StreamMuxerBox::new(PendingStreamMuxer),
945 KeepAliveUntilConnectionHandler {
946 until: Instant::now() + idle_timeout * 2,
947 },
948 None,
949 0,
950 idle_timeout,
951 );
952
953 assert!(connection.poll_noop_waker().is_pending());
954
955 tokio::time::sleep(idle_timeout).await;
956
957 assert!(
958 connection.poll_noop_waker().is_pending(),
959 "`KeepAlive::Until` is greater than idle-timeout, continue sleeping"
960 );
961
962 tokio::time::sleep(idle_timeout).await;
963
964 assert!(matches!(
965 connection.poll_noop_waker(),
966 Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
967 ));
968 }
969
970 #[tokio::test]
971 async fn idle_timeout_with_keep_alive_until_less_than_idle_timeout() {
972 let idle_timeout = Duration::from_millis(100);
973
974 let mut connection = Connection::new(
975 StreamMuxerBox::new(PendingStreamMuxer),
976 KeepAliveUntilConnectionHandler {
977 until: Instant::now() + idle_timeout / 2,
978 },
979 None,
980 0,
981 idle_timeout,
982 );
983
984 assert!(connection.poll_noop_waker().is_pending());
985
986 tokio::time::sleep(idle_timeout / 2).await;
987
988 assert!(
989 connection.poll_noop_waker().is_pending(),
990 "`KeepAlive::Until` is less than idle-timeout, honor idle-timeout"
991 );
992
993 tokio::time::sleep(idle_timeout / 2).await;
994
995 assert!(matches!(
996 connection.poll_noop_waker(),
997 Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
998 ));
999 }
1000
1001 #[test]
1002 fn checked_add_fraction_can_add_u64_max() {
1003 let _ = env_logger::try_init();
1004 let start = Instant::now();
1005
1006 let duration = checked_add_fraction(start, Duration::from_secs(u64::MAX));
1007
1008 assert!(start.checked_add(duration).is_some())
1009 }
1010
1011 #[test]
1012 fn compute_new_shutdown_does_not_panic() {
1013 let _ = env_logger::try_init();
1014
1015 #[derive(Debug)]
1016 struct ArbitraryShutdown(Shutdown);
1017
1018 impl Clone for ArbitraryShutdown {
1019 fn clone(&self) -> Self {
1020 let shutdown = match self.0 {
1021 Shutdown::None => Shutdown::None,
1022 Shutdown::Asap => Shutdown::Asap,
1023 Shutdown::Later(_, instant) => Shutdown::Later(
1024 Delay::new(Duration::from_secs(1)),
1027 instant,
1028 ),
1029 };
1030
1031 ArbitraryShutdown(shutdown)
1032 }
1033 }
1034
1035 impl Arbitrary for ArbitraryShutdown {
1036 fn arbitrary(g: &mut Gen) -> Self {
1037 let shutdown = match g.gen_range(1u8..4) {
1038 1 => Shutdown::None,
1039 2 => Shutdown::Asap,
1040 3 => Shutdown::Later(
1041 Delay::new(Duration::from_secs(u32::arbitrary(g) as u64)),
1042 Instant::now()
1043 .checked_add(Duration::arbitrary(g))
1044 .unwrap_or(Instant::now()),
1045 ),
1046 _ => unreachable!(),
1047 };
1048
1049 Self(shutdown)
1050 }
1051 }
1052
1053 fn prop(
1054 handler_keep_alive: KeepAlive,
1055 current_shutdown: ArbitraryShutdown,
1056 idle_timeout: Duration,
1057 ) {
1058 compute_new_shutdown(handler_keep_alive, ¤t_shutdown.0, idle_timeout);
1059 }
1060
1061 QuickCheck::new().quickcheck(prop as fn(_, _, _));
1062 }
1063
1064 struct KeepAliveUntilConnectionHandler {
1065 until: Instant,
1066 }
1067
1068 impl ConnectionHandler for KeepAliveUntilConnectionHandler {
1069 type FromBehaviour = Void;
1070 type ToBehaviour = Void;
1071 type Error = Void;
1072 type InboundProtocol = DeniedUpgrade;
1073 type OutboundProtocol = DeniedUpgrade;
1074 type InboundOpenInfo = ();
1075 type OutboundOpenInfo = Void;
1076
1077 fn listen_protocol(
1078 &self,
1079 ) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
1080 SubstreamProtocol::new(DeniedUpgrade, ())
1081 }
1082
1083 fn connection_keep_alive(&self) -> KeepAlive {
1084 #[allow(deprecated)]
1085 KeepAlive::Until(self.until)
1086 }
1087
1088 #[allow(deprecated)]
1089 fn poll(
1090 &mut self,
1091 _: &mut Context<'_>,
1092 ) -> Poll<
1093 ConnectionHandlerEvent<
1094 Self::OutboundProtocol,
1095 Self::OutboundOpenInfo,
1096 Self::ToBehaviour,
1097 Self::Error,
1098 >,
1099 > {
1100 Poll::Pending
1101 }
1102
1103 fn on_behaviour_event(&mut self, _: Self::FromBehaviour) {}
1104
1105 fn on_connection_event(
1106 &mut self,
1107 _: ConnectionEvent<
1108 Self::InboundProtocol,
1109 Self::OutboundProtocol,
1110 Self::InboundOpenInfo,
1111 Self::OutboundOpenInfo,
1112 >,
1113 ) {
1114 }
1115 }
1116
1117 struct DummyStreamMuxer {
1118 counter: Arc<()>,
1119 }
1120
1121 impl StreamMuxer for DummyStreamMuxer {
1122 type Substream = PendingSubstream;
1123 type Error = Void;
1124
1125 fn poll_inbound(
1126 self: Pin<&mut Self>,
1127 _: &mut Context<'_>,
1128 ) -> Poll<Result<Self::Substream, Self::Error>> {
1129 Poll::Ready(Ok(PendingSubstream(Arc::downgrade(&self.counter))))
1130 }
1131
1132 fn poll_outbound(
1133 self: Pin<&mut Self>,
1134 _: &mut Context<'_>,
1135 ) -> Poll<Result<Self::Substream, Self::Error>> {
1136 Poll::Pending
1137 }
1138
1139 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1140 Poll::Ready(Ok(()))
1141 }
1142
1143 fn poll(
1144 self: Pin<&mut Self>,
1145 _: &mut Context<'_>,
1146 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
1147 Poll::Pending
1148 }
1149 }
1150
1151 struct PendingStreamMuxer;
1153
1154 impl StreamMuxer for PendingStreamMuxer {
1155 type Substream = PendingSubstream;
1156 type Error = Void;
1157
1158 fn poll_inbound(
1159 self: Pin<&mut Self>,
1160 _: &mut Context<'_>,
1161 ) -> Poll<Result<Self::Substream, Self::Error>> {
1162 Poll::Pending
1163 }
1164
1165 fn poll_outbound(
1166 self: Pin<&mut Self>,
1167 _: &mut Context<'_>,
1168 ) -> Poll<Result<Self::Substream, Self::Error>> {
1169 Poll::Pending
1170 }
1171
1172 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1173 Poll::Pending
1174 }
1175
1176 fn poll(
1177 self: Pin<&mut Self>,
1178 _: &mut Context<'_>,
1179 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
1180 Poll::Pending
1181 }
1182 }
1183
1184 struct PendingSubstream(Weak<()>);
1185
1186 impl AsyncRead for PendingSubstream {
1187 fn poll_read(
1188 self: Pin<&mut Self>,
1189 _cx: &mut Context<'_>,
1190 _buf: &mut [u8],
1191 ) -> Poll<std::io::Result<usize>> {
1192 Poll::Pending
1193 }
1194 }
1195
1196 impl AsyncWrite for PendingSubstream {
1197 fn poll_write(
1198 self: Pin<&mut Self>,
1199 _cx: &mut Context<'_>,
1200 _buf: &[u8],
1201 ) -> Poll<std::io::Result<usize>> {
1202 Poll::Pending
1203 }
1204
1205 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1206 Poll::Pending
1207 }
1208
1209 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1210 Poll::Pending
1211 }
1212 }
1213
1214 struct MockConnectionHandler {
1215 outbound_requested: bool,
1216 error: Option<StreamUpgradeError<Void>>,
1217 upgrade_timeout: Duration,
1218 }
1219
1220 impl MockConnectionHandler {
1221 fn new(upgrade_timeout: Duration) -> Self {
1222 Self {
1223 outbound_requested: false,
1224 error: None,
1225 upgrade_timeout,
1226 }
1227 }
1228
1229 fn open_new_outbound(&mut self) {
1230 self.outbound_requested = true;
1231 }
1232 }
1233
1234 #[derive(Default)]
1235 struct ConfigurableProtocolConnectionHandler {
1236 events: Vec<ConnectionHandlerEvent<DeniedUpgrade, (), Void, Void>>,
1237 active_protocols: HashSet<StreamProtocol>,
1238 local_added: Vec<Vec<StreamProtocol>>,
1239 local_removed: Vec<Vec<StreamProtocol>>,
1240 remote_added: Vec<Vec<StreamProtocol>>,
1241 remote_removed: Vec<Vec<StreamProtocol>>,
1242 }
1243
1244 impl ConfigurableProtocolConnectionHandler {
1245 fn listen_on(&mut self, protocols: &[&'static str]) {
1246 self.active_protocols = protocols.iter().copied().map(StreamProtocol::new).collect();
1247 }
1248
1249 fn remote_adds_support_for(&mut self, protocols: &[&'static str]) {
1250 self.events
1251 .push(ConnectionHandlerEvent::ReportRemoteProtocols(
1252 ProtocolSupport::Added(
1253 protocols.iter().copied().map(StreamProtocol::new).collect(),
1254 ),
1255 ));
1256 }
1257
1258 fn remote_removes_support_for(&mut self, protocols: &[&'static str]) {
1259 self.events
1260 .push(ConnectionHandlerEvent::ReportRemoteProtocols(
1261 ProtocolSupport::Removed(
1262 protocols.iter().copied().map(StreamProtocol::new).collect(),
1263 ),
1264 ));
1265 }
1266 }
1267
1268 impl ConnectionHandler for MockConnectionHandler {
1269 type FromBehaviour = Void;
1270 type ToBehaviour = Void;
1271 type Error = Void;
1272 type InboundProtocol = DeniedUpgrade;
1273 type OutboundProtocol = DeniedUpgrade;
1274 type InboundOpenInfo = ();
1275 type OutboundOpenInfo = ();
1276
1277 fn listen_protocol(
1278 &self,
1279 ) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
1280 SubstreamProtocol::new(DeniedUpgrade, ()).with_timeout(self.upgrade_timeout)
1281 }
1282
1283 fn on_connection_event(
1284 &mut self,
1285 event: ConnectionEvent<
1286 Self::InboundProtocol,
1287 Self::OutboundProtocol,
1288 Self::InboundOpenInfo,
1289 Self::OutboundOpenInfo,
1290 >,
1291 ) {
1292 match event {
1293 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
1294 protocol,
1295 ..
1296 }) => void::unreachable(protocol),
1297 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
1298 protocol,
1299 ..
1300 }) => void::unreachable(protocol),
1301 ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
1302 self.error = Some(error)
1303 }
1304 ConnectionEvent::AddressChange(_)
1305 | ConnectionEvent::ListenUpgradeError(_)
1306 | ConnectionEvent::LocalProtocolsChange(_)
1307 | ConnectionEvent::RemoteProtocolsChange(_) => {}
1308 }
1309 }
1310
1311 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
1312 void::unreachable(event)
1313 }
1314
1315 fn connection_keep_alive(&self) -> KeepAlive {
1316 KeepAlive::Yes
1317 }
1318
1319 #[allow(deprecated)]
1320 fn poll(
1321 &mut self,
1322 _: &mut Context<'_>,
1323 ) -> Poll<
1324 ConnectionHandlerEvent<
1325 Self::OutboundProtocol,
1326 Self::OutboundOpenInfo,
1327 Self::ToBehaviour,
1328 Self::Error,
1329 >,
1330 > {
1331 if self.outbound_requested {
1332 self.outbound_requested = false;
1333 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
1334 protocol: SubstreamProtocol::new(DeniedUpgrade, ())
1335 .with_timeout(self.upgrade_timeout),
1336 });
1337 }
1338
1339 Poll::Pending
1340 }
1341 }
1342
1343 impl ConnectionHandler for ConfigurableProtocolConnectionHandler {
1344 type FromBehaviour = Void;
1345 type ToBehaviour = Void;
1346 type Error = Void;
1347 type InboundProtocol = ManyProtocolsUpgrade;
1348 type OutboundProtocol = DeniedUpgrade;
1349 type InboundOpenInfo = ();
1350 type OutboundOpenInfo = ();
1351
1352 fn listen_protocol(
1353 &self,
1354 ) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
1355 SubstreamProtocol::new(
1356 ManyProtocolsUpgrade {
1357 protocols: Vec::from_iter(self.active_protocols.clone()),
1358 },
1359 (),
1360 )
1361 }
1362
1363 fn on_connection_event(
1364 &mut self,
1365 event: ConnectionEvent<
1366 Self::InboundProtocol,
1367 Self::OutboundProtocol,
1368 Self::InboundOpenInfo,
1369 Self::OutboundOpenInfo,
1370 >,
1371 ) {
1372 match event {
1373 ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Added(added)) => {
1374 self.local_added.push(added.cloned().collect())
1375 }
1376 ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Removed(removed)) => {
1377 self.local_removed.push(removed.cloned().collect())
1378 }
1379 ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Added(added)) => {
1380 self.remote_added.push(added.cloned().collect())
1381 }
1382 ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Removed(removed)) => {
1383 self.remote_removed.push(removed.cloned().collect())
1384 }
1385 _ => {}
1386 }
1387 }
1388
1389 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
1390 void::unreachable(event)
1391 }
1392
1393 fn connection_keep_alive(&self) -> KeepAlive {
1394 KeepAlive::Yes
1395 }
1396
1397 #[allow(deprecated)]
1398 fn poll(
1399 &mut self,
1400 _: &mut Context<'_>,
1401 ) -> Poll<
1402 ConnectionHandlerEvent<
1403 Self::OutboundProtocol,
1404 Self::OutboundOpenInfo,
1405 Self::ToBehaviour,
1406 Self::Error,
1407 >,
1408 > {
1409 if let Some(event) = self.events.pop() {
1410 return Poll::Ready(event);
1411 }
1412
1413 Poll::Pending
1414 }
1415 }
1416
1417 struct ManyProtocolsUpgrade {
1418 protocols: Vec<StreamProtocol>,
1419 }
1420
1421 impl UpgradeInfo for ManyProtocolsUpgrade {
1422 type Info = StreamProtocol;
1423 type InfoIter = std::vec::IntoIter<Self::Info>;
1424
1425 fn protocol_info(&self) -> Self::InfoIter {
1426 self.protocols.clone().into_iter()
1427 }
1428 }
1429
1430 impl<C> InboundUpgrade<C> for ManyProtocolsUpgrade {
1431 type Output = C;
1432 type Error = Void;
1433 type Future = future::Ready<Result<Self::Output, Self::Error>>;
1434
1435 fn upgrade_inbound(self, stream: C, _: Self::Info) -> Self::Future {
1436 future::ready(Ok(stream))
1437 }
1438 }
1439
1440 impl<C> OutboundUpgrade<C> for ManyProtocolsUpgrade {
1441 type Output = C;
1442 type Error = Void;
1443 type Future = future::Ready<Result<Self::Output, Self::Error>>;
1444
1445 fn upgrade_outbound(self, stream: C, _: Self::Info) -> Self::Future {
1446 future::ready(Ok(stream))
1447 }
1448 }
1449}
1450
1451#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1453enum PendingPoint {
1454 Dialer {
1460 role_override: Endpoint,
1462 },
1463 Listener {
1465 local_addr: Multiaddr,
1467 send_back_addr: Multiaddr,
1469 },
1470}
1471
1472impl From<ConnectedPoint> for PendingPoint {
1473 fn from(endpoint: ConnectedPoint) -> Self {
1474 match endpoint {
1475 ConnectedPoint::Dialer { role_override, .. } => PendingPoint::Dialer { role_override },
1476 ConnectedPoint::Listener {
1477 local_addr,
1478 send_back_addr,
1479 } => PendingPoint::Listener {
1480 local_addr,
1481 send_back_addr,
1482 },
1483 }
1484 }
1485}