1mod error;
22
23pub(crate) mod pool;
24mod supported_protocols;
25
26pub use error::ConnectionError;
27pub(crate) use error::{
28 PendingConnectionError, PendingInboundConnectionError, PendingOutboundConnectionError,
29};
30use libp2p_core::transport::PortUse;
31pub use supported_protocols::SupportedProtocols;
32
33use crate::handler::{
34 AddressChange, ConnectionEvent, ConnectionHandler, DialUpgradeError, FullyNegotiatedInbound,
35 FullyNegotiatedOutbound, ListenUpgradeError, ProtocolSupport, ProtocolsChange, UpgradeInfoSend,
36};
37use crate::stream::ActiveStreamCounter;
38use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend};
39use crate::{
40 ConnectionHandlerEvent, Stream, StreamProtocol, StreamUpgradeError, SubstreamProtocol,
41};
42use futures::future::BoxFuture;
43use futures::stream::FuturesUnordered;
44use futures::StreamExt;
45use futures::{stream, FutureExt};
46use futures_timer::Delay;
47use libp2p_core::connection::ConnectedPoint;
48use libp2p_core::multiaddr::Multiaddr;
49use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerEvent, StreamMuxerExt, SubstreamBox};
50use libp2p_core::upgrade;
51use libp2p_core::upgrade::{NegotiationError, ProtocolError};
52use libp2p_core::Endpoint;
53use libp2p_identity::PeerId;
54use std::collections::{HashMap, HashSet};
55use std::fmt::{Display, Formatter};
56use std::future::Future;
57use std::sync::atomic::{AtomicUsize, Ordering};
58use std::task::Waker;
59use std::time::Duration;
60use std::{fmt, io, mem, pin::Pin, task::Context, task::Poll};
61use web_time::Instant;
62
63static NEXT_CONNECTION_ID: AtomicUsize = AtomicUsize::new(1);
64
65#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
67pub struct ConnectionId(usize);
68
69impl ConnectionId {
70 pub fn new_unchecked(id: usize) -> Self {
77 Self(id)
78 }
79
80 pub(crate) fn next() -> Self {
82 Self(NEXT_CONNECTION_ID.fetch_add(1, Ordering::SeqCst))
83 }
84}
85
86impl Display for ConnectionId {
87 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
88 write!(f, "{}", self.0)
89 }
90}
91
92#[derive(Debug, Clone, PartialEq, Eq)]
94pub(crate) struct Connected {
95 pub(crate) endpoint: ConnectedPoint,
97 pub(crate) peer_id: PeerId,
99}
100
101#[derive(Debug, Clone)]
103pub(crate) enum Event<T> {
104 Handler(T),
106 AddressChange(Multiaddr),
108}
109
110pub(crate) struct Connection<THandler>
112where
113 THandler: ConnectionHandler,
114{
115 muxing: StreamMuxerBox,
117 handler: THandler,
119 negotiating_in: FuturesUnordered<
121 StreamUpgrade<
122 THandler::InboundOpenInfo,
123 <THandler::InboundProtocol as InboundUpgradeSend>::Output,
124 <THandler::InboundProtocol as InboundUpgradeSend>::Error,
125 >,
126 >,
127 negotiating_out: FuturesUnordered<
129 StreamUpgrade<
130 THandler::OutboundOpenInfo,
131 <THandler::OutboundProtocol as OutboundUpgradeSend>::Output,
132 <THandler::OutboundProtocol as OutboundUpgradeSend>::Error,
133 >,
134 >,
135 shutdown: Shutdown,
137 substream_upgrade_protocol_override: Option<upgrade::Version>,
139 max_negotiating_inbound_streams: usize,
148 requested_substreams: FuturesUnordered<
153 SubstreamRequested<THandler::OutboundOpenInfo, THandler::OutboundProtocol>,
154 >,
155
156 local_supported_protocols:
157 HashMap<AsStrHashEq<<THandler::InboundProtocol as UpgradeInfoSend>::Info>, bool>,
158 remote_supported_protocols: HashSet<StreamProtocol>,
159 protocol_buffer: Vec<StreamProtocol>,
160
161 idle_timeout: Duration,
162 stream_counter: ActiveStreamCounter,
163}
164
165impl<THandler> fmt::Debug for Connection<THandler>
166where
167 THandler: ConnectionHandler + fmt::Debug,
168 THandler::OutboundOpenInfo: fmt::Debug,
169{
170 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
171 f.debug_struct("Connection")
172 .field("handler", &self.handler)
173 .finish()
174 }
175}
176
177impl<THandler> Unpin for Connection<THandler> where THandler: ConnectionHandler {}
178
179impl<THandler> Connection<THandler>
180where
181 THandler: ConnectionHandler,
182{
183 pub(crate) fn new(
186 muxer: StreamMuxerBox,
187 mut handler: THandler,
188 substream_upgrade_protocol_override: Option<upgrade::Version>,
189 max_negotiating_inbound_streams: usize,
190 idle_timeout: Duration,
191 ) -> Self {
192 let initial_protocols = gather_supported_protocols(&handler);
193 let mut buffer = Vec::new();
194
195 if !initial_protocols.is_empty() {
196 handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(
197 ProtocolsChange::from_initial_protocols(
198 initial_protocols.keys().map(|e| &e.0),
199 &mut buffer,
200 ),
201 ));
202 }
203
204 Connection {
205 muxing: muxer,
206 handler,
207 negotiating_in: Default::default(),
208 negotiating_out: Default::default(),
209 shutdown: Shutdown::None,
210 substream_upgrade_protocol_override,
211 max_negotiating_inbound_streams,
212 requested_substreams: Default::default(),
213 local_supported_protocols: initial_protocols,
214 remote_supported_protocols: Default::default(),
215 protocol_buffer: buffer,
216 idle_timeout,
217 stream_counter: ActiveStreamCounter::default(),
218 }
219 }
220
221 pub(crate) fn on_behaviour_event(&mut self, event: THandler::FromBehaviour) {
223 self.handler.on_behaviour_event(event);
224 }
225
226 pub(crate) fn close(
228 self,
229 ) -> (
230 impl futures::Stream<Item = THandler::ToBehaviour>,
231 impl Future<Output = io::Result<()>>,
232 ) {
233 let Connection {
234 mut handler,
235 muxing,
236 ..
237 } = self;
238
239 (
240 stream::poll_fn(move |cx| handler.poll_close(cx)),
241 muxing.close(),
242 )
243 }
244
245 #[tracing::instrument(level = "debug", name = "Connection::poll", skip(self, cx))]
248 pub(crate) fn poll(
249 self: Pin<&mut Self>,
250 cx: &mut Context<'_>,
251 ) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError>> {
252 let Self {
253 requested_substreams,
254 muxing,
255 handler,
256 negotiating_out,
257 negotiating_in,
258 shutdown,
259 max_negotiating_inbound_streams,
260 substream_upgrade_protocol_override,
261 local_supported_protocols: supported_protocols,
262 remote_supported_protocols,
263 protocol_buffer,
264 idle_timeout,
265 stream_counter,
266 ..
267 } = self.get_mut();
268
269 loop {
270 match requested_substreams.poll_next_unpin(cx) {
271 Poll::Ready(Some(Ok(()))) => continue,
272 Poll::Ready(Some(Err(info))) => {
273 handler.on_connection_event(ConnectionEvent::DialUpgradeError(
274 DialUpgradeError {
275 info,
276 error: StreamUpgradeError::Timeout,
277 },
278 ));
279 continue;
280 }
281 Poll::Ready(None) | Poll::Pending => {}
282 }
283
284 match handler.poll(cx) {
286 Poll::Pending => {}
287 Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
288 let timeout = *protocol.timeout();
289 let (upgrade, user_data) = protocol.into_upgrade();
290
291 requested_substreams.push(SubstreamRequested::new(user_data, timeout, upgrade));
292 continue; }
294 Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)) => {
295 return Poll::Ready(Ok(Event::Handler(event)));
296 }
297 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
298 ProtocolSupport::Added(protocols),
299 )) => {
300 if let Some(added) =
301 ProtocolsChange::add(remote_supported_protocols, protocols, protocol_buffer)
302 {
303 handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange(added));
304 remote_supported_protocols.extend(protocol_buffer.drain(..));
305 }
306 continue;
307 }
308 Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
309 ProtocolSupport::Removed(protocols),
310 )) => {
311 if let Some(removed) = ProtocolsChange::remove(
312 remote_supported_protocols,
313 protocols,
314 protocol_buffer,
315 ) {
316 handler
317 .on_connection_event(ConnectionEvent::RemoteProtocolsChange(removed));
318 }
319 continue;
320 }
321 }
322
323 match negotiating_out.poll_next_unpin(cx) {
325 Poll::Pending | Poll::Ready(None) => {}
326 Poll::Ready(Some((info, Ok(protocol)))) => {
327 handler.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
328 FullyNegotiatedOutbound { protocol, info },
329 ));
330 continue;
331 }
332 Poll::Ready(Some((info, Err(error)))) => {
333 handler.on_connection_event(ConnectionEvent::DialUpgradeError(
334 DialUpgradeError { info, error },
335 ));
336 continue;
337 }
338 }
339
340 match negotiating_in.poll_next_unpin(cx) {
343 Poll::Pending | Poll::Ready(None) => {}
344 Poll::Ready(Some((info, Ok(protocol)))) => {
345 handler.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
346 FullyNegotiatedInbound { protocol, info },
347 ));
348 continue;
349 }
350 Poll::Ready(Some((info, Err(StreamUpgradeError::Apply(error))))) => {
351 handler.on_connection_event(ConnectionEvent::ListenUpgradeError(
352 ListenUpgradeError { info, error },
353 ));
354 continue;
355 }
356 Poll::Ready(Some((_, Err(StreamUpgradeError::Io(e))))) => {
357 tracing::debug!("failed to upgrade inbound stream: {e}");
358 continue;
359 }
360 Poll::Ready(Some((_, Err(StreamUpgradeError::NegotiationFailed)))) => {
361 tracing::debug!("no protocol could be agreed upon for inbound stream");
362 continue;
363 }
364 Poll::Ready(Some((_, Err(StreamUpgradeError::Timeout)))) => {
365 tracing::debug!("inbound stream upgrade timed out");
366 continue;
367 }
368 }
369
370 if negotiating_in.is_empty()
373 && negotiating_out.is_empty()
374 && requested_substreams.is_empty()
375 && stream_counter.has_no_active_streams()
376 {
377 if let Some(new_timeout) =
378 compute_new_shutdown(handler.connection_keep_alive(), shutdown, *idle_timeout)
379 {
380 *shutdown = new_timeout;
381 }
382
383 match shutdown {
384 Shutdown::None => {}
385 Shutdown::Asap => return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)),
386 Shutdown::Later(delay) => match Future::poll(Pin::new(delay), cx) {
387 Poll::Ready(_) => {
388 return Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
389 }
390 Poll::Pending => {}
391 },
392 }
393 } else {
394 *shutdown = Shutdown::None;
395 }
396
397 match muxing.poll_unpin(cx)? {
398 Poll::Pending => {}
399 Poll::Ready(StreamMuxerEvent::AddressChange(address)) => {
400 handler.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
401 new_address: &address,
402 }));
403 return Poll::Ready(Ok(Event::AddressChange(address)));
404 }
405 }
406
407 if let Some(requested_substream) = requested_substreams.iter_mut().next() {
408 match muxing.poll_outbound_unpin(cx)? {
409 Poll::Pending => {}
410 Poll::Ready(substream) => {
411 let (user_data, timeout, upgrade) = requested_substream.extract();
412
413 negotiating_out.push(StreamUpgrade::new_outbound(
414 substream,
415 user_data,
416 timeout,
417 upgrade,
418 *substream_upgrade_protocol_override,
419 stream_counter.clone(),
420 ));
421
422 continue; }
424 }
425 }
426
427 if negotiating_in.len() < *max_negotiating_inbound_streams {
428 match muxing.poll_inbound_unpin(cx)? {
429 Poll::Pending => {}
430 Poll::Ready(substream) => {
431 let protocol = handler.listen_protocol();
432
433 negotiating_in.push(StreamUpgrade::new_inbound(
434 substream,
435 protocol,
436 stream_counter.clone(),
437 ));
438
439 continue; }
441 }
442 }
443
444 let changes = ProtocolsChange::from_full_sets(
445 supported_protocols,
446 handler.listen_protocol().upgrade().protocol_info(),
447 protocol_buffer,
448 );
449
450 if !changes.is_empty() {
451 for change in changes {
452 handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(change));
453 }
454 continue; }
456
457 return Poll::Pending; }
459 }
460
461 #[cfg(test)]
462 fn poll_noop_waker(&mut self) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError>> {
463 Pin::new(self).poll(&mut Context::from_waker(futures::task::noop_waker_ref()))
464 }
465}
466
467fn gather_supported_protocols<C: ConnectionHandler>(
468 handler: &C,
469) -> HashMap<AsStrHashEq<<C::InboundProtocol as UpgradeInfoSend>::Info>, bool> {
470 handler
471 .listen_protocol()
472 .upgrade()
473 .protocol_info()
474 .map(|info| (AsStrHashEq(info), true))
475 .collect()
476}
477
478fn compute_new_shutdown(
479 handler_keep_alive: bool,
480 current_shutdown: &Shutdown,
481 idle_timeout: Duration,
482) -> Option<Shutdown> {
483 match (current_shutdown, handler_keep_alive) {
484 (_, false) if idle_timeout == Duration::ZERO => Some(Shutdown::Asap),
485 (Shutdown::Later(_), false) => None, (_, false) => {
487 let now = Instant::now();
488 let safe_keep_alive = checked_add_fraction(now, idle_timeout);
489
490 Some(Shutdown::Later(Delay::new(safe_keep_alive)))
491 }
492 (_, true) => Some(Shutdown::None),
493 }
494}
495
496fn checked_add_fraction(start: Instant, mut duration: Duration) -> Duration {
501 while start.checked_add(duration).is_none() {
502 tracing::debug!(start=?start, duration=?duration, "start + duration cannot be presented, halving duration");
503
504 duration /= 2;
505 }
506
507 duration
508}
509
510#[derive(Debug, Copy, Clone)]
512pub(crate) struct IncomingInfo<'a> {
513 pub(crate) local_addr: &'a Multiaddr,
515 pub(crate) send_back_addr: &'a Multiaddr,
517}
518
519impl<'a> IncomingInfo<'a> {
520 pub(crate) fn create_connected_point(&self) -> ConnectedPoint {
522 ConnectedPoint::Listener {
523 local_addr: self.local_addr.clone(),
524 send_back_addr: self.send_back_addr.clone(),
525 }
526 }
527}
528
529struct StreamUpgrade<UserData, TOk, TErr> {
530 user_data: Option<UserData>,
531 timeout: Delay,
532 upgrade: BoxFuture<'static, Result<TOk, StreamUpgradeError<TErr>>>,
533}
534
535impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
536 fn new_outbound<Upgrade>(
537 substream: SubstreamBox,
538 user_data: UserData,
539 timeout: Delay,
540 upgrade: Upgrade,
541 version_override: Option<upgrade::Version>,
542 counter: ActiveStreamCounter,
543 ) -> Self
544 where
545 Upgrade: OutboundUpgradeSend<Output = TOk, Error = TErr>,
546 {
547 let effective_version = match version_override {
548 Some(version_override) if version_override != upgrade::Version::default() => {
549 tracing::debug!(
550 "Substream upgrade protocol override: {:?} -> {:?}",
551 upgrade::Version::default(),
552 version_override
553 );
554
555 version_override
556 }
557 _ => upgrade::Version::default(),
558 };
559 let protocols = upgrade.protocol_info();
560
561 Self {
562 user_data: Some(user_data),
563 timeout,
564 upgrade: Box::pin(async move {
565 let (info, stream) = multistream_select::dialer_select_proto(
566 substream,
567 protocols,
568 effective_version,
569 )
570 .await
571 .map_err(to_stream_upgrade_error)?;
572
573 let output = upgrade
574 .upgrade_outbound(Stream::new(stream, counter), info)
575 .await
576 .map_err(StreamUpgradeError::Apply)?;
577
578 Ok(output)
579 }),
580 }
581 }
582}
583
584impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
585 fn new_inbound<Upgrade>(
586 substream: SubstreamBox,
587 protocol: SubstreamProtocol<Upgrade, UserData>,
588 counter: ActiveStreamCounter,
589 ) -> Self
590 where
591 Upgrade: InboundUpgradeSend<Output = TOk, Error = TErr>,
592 {
593 let timeout = *protocol.timeout();
594 let (upgrade, open_info) = protocol.into_upgrade();
595 let protocols = upgrade.protocol_info();
596
597 Self {
598 user_data: Some(open_info),
599 timeout: Delay::new(timeout),
600 upgrade: Box::pin(async move {
601 let (info, stream) =
602 multistream_select::listener_select_proto(substream, protocols)
603 .await
604 .map_err(to_stream_upgrade_error)?;
605
606 let output = upgrade
607 .upgrade_inbound(Stream::new(stream, counter), info)
608 .await
609 .map_err(StreamUpgradeError::Apply)?;
610
611 Ok(output)
612 }),
613 }
614 }
615}
616
617fn to_stream_upgrade_error<T>(e: NegotiationError) -> StreamUpgradeError<T> {
618 match e {
619 NegotiationError::Failed => StreamUpgradeError::NegotiationFailed,
620 NegotiationError::ProtocolError(ProtocolError::IoError(e)) => StreamUpgradeError::Io(e),
621 NegotiationError::ProtocolError(other) => {
622 StreamUpgradeError::Io(io::Error::new(io::ErrorKind::Other, other))
623 }
624 }
625}
626
627impl<UserData, TOk, TErr> Unpin for StreamUpgrade<UserData, TOk, TErr> {}
628
629impl<UserData, TOk, TErr> Future for StreamUpgrade<UserData, TOk, TErr> {
630 type Output = (UserData, Result<TOk, StreamUpgradeError<TErr>>);
631
632 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
633 match self.timeout.poll_unpin(cx) {
634 Poll::Ready(()) => {
635 return Poll::Ready((
636 self.user_data
637 .take()
638 .expect("Future not to be polled again once ready."),
639 Err(StreamUpgradeError::Timeout),
640 ))
641 }
642
643 Poll::Pending => {}
644 }
645
646 let result = futures::ready!(self.upgrade.poll_unpin(cx));
647 let user_data = self
648 .user_data
649 .take()
650 .expect("Future not to be polled again once ready.");
651
652 Poll::Ready((user_data, result))
653 }
654}
655
656enum SubstreamRequested<UserData, Upgrade> {
657 Waiting {
658 user_data: UserData,
659 timeout: Delay,
660 upgrade: Upgrade,
661 extracted_waker: Option<Waker>,
666 },
667 Done,
668}
669
670impl<UserData, Upgrade> SubstreamRequested<UserData, Upgrade> {
671 fn new(user_data: UserData, timeout: Duration, upgrade: Upgrade) -> Self {
672 Self::Waiting {
673 user_data,
674 timeout: Delay::new(timeout),
675 upgrade,
676 extracted_waker: None,
677 }
678 }
679
680 fn extract(&mut self) -> (UserData, Delay, Upgrade) {
681 match mem::replace(self, Self::Done) {
682 SubstreamRequested::Waiting {
683 user_data,
684 timeout,
685 upgrade,
686 extracted_waker: waker,
687 } => {
688 if let Some(waker) = waker {
689 waker.wake();
690 }
691
692 (user_data, timeout, upgrade)
693 }
694 SubstreamRequested::Done => panic!("cannot extract twice"),
695 }
696 }
697}
698
699impl<UserData, Upgrade> Unpin for SubstreamRequested<UserData, Upgrade> {}
700
701impl<UserData, Upgrade> Future for SubstreamRequested<UserData, Upgrade> {
702 type Output = Result<(), UserData>;
703
704 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
705 let this = self.get_mut();
706
707 match mem::replace(this, Self::Done) {
708 SubstreamRequested::Waiting {
709 user_data,
710 upgrade,
711 mut timeout,
712 ..
713 } => match timeout.poll_unpin(cx) {
714 Poll::Ready(()) => Poll::Ready(Err(user_data)),
715 Poll::Pending => {
716 *this = Self::Waiting {
717 user_data,
718 upgrade,
719 timeout,
720 extracted_waker: Some(cx.waker().clone()),
721 };
722 Poll::Pending
723 }
724 },
725 SubstreamRequested::Done => Poll::Ready(Ok(())),
726 }
727 }
728}
729
730#[derive(Debug)]
740enum Shutdown {
741 None,
743 Asap,
745 Later(Delay),
747}
748
749pub(crate) struct AsStrHashEq<T>(pub(crate) T);
753
754impl<T: AsRef<str>> Eq for AsStrHashEq<T> {}
755
756impl<T: AsRef<str>> PartialEq for AsStrHashEq<T> {
757 fn eq(&self, other: &Self) -> bool {
758 self.0.as_ref() == other.0.as_ref()
759 }
760}
761
762impl<T: AsRef<str>> std::hash::Hash for AsStrHashEq<T> {
763 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
764 self.0.as_ref().hash(state)
765 }
766}
767
768#[cfg(test)]
769mod tests {
770 use super::*;
771 use crate::dummy;
772 use futures::future;
773 use futures::AsyncRead;
774 use futures::AsyncWrite;
775 use libp2p_core::upgrade::{DeniedUpgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo};
776 use libp2p_core::StreamMuxer;
777 use quickcheck::*;
778 use std::sync::{Arc, Weak};
779 use std::time::Instant;
780 use tracing_subscriber::EnvFilter;
781 use void::Void;
782
783 #[test]
784 fn max_negotiating_inbound_streams() {
785 let _ = tracing_subscriber::fmt()
786 .with_env_filter(EnvFilter::from_default_env())
787 .try_init();
788
789 fn prop(max_negotiating_inbound_streams: u8) {
790 let max_negotiating_inbound_streams: usize = max_negotiating_inbound_streams.into();
791
792 let alive_substream_counter = Arc::new(());
793 let mut connection = Connection::new(
794 StreamMuxerBox::new(DummyStreamMuxer {
795 counter: alive_substream_counter.clone(),
796 }),
797 MockConnectionHandler::new(Duration::from_secs(10)),
798 None,
799 max_negotiating_inbound_streams,
800 Duration::ZERO,
801 );
802
803 let result = connection.poll_noop_waker();
804
805 assert!(result.is_pending());
806 assert_eq!(
807 Arc::weak_count(&alive_substream_counter),
808 max_negotiating_inbound_streams,
809 "Expect no more than the maximum number of allowed streams"
810 );
811 }
812
813 QuickCheck::new().quickcheck(prop as fn(_));
814 }
815
816 #[test]
817 fn outbound_stream_timeout_starts_on_request() {
818 let upgrade_timeout = Duration::from_secs(1);
819 let mut connection = Connection::new(
820 StreamMuxerBox::new(PendingStreamMuxer),
821 MockConnectionHandler::new(upgrade_timeout),
822 None,
823 2,
824 Duration::ZERO,
825 );
826
827 connection.handler.open_new_outbound();
828 let _ = connection.poll_noop_waker();
829
830 std::thread::sleep(upgrade_timeout + Duration::from_secs(1));
831
832 let _ = connection.poll_noop_waker();
833
834 assert!(matches!(
835 connection.handler.error.unwrap(),
836 StreamUpgradeError::Timeout
837 ))
838 }
839
840 #[test]
841 fn propagates_changes_to_supported_inbound_protocols() {
842 let mut connection = Connection::new(
843 StreamMuxerBox::new(PendingStreamMuxer),
844 ConfigurableProtocolConnectionHandler::default(),
845 None,
846 0,
847 Duration::ZERO,
848 );
849
850 connection.handler.listen_on(&["/foo"]);
852 let _ = connection.poll_noop_waker();
853
854 assert_eq!(connection.handler.local_added, vec![vec!["/foo"]]);
855 assert!(connection.handler.local_removed.is_empty());
856
857 connection.handler.listen_on(&["/foo", "/bar"]);
859 let _ = connection.poll_noop_waker();
860
861 assert_eq!(
862 connection.handler.local_added,
863 vec![vec!["/foo"], vec!["/bar"]],
864 "expect to only receive an event for the newly added protocols"
865 );
866 assert!(connection.handler.local_removed.is_empty());
867
868 connection.handler.listen_on(&["/bar"]);
870 let _ = connection.poll_noop_waker();
871
872 assert_eq!(
873 connection.handler.local_added,
874 vec![vec!["/foo"], vec!["/bar"]]
875 );
876 assert_eq!(connection.handler.local_removed, vec![vec!["/foo"]]);
877 }
878
879 #[test]
880 fn only_propagtes_actual_changes_to_remote_protocols_to_handler() {
881 let mut connection = Connection::new(
882 StreamMuxerBox::new(PendingStreamMuxer),
883 ConfigurableProtocolConnectionHandler::default(),
884 None,
885 0,
886 Duration::ZERO,
887 );
888
889 connection.handler.remote_adds_support_for(&["/foo"]);
891 let _ = connection.poll_noop_waker();
892
893 assert_eq!(connection.handler.remote_added, vec![vec!["/foo"]]);
894 assert!(connection.handler.remote_removed.is_empty());
895
896 connection
898 .handler
899 .remote_adds_support_for(&["/foo", "/bar"]);
900 let _ = connection.poll_noop_waker();
901
902 assert_eq!(
903 connection.handler.remote_added,
904 vec![vec!["/foo"], vec!["/bar"]],
905 "expect to only receive an event for the newly added protocol"
906 );
907 assert!(connection.handler.remote_removed.is_empty());
908
909 connection.handler.remote_removes_support_for(&["/baz"]);
911 let _ = connection.poll_noop_waker();
912
913 assert_eq!(
914 connection.handler.remote_added,
915 vec![vec!["/foo"], vec!["/bar"]]
916 );
917 assert!(&connection.handler.remote_removed.is_empty());
918
919 connection.handler.remote_removes_support_for(&["/bar"]);
921 let _ = connection.poll_noop_waker();
922
923 assert_eq!(
924 connection.handler.remote_added,
925 vec![vec!["/foo"], vec!["/bar"]]
926 );
927 assert_eq!(connection.handler.remote_removed, vec![vec!["/bar"]]);
928 }
929
930 #[tokio::test]
931 async fn idle_timeout_with_keep_alive_no() {
932 let idle_timeout = Duration::from_millis(100);
933
934 let mut connection = Connection::new(
935 StreamMuxerBox::new(PendingStreamMuxer),
936 dummy::ConnectionHandler,
937 None,
938 0,
939 idle_timeout,
940 );
941
942 assert!(connection.poll_noop_waker().is_pending());
943
944 tokio::time::sleep(idle_timeout).await;
945
946 assert!(matches!(
947 connection.poll_noop_waker(),
948 Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
949 ));
950 }
951
952 #[test]
953 fn checked_add_fraction_can_add_u64_max() {
954 let _ = tracing_subscriber::fmt()
955 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
956 .try_init();
957 let start = Instant::now();
958
959 let duration = checked_add_fraction(start, Duration::from_secs(u64::MAX));
960
961 assert!(start.checked_add(duration).is_some())
962 }
963
964 #[test]
965 fn compute_new_shutdown_does_not_panic() {
966 let _ = tracing_subscriber::fmt()
967 .with_env_filter(EnvFilter::from_default_env())
968 .try_init();
969
970 #[derive(Debug)]
971 struct ArbitraryShutdown(Shutdown);
972
973 impl Clone for ArbitraryShutdown {
974 fn clone(&self) -> Self {
975 let shutdown = match self.0 {
976 Shutdown::None => Shutdown::None,
977 Shutdown::Asap => Shutdown::Asap,
978 Shutdown::Later(_) => Shutdown::Later(
979 Delay::new(Duration::from_secs(1)),
982 ),
983 };
984
985 ArbitraryShutdown(shutdown)
986 }
987 }
988
989 impl Arbitrary for ArbitraryShutdown {
990 fn arbitrary(g: &mut Gen) -> Self {
991 let shutdown = match g.gen_range(1u8..4) {
992 1 => Shutdown::None,
993 2 => Shutdown::Asap,
994 3 => Shutdown::Later(Delay::new(Duration::from_secs(u32::arbitrary(g) as u64))),
995 _ => unreachable!(),
996 };
997
998 Self(shutdown)
999 }
1000 }
1001
1002 fn prop(
1003 handler_keep_alive: bool,
1004 current_shutdown: ArbitraryShutdown,
1005 idle_timeout: Duration,
1006 ) {
1007 compute_new_shutdown(handler_keep_alive, ¤t_shutdown.0, idle_timeout);
1008 }
1009
1010 QuickCheck::new().quickcheck(prop as fn(_, _, _));
1011 }
1012
1013 struct DummyStreamMuxer {
1014 counter: Arc<()>,
1015 }
1016
1017 impl StreamMuxer for DummyStreamMuxer {
1018 type Substream = PendingSubstream;
1019 type Error = Void;
1020
1021 fn poll_inbound(
1022 self: Pin<&mut Self>,
1023 _: &mut Context<'_>,
1024 ) -> Poll<Result<Self::Substream, Self::Error>> {
1025 Poll::Ready(Ok(PendingSubstream {
1026 _weak: Arc::downgrade(&self.counter),
1027 }))
1028 }
1029
1030 fn poll_outbound(
1031 self: Pin<&mut Self>,
1032 _: &mut Context<'_>,
1033 ) -> Poll<Result<Self::Substream, Self::Error>> {
1034 Poll::Pending
1035 }
1036
1037 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1038 Poll::Ready(Ok(()))
1039 }
1040
1041 fn poll(
1042 self: Pin<&mut Self>,
1043 _: &mut Context<'_>,
1044 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
1045 Poll::Pending
1046 }
1047 }
1048
1049 struct PendingStreamMuxer;
1051
1052 impl StreamMuxer for PendingStreamMuxer {
1053 type Substream = PendingSubstream;
1054 type Error = Void;
1055
1056 fn poll_inbound(
1057 self: Pin<&mut Self>,
1058 _: &mut Context<'_>,
1059 ) -> Poll<Result<Self::Substream, Self::Error>> {
1060 Poll::Pending
1061 }
1062
1063 fn poll_outbound(
1064 self: Pin<&mut Self>,
1065 _: &mut Context<'_>,
1066 ) -> Poll<Result<Self::Substream, Self::Error>> {
1067 Poll::Pending
1068 }
1069
1070 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1071 Poll::Pending
1072 }
1073
1074 fn poll(
1075 self: Pin<&mut Self>,
1076 _: &mut Context<'_>,
1077 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
1078 Poll::Pending
1079 }
1080 }
1081
1082 struct PendingSubstream {
1083 _weak: Weak<()>,
1084 }
1085
1086 impl AsyncRead for PendingSubstream {
1087 fn poll_read(
1088 self: Pin<&mut Self>,
1089 _cx: &mut Context<'_>,
1090 _buf: &mut [u8],
1091 ) -> Poll<std::io::Result<usize>> {
1092 Poll::Pending
1093 }
1094 }
1095
1096 impl AsyncWrite for PendingSubstream {
1097 fn poll_write(
1098 self: Pin<&mut Self>,
1099 _cx: &mut Context<'_>,
1100 _buf: &[u8],
1101 ) -> Poll<std::io::Result<usize>> {
1102 Poll::Pending
1103 }
1104
1105 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1106 Poll::Pending
1107 }
1108
1109 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
1110 Poll::Pending
1111 }
1112 }
1113
1114 struct MockConnectionHandler {
1115 outbound_requested: bool,
1116 error: Option<StreamUpgradeError<Void>>,
1117 upgrade_timeout: Duration,
1118 }
1119
1120 impl MockConnectionHandler {
1121 fn new(upgrade_timeout: Duration) -> Self {
1122 Self {
1123 outbound_requested: false,
1124 error: None,
1125 upgrade_timeout,
1126 }
1127 }
1128
1129 fn open_new_outbound(&mut self) {
1130 self.outbound_requested = true;
1131 }
1132 }
1133
1134 #[derive(Default)]
1135 struct ConfigurableProtocolConnectionHandler {
1136 events: Vec<ConnectionHandlerEvent<DeniedUpgrade, (), Void>>,
1137 active_protocols: HashSet<StreamProtocol>,
1138 local_added: Vec<Vec<StreamProtocol>>,
1139 local_removed: Vec<Vec<StreamProtocol>>,
1140 remote_added: Vec<Vec<StreamProtocol>>,
1141 remote_removed: Vec<Vec<StreamProtocol>>,
1142 }
1143
1144 impl ConfigurableProtocolConnectionHandler {
1145 fn listen_on(&mut self, protocols: &[&'static str]) {
1146 self.active_protocols = protocols.iter().copied().map(StreamProtocol::new).collect();
1147 }
1148
1149 fn remote_adds_support_for(&mut self, protocols: &[&'static str]) {
1150 self.events
1151 .push(ConnectionHandlerEvent::ReportRemoteProtocols(
1152 ProtocolSupport::Added(
1153 protocols.iter().copied().map(StreamProtocol::new).collect(),
1154 ),
1155 ));
1156 }
1157
1158 fn remote_removes_support_for(&mut self, protocols: &[&'static str]) {
1159 self.events
1160 .push(ConnectionHandlerEvent::ReportRemoteProtocols(
1161 ProtocolSupport::Removed(
1162 protocols.iter().copied().map(StreamProtocol::new).collect(),
1163 ),
1164 ));
1165 }
1166 }
1167
1168 impl ConnectionHandler for MockConnectionHandler {
1169 type FromBehaviour = Void;
1170 type ToBehaviour = Void;
1171 type InboundProtocol = DeniedUpgrade;
1172 type OutboundProtocol = DeniedUpgrade;
1173 type InboundOpenInfo = ();
1174 type OutboundOpenInfo = ();
1175
1176 fn listen_protocol(
1177 &self,
1178 ) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
1179 SubstreamProtocol::new(DeniedUpgrade, ()).with_timeout(self.upgrade_timeout)
1180 }
1181
1182 fn on_connection_event(
1183 &mut self,
1184 event: ConnectionEvent<
1185 Self::InboundProtocol,
1186 Self::OutboundProtocol,
1187 Self::InboundOpenInfo,
1188 Self::OutboundOpenInfo,
1189 >,
1190 ) {
1191 match event {
1192 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
1193 protocol,
1194 ..
1195 }) => void::unreachable(protocol),
1196 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
1197 protocol,
1198 ..
1199 }) => void::unreachable(protocol),
1200 ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
1201 self.error = Some(error)
1202 }
1203 ConnectionEvent::AddressChange(_)
1204 | ConnectionEvent::ListenUpgradeError(_)
1205 | ConnectionEvent::LocalProtocolsChange(_)
1206 | ConnectionEvent::RemoteProtocolsChange(_) => {}
1207 }
1208 }
1209
1210 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
1211 void::unreachable(event)
1212 }
1213
1214 fn connection_keep_alive(&self) -> bool {
1215 true
1216 }
1217
1218 fn poll(
1219 &mut self,
1220 _: &mut Context<'_>,
1221 ) -> Poll<
1222 ConnectionHandlerEvent<
1223 Self::OutboundProtocol,
1224 Self::OutboundOpenInfo,
1225 Self::ToBehaviour,
1226 >,
1227 > {
1228 if self.outbound_requested {
1229 self.outbound_requested = false;
1230 return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
1231 protocol: SubstreamProtocol::new(DeniedUpgrade, ())
1232 .with_timeout(self.upgrade_timeout),
1233 });
1234 }
1235
1236 Poll::Pending
1237 }
1238 }
1239
1240 impl ConnectionHandler for ConfigurableProtocolConnectionHandler {
1241 type FromBehaviour = Void;
1242 type ToBehaviour = Void;
1243 type InboundProtocol = ManyProtocolsUpgrade;
1244 type OutboundProtocol = DeniedUpgrade;
1245 type InboundOpenInfo = ();
1246 type OutboundOpenInfo = ();
1247
1248 fn listen_protocol(
1249 &self,
1250 ) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
1251 SubstreamProtocol::new(
1252 ManyProtocolsUpgrade {
1253 protocols: Vec::from_iter(self.active_protocols.clone()),
1254 },
1255 (),
1256 )
1257 }
1258
1259 fn on_connection_event(
1260 &mut self,
1261 event: ConnectionEvent<
1262 Self::InboundProtocol,
1263 Self::OutboundProtocol,
1264 Self::InboundOpenInfo,
1265 Self::OutboundOpenInfo,
1266 >,
1267 ) {
1268 match event {
1269 ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Added(added)) => {
1270 self.local_added.push(added.cloned().collect())
1271 }
1272 ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Removed(removed)) => {
1273 self.local_removed.push(removed.cloned().collect())
1274 }
1275 ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Added(added)) => {
1276 self.remote_added.push(added.cloned().collect())
1277 }
1278 ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Removed(removed)) => {
1279 self.remote_removed.push(removed.cloned().collect())
1280 }
1281 _ => {}
1282 }
1283 }
1284
1285 fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
1286 void::unreachable(event)
1287 }
1288
1289 fn connection_keep_alive(&self) -> bool {
1290 true
1291 }
1292
1293 fn poll(
1294 &mut self,
1295 _: &mut Context<'_>,
1296 ) -> Poll<
1297 ConnectionHandlerEvent<
1298 Self::OutboundProtocol,
1299 Self::OutboundOpenInfo,
1300 Self::ToBehaviour,
1301 >,
1302 > {
1303 if let Some(event) = self.events.pop() {
1304 return Poll::Ready(event);
1305 }
1306
1307 Poll::Pending
1308 }
1309 }
1310
1311 struct ManyProtocolsUpgrade {
1312 protocols: Vec<StreamProtocol>,
1313 }
1314
1315 impl UpgradeInfo for ManyProtocolsUpgrade {
1316 type Info = StreamProtocol;
1317 type InfoIter = std::vec::IntoIter<Self::Info>;
1318
1319 fn protocol_info(&self) -> Self::InfoIter {
1320 self.protocols.clone().into_iter()
1321 }
1322 }
1323
1324 impl<C> InboundUpgrade<C> for ManyProtocolsUpgrade {
1325 type Output = C;
1326 type Error = Void;
1327 type Future = future::Ready<Result<Self::Output, Self::Error>>;
1328
1329 fn upgrade_inbound(self, stream: C, _: Self::Info) -> Self::Future {
1330 future::ready(Ok(stream))
1331 }
1332 }
1333
1334 impl<C> OutboundUpgrade<C> for ManyProtocolsUpgrade {
1335 type Output = C;
1336 type Error = Void;
1337 type Future = future::Ready<Result<Self::Output, Self::Error>>;
1338
1339 fn upgrade_outbound(self, stream: C, _: Self::Info) -> Self::Future {
1340 future::ready(Ok(stream))
1341 }
1342 }
1343}
1344
1345#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1347enum PendingPoint {
1348 Dialer {
1354 role_override: Endpoint,
1356 port_use: PortUse,
1357 },
1358 Listener {
1360 local_addr: Multiaddr,
1362 send_back_addr: Multiaddr,
1364 },
1365}
1366
1367impl From<ConnectedPoint> for PendingPoint {
1368 fn from(endpoint: ConnectedPoint) -> Self {
1369 match endpoint {
1370 ConnectedPoint::Dialer {
1371 role_override,
1372 port_use,
1373 ..
1374 } => PendingPoint::Dialer {
1375 role_override,
1376 port_use,
1377 },
1378 ConnectedPoint::Listener {
1379 local_addr,
1380 send_back_addr,
1381 } => PendingPoint::Listener {
1382 local_addr,
1383 send_back_addr,
1384 },
1385 }
1386 }
1387}