mod error;
pub(crate) mod pool;
mod supported_protocols;
pub use error::ConnectionError;
pub(crate) use error::{
PendingConnectionError, PendingInboundConnectionError, PendingOutboundConnectionError,
};
pub use supported_protocols::SupportedProtocols;
use crate::handler::{
AddressChange, ConnectionEvent, ConnectionHandler, DialUpgradeError, FullyNegotiatedInbound,
FullyNegotiatedOutbound, ListenUpgradeError, ProtocolSupport, ProtocolsAdded, ProtocolsChange,
UpgradeInfoSend,
};
use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend};
use crate::{
ConnectionHandlerEvent, KeepAlive, Stream, StreamProtocol, StreamUpgradeError,
SubstreamProtocol,
};
use futures::future::BoxFuture;
use futures::stream::FuturesUnordered;
use futures::FutureExt;
use futures::StreamExt;
use futures_timer::Delay;
use instant::Instant;
use libp2p_core::connection::ConnectedPoint;
use libp2p_core::multiaddr::Multiaddr;
use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerEvent, StreamMuxerExt, SubstreamBox};
use libp2p_core::upgrade;
use libp2p_core::upgrade::{NegotiationError, ProtocolError};
use libp2p_core::Endpoint;
use libp2p_identity::PeerId;
use std::cmp::max;
use std::collections::HashSet;
use std::fmt::{Display, Formatter};
use std::future::Future;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::Waker;
use std::time::Duration;
use std::{fmt, io, mem, pin::Pin, task::Context, task::Poll};
static NEXT_CONNECTION_ID: AtomicUsize = AtomicUsize::new(1);
#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
pub struct ConnectionId(usize);
impl ConnectionId {
pub fn new_unchecked(id: usize) -> Self {
Self(id)
}
pub(crate) fn next() -> Self {
Self(NEXT_CONNECTION_ID.fetch_add(1, Ordering::SeqCst))
}
}
impl Display for ConnectionId {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct Connected {
pub(crate) endpoint: ConnectedPoint,
pub(crate) peer_id: PeerId,
}
#[derive(Debug, Clone)]
pub(crate) enum Event<T> {
Handler(T),
AddressChange(Multiaddr),
}
pub(crate) struct Connection<THandler>
where
THandler: ConnectionHandler,
{
muxing: StreamMuxerBox,
handler: THandler,
negotiating_in: FuturesUnordered<
StreamUpgrade<
THandler::InboundOpenInfo,
<THandler::InboundProtocol as InboundUpgradeSend>::Output,
<THandler::InboundProtocol as InboundUpgradeSend>::Error,
>,
>,
negotiating_out: FuturesUnordered<
StreamUpgrade<
THandler::OutboundOpenInfo,
<THandler::OutboundProtocol as OutboundUpgradeSend>::Output,
<THandler::OutboundProtocol as OutboundUpgradeSend>::Error,
>,
>,
shutdown: Shutdown,
substream_upgrade_protocol_override: Option<upgrade::Version>,
max_negotiating_inbound_streams: usize,
requested_substreams: FuturesUnordered<
SubstreamRequested<THandler::OutboundOpenInfo, THandler::OutboundProtocol>,
>,
local_supported_protocols: HashSet<StreamProtocol>,
remote_supported_protocols: HashSet<StreamProtocol>,
idle_timeout: Duration,
}
impl<THandler> fmt::Debug for Connection<THandler>
where
THandler: ConnectionHandler + fmt::Debug,
THandler::OutboundOpenInfo: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Connection")
.field("handler", &self.handler)
.finish()
}
}
impl<THandler> Unpin for Connection<THandler> where THandler: ConnectionHandler {}
impl<THandler> Connection<THandler>
where
THandler: ConnectionHandler,
{
pub(crate) fn new(
muxer: StreamMuxerBox,
mut handler: THandler,
substream_upgrade_protocol_override: Option<upgrade::Version>,
max_negotiating_inbound_streams: usize,
idle_timeout: Duration,
) -> Self {
let initial_protocols = gather_supported_protocols(&handler);
if !initial_protocols.is_empty() {
handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(
ProtocolsChange::Added(ProtocolsAdded::from_set(&initial_protocols)),
));
}
Connection {
muxing: muxer,
handler,
negotiating_in: Default::default(),
negotiating_out: Default::default(),
shutdown: Shutdown::None,
substream_upgrade_protocol_override,
max_negotiating_inbound_streams,
requested_substreams: Default::default(),
local_supported_protocols: initial_protocols,
remote_supported_protocols: Default::default(),
idle_timeout,
}
}
pub(crate) fn on_behaviour_event(&mut self, event: THandler::FromBehaviour) {
self.handler.on_behaviour_event(event);
}
pub(crate) fn close(self) -> (THandler, impl Future<Output = io::Result<()>>) {
(self.handler, self.muxing.close())
}
#[allow(deprecated)]
pub(crate) fn poll(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError<THandler::Error>>> {
let Self {
requested_substreams,
muxing,
handler,
negotiating_out,
negotiating_in,
shutdown,
max_negotiating_inbound_streams,
substream_upgrade_protocol_override,
local_supported_protocols: supported_protocols,
remote_supported_protocols,
idle_timeout,
} = self.get_mut();
loop {
match requested_substreams.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(()))) => continue,
Poll::Ready(Some(Err(info))) => {
handler.on_connection_event(ConnectionEvent::DialUpgradeError(
DialUpgradeError {
info,
error: StreamUpgradeError::Timeout,
},
));
continue;
}
Poll::Ready(None) | Poll::Pending => {}
}
match handler.poll(cx) {
Poll::Pending => {}
Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }) => {
let timeout = *protocol.timeout();
let (upgrade, user_data) = protocol.into_upgrade();
requested_substreams.push(SubstreamRequested::new(user_data, timeout, upgrade));
continue; }
Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)) => {
return Poll::Ready(Ok(Event::Handler(event)));
}
#[allow(deprecated)]
Poll::Ready(ConnectionHandlerEvent::Close(err)) => {
return Poll::Ready(Err(ConnectionError::Handler(err)));
}
Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
ProtocolSupport::Added(protocols),
)) => {
if let Some(added) =
ProtocolsChange::add(remote_supported_protocols, &protocols)
{
handler.on_connection_event(ConnectionEvent::RemoteProtocolsChange(added));
remote_supported_protocols.extend(protocols);
}
continue;
}
Poll::Ready(ConnectionHandlerEvent::ReportRemoteProtocols(
ProtocolSupport::Removed(protocols),
)) => {
if let Some(removed) =
ProtocolsChange::remove(remote_supported_protocols, &protocols)
{
handler
.on_connection_event(ConnectionEvent::RemoteProtocolsChange(removed));
remote_supported_protocols.retain(|p| !protocols.contains(p));
}
continue;
}
}
match negotiating_out.poll_next_unpin(cx) {
Poll::Pending | Poll::Ready(None) => {}
Poll::Ready(Some((info, Ok(protocol)))) => {
handler.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
FullyNegotiatedOutbound { protocol, info },
));
continue;
}
Poll::Ready(Some((info, Err(error)))) => {
handler.on_connection_event(ConnectionEvent::DialUpgradeError(
DialUpgradeError { info, error },
));
continue;
}
}
match negotiating_in.poll_next_unpin(cx) {
Poll::Pending | Poll::Ready(None) => {}
Poll::Ready(Some((info, Ok(protocol)))) => {
handler.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
FullyNegotiatedInbound { protocol, info },
));
continue;
}
Poll::Ready(Some((info, Err(StreamUpgradeError::Apply(error))))) => {
handler.on_connection_event(ConnectionEvent::ListenUpgradeError(
ListenUpgradeError { info, error },
));
continue;
}
Poll::Ready(Some((_, Err(StreamUpgradeError::Io(e))))) => {
log::debug!("failed to upgrade inbound stream: {e}");
continue;
}
Poll::Ready(Some((_, Err(StreamUpgradeError::NegotiationFailed)))) => {
log::debug!("no protocol could be agreed upon for inbound stream");
continue;
}
Poll::Ready(Some((_, Err(StreamUpgradeError::Timeout)))) => {
log::debug!("inbound stream upgrade timed out");
continue;
}
}
if let Some(new_shutdown) =
compute_new_shutdown(handler.connection_keep_alive(), shutdown, *idle_timeout)
{
*shutdown = new_shutdown;
}
if negotiating_in.is_empty()
&& negotiating_out.is_empty()
&& requested_substreams.is_empty()
{
match shutdown {
Shutdown::None => {}
Shutdown::Asap => return Poll::Ready(Err(ConnectionError::KeepAliveTimeout)),
Shutdown::Later(delay, _) => match Future::poll(Pin::new(delay), cx) {
Poll::Ready(_) => {
return Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
}
Poll::Pending => {}
},
}
}
match muxing.poll_unpin(cx)? {
Poll::Pending => {}
Poll::Ready(StreamMuxerEvent::AddressChange(address)) => {
handler.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
new_address: &address,
}));
return Poll::Ready(Ok(Event::AddressChange(address)));
}
}
if let Some(requested_substream) = requested_substreams.iter_mut().next() {
match muxing.poll_outbound_unpin(cx)? {
Poll::Pending => {}
Poll::Ready(substream) => {
let (user_data, timeout, upgrade) = requested_substream.extract();
negotiating_out.push(StreamUpgrade::new_outbound(
substream,
user_data,
timeout,
upgrade,
*substream_upgrade_protocol_override,
));
continue; }
}
}
if negotiating_in.len() < *max_negotiating_inbound_streams {
match muxing.poll_inbound_unpin(cx)? {
Poll::Pending => {}
Poll::Ready(substream) => {
let protocol = handler.listen_protocol();
negotiating_in.push(StreamUpgrade::new_inbound(substream, protocol));
continue; }
}
}
let new_protocols = gather_supported_protocols(handler);
let changes = ProtocolsChange::from_full_sets(supported_protocols, &new_protocols);
if !changes.is_empty() {
for change in changes {
handler.on_connection_event(ConnectionEvent::LocalProtocolsChange(change));
}
*supported_protocols = new_protocols;
continue; }
return Poll::Pending; }
}
#[cfg(test)]
#[allow(deprecated)]
fn poll_noop_waker(
&mut self,
) -> Poll<Result<Event<THandler::ToBehaviour>, ConnectionError<THandler::Error>>> {
Pin::new(self).poll(&mut Context::from_waker(futures::task::noop_waker_ref()))
}
}
fn gather_supported_protocols(handler: &impl ConnectionHandler) -> HashSet<StreamProtocol> {
handler
.listen_protocol()
.upgrade()
.protocol_info()
.filter_map(|i| StreamProtocol::try_from_owned(i.as_ref().to_owned()).ok())
.collect()
}
fn compute_new_shutdown(
handler_keep_alive: KeepAlive,
current_shutdown: &Shutdown,
idle_timeout: Duration,
) -> Option<Shutdown> {
#[allow(deprecated)]
match (current_shutdown, handler_keep_alive) {
(Shutdown::Later(_, deadline), KeepAlive::Until(t)) => {
let now = Instant::now();
if *deadline != t {
let deadline = t;
if let Some(new_duration) = deadline.checked_duration_since(Instant::now()) {
let effective_keep_alive = max(new_duration, idle_timeout);
let safe_keep_alive = checked_add_fraction(now, effective_keep_alive);
return Some(Shutdown::Later(Delay::new(safe_keep_alive), deadline));
}
}
None
}
(_, KeepAlive::Until(earliest_shutdown)) => {
let now = Instant::now();
if let Some(requested) = earliest_shutdown.checked_duration_since(now) {
let effective_keep_alive = max(requested, idle_timeout);
let safe_keep_alive = checked_add_fraction(now, effective_keep_alive);
return Some(Shutdown::Later(
Delay::new(safe_keep_alive),
earliest_shutdown,
));
}
None
}
(_, KeepAlive::No) if idle_timeout == Duration::ZERO => Some(Shutdown::Asap),
(Shutdown::Later(_, _), KeepAlive::No) => {
None
}
(_, KeepAlive::No) => {
let now = Instant::now();
let safe_keep_alive = checked_add_fraction(now, idle_timeout);
Some(Shutdown::Later(
Delay::new(safe_keep_alive),
now + safe_keep_alive,
))
}
(_, KeepAlive::Yes) => Some(Shutdown::None),
}
}
fn checked_add_fraction(start: Instant, mut duration: Duration) -> Duration {
while start.checked_add(duration).is_none() {
log::debug!("{start:?} + {duration:?} cannot be presented, halving duration");
duration /= 2;
}
duration
}
#[derive(Debug, Copy, Clone)]
pub(crate) struct IncomingInfo<'a> {
pub(crate) local_addr: &'a Multiaddr,
pub(crate) send_back_addr: &'a Multiaddr,
}
impl<'a> IncomingInfo<'a> {
pub(crate) fn create_connected_point(&self) -> ConnectedPoint {
ConnectedPoint::Listener {
local_addr: self.local_addr.clone(),
send_back_addr: self.send_back_addr.clone(),
}
}
}
struct StreamUpgrade<UserData, TOk, TErr> {
user_data: Option<UserData>,
timeout: Delay,
upgrade: BoxFuture<'static, Result<TOk, StreamUpgradeError<TErr>>>,
}
impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
fn new_outbound<Upgrade>(
substream: SubstreamBox,
user_data: UserData,
timeout: Delay,
upgrade: Upgrade,
version_override: Option<upgrade::Version>,
) -> Self
where
Upgrade: OutboundUpgradeSend<Output = TOk, Error = TErr>,
{
let effective_version = match version_override {
Some(version_override) if version_override != upgrade::Version::default() => {
log::debug!(
"Substream upgrade protocol override: {:?} -> {:?}",
upgrade::Version::default(),
version_override
);
version_override
}
_ => upgrade::Version::default(),
};
let protocols = upgrade.protocol_info();
Self {
user_data: Some(user_data),
timeout,
upgrade: Box::pin(async move {
let (info, stream) = multistream_select::dialer_select_proto(
substream,
protocols,
effective_version,
)
.await
.map_err(to_stream_upgrade_error)?;
let output = upgrade
.upgrade_outbound(Stream::new(stream), info)
.await
.map_err(StreamUpgradeError::Apply)?;
Ok(output)
}),
}
}
}
impl<UserData, TOk, TErr> StreamUpgrade<UserData, TOk, TErr> {
fn new_inbound<Upgrade>(
substream: SubstreamBox,
protocol: SubstreamProtocol<Upgrade, UserData>,
) -> Self
where
Upgrade: InboundUpgradeSend<Output = TOk, Error = TErr>,
{
let timeout = *protocol.timeout();
let (upgrade, open_info) = protocol.into_upgrade();
let protocols = upgrade.protocol_info();
Self {
user_data: Some(open_info),
timeout: Delay::new(timeout),
upgrade: Box::pin(async move {
let (info, stream) =
multistream_select::listener_select_proto(substream, protocols)
.await
.map_err(to_stream_upgrade_error)?;
let output = upgrade
.upgrade_inbound(Stream::new(stream), info)
.await
.map_err(StreamUpgradeError::Apply)?;
Ok(output)
}),
}
}
}
fn to_stream_upgrade_error<T>(e: NegotiationError) -> StreamUpgradeError<T> {
match e {
NegotiationError::Failed => StreamUpgradeError::NegotiationFailed,
NegotiationError::ProtocolError(ProtocolError::IoError(e)) => StreamUpgradeError::Io(e),
NegotiationError::ProtocolError(other) => {
StreamUpgradeError::Io(io::Error::new(io::ErrorKind::Other, other))
}
}
}
impl<UserData, TOk, TErr> Unpin for StreamUpgrade<UserData, TOk, TErr> {}
impl<UserData, TOk, TErr> Future for StreamUpgrade<UserData, TOk, TErr> {
type Output = (UserData, Result<TOk, StreamUpgradeError<TErr>>);
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match self.timeout.poll_unpin(cx) {
Poll::Ready(()) => {
return Poll::Ready((
self.user_data
.take()
.expect("Future not to be polled again once ready."),
Err(StreamUpgradeError::Timeout),
))
}
Poll::Pending => {}
}
let result = futures::ready!(self.upgrade.poll_unpin(cx));
let user_data = self
.user_data
.take()
.expect("Future not to be polled again once ready.");
Poll::Ready((user_data, result))
}
}
enum SubstreamRequested<UserData, Upgrade> {
Waiting {
user_data: UserData,
timeout: Delay,
upgrade: Upgrade,
extracted_waker: Option<Waker>,
},
Done,
}
impl<UserData, Upgrade> SubstreamRequested<UserData, Upgrade> {
fn new(user_data: UserData, timeout: Duration, upgrade: Upgrade) -> Self {
Self::Waiting {
user_data,
timeout: Delay::new(timeout),
upgrade,
extracted_waker: None,
}
}
fn extract(&mut self) -> (UserData, Delay, Upgrade) {
match mem::replace(self, Self::Done) {
SubstreamRequested::Waiting {
user_data,
timeout,
upgrade,
extracted_waker: waker,
} => {
if let Some(waker) = waker {
waker.wake();
}
(user_data, timeout, upgrade)
}
SubstreamRequested::Done => panic!("cannot extract twice"),
}
}
}
impl<UserData, Upgrade> Unpin for SubstreamRequested<UserData, Upgrade> {}
impl<UserData, Upgrade> Future for SubstreamRequested<UserData, Upgrade> {
type Output = Result<(), UserData>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match mem::replace(this, Self::Done) {
SubstreamRequested::Waiting {
user_data,
upgrade,
mut timeout,
..
} => match timeout.poll_unpin(cx) {
Poll::Ready(()) => Poll::Ready(Err(user_data)),
Poll::Pending => {
*this = Self::Waiting {
user_data,
upgrade,
timeout,
extracted_waker: Some(cx.waker().clone()),
};
Poll::Pending
}
},
SubstreamRequested::Done => Poll::Ready(Ok(())),
}
}
}
#[derive(Debug)]
enum Shutdown {
None,
Asap,
Later(Delay, Instant),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dummy;
use futures::future;
use futures::AsyncRead;
use futures::AsyncWrite;
use libp2p_core::upgrade::{DeniedUpgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo};
use libp2p_core::StreamMuxer;
use quickcheck::*;
use std::sync::{Arc, Weak};
use std::time::Instant;
use void::Void;
#[test]
fn max_negotiating_inbound_streams() {
let _ = env_logger::try_init();
fn prop(max_negotiating_inbound_streams: u8) {
let max_negotiating_inbound_streams: usize = max_negotiating_inbound_streams.into();
let alive_substream_counter = Arc::new(());
let mut connection = Connection::new(
StreamMuxerBox::new(DummyStreamMuxer {
counter: alive_substream_counter.clone(),
}),
MockConnectionHandler::new(Duration::from_secs(10)),
None,
max_negotiating_inbound_streams,
Duration::ZERO,
);
let result = connection.poll_noop_waker();
assert!(result.is_pending());
assert_eq!(
Arc::weak_count(&alive_substream_counter),
max_negotiating_inbound_streams,
"Expect no more than the maximum number of allowed streams"
);
}
QuickCheck::new().quickcheck(prop as fn(_));
}
#[test]
fn outbound_stream_timeout_starts_on_request() {
let upgrade_timeout = Duration::from_secs(1);
let mut connection = Connection::new(
StreamMuxerBox::new(PendingStreamMuxer),
MockConnectionHandler::new(upgrade_timeout),
None,
2,
Duration::ZERO,
);
connection.handler.open_new_outbound();
let _ = connection.poll_noop_waker();
std::thread::sleep(upgrade_timeout + Duration::from_secs(1));
let _ = connection.poll_noop_waker();
assert!(matches!(
connection.handler.error.unwrap(),
StreamUpgradeError::Timeout
))
}
#[test]
fn propagates_changes_to_supported_inbound_protocols() {
let mut connection = Connection::new(
StreamMuxerBox::new(PendingStreamMuxer),
ConfigurableProtocolConnectionHandler::default(),
None,
0,
Duration::ZERO,
);
connection.handler.listen_on(&["/foo"]);
let _ = connection.poll_noop_waker();
assert_eq!(connection.handler.local_added, vec![vec!["/foo"]]);
assert!(connection.handler.local_removed.is_empty());
connection.handler.listen_on(&["/foo", "/bar"]);
let _ = connection.poll_noop_waker();
assert_eq!(
connection.handler.local_added,
vec![vec!["/foo"], vec!["/bar"]],
"expect to only receive an event for the newly added protocols"
);
assert!(connection.handler.local_removed.is_empty());
connection.handler.listen_on(&["/bar"]);
let _ = connection.poll_noop_waker();
assert_eq!(
connection.handler.local_added,
vec![vec!["/foo"], vec!["/bar"]]
);
assert_eq!(connection.handler.local_removed, vec![vec!["/foo"]]);
}
#[test]
fn only_propagtes_actual_changes_to_remote_protocols_to_handler() {
let mut connection = Connection::new(
StreamMuxerBox::new(PendingStreamMuxer),
ConfigurableProtocolConnectionHandler::default(),
None,
0,
Duration::ZERO,
);
connection.handler.remote_adds_support_for(&["/foo"]);
let _ = connection.poll_noop_waker();
assert_eq!(connection.handler.remote_added, vec![vec!["/foo"]]);
assert!(connection.handler.remote_removed.is_empty());
connection
.handler
.remote_adds_support_for(&["/foo", "/bar"]);
let _ = connection.poll_noop_waker();
assert_eq!(
connection.handler.remote_added,
vec![vec!["/foo"], vec!["/bar"]],
"expect to only receive an event for the newly added protocol"
);
assert!(connection.handler.remote_removed.is_empty());
connection.handler.remote_removes_support_for(&["/baz"]);
let _ = connection.poll_noop_waker();
assert_eq!(
connection.handler.remote_added,
vec![vec!["/foo"], vec!["/bar"]]
);
assert!(&connection.handler.remote_removed.is_empty());
connection.handler.remote_removes_support_for(&["/bar"]);
let _ = connection.poll_noop_waker();
assert_eq!(
connection.handler.remote_added,
vec![vec!["/foo"], vec!["/bar"]]
);
assert_eq!(connection.handler.remote_removed, vec![vec!["/bar"]]);
}
#[tokio::test]
async fn idle_timeout_with_keep_alive_no() {
let idle_timeout = Duration::from_millis(100);
let mut connection = Connection::new(
StreamMuxerBox::new(PendingStreamMuxer),
dummy::ConnectionHandler,
None,
0,
idle_timeout,
);
assert!(connection.poll_noop_waker().is_pending());
tokio::time::sleep(idle_timeout).await;
assert!(matches!(
connection.poll_noop_waker(),
Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
));
}
#[tokio::test]
async fn idle_timeout_with_keep_alive_until_greater_than_idle_timeout() {
let idle_timeout = Duration::from_millis(100);
let mut connection = Connection::new(
StreamMuxerBox::new(PendingStreamMuxer),
KeepAliveUntilConnectionHandler {
until: Instant::now() + idle_timeout * 2,
},
None,
0,
idle_timeout,
);
assert!(connection.poll_noop_waker().is_pending());
tokio::time::sleep(idle_timeout).await;
assert!(
connection.poll_noop_waker().is_pending(),
"`KeepAlive::Until` is greater than idle-timeout, continue sleeping"
);
tokio::time::sleep(idle_timeout).await;
assert!(matches!(
connection.poll_noop_waker(),
Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
));
}
#[tokio::test]
async fn idle_timeout_with_keep_alive_until_less_than_idle_timeout() {
let idle_timeout = Duration::from_millis(100);
let mut connection = Connection::new(
StreamMuxerBox::new(PendingStreamMuxer),
KeepAliveUntilConnectionHandler {
until: Instant::now() + idle_timeout / 2,
},
None,
0,
idle_timeout,
);
assert!(connection.poll_noop_waker().is_pending());
tokio::time::sleep(idle_timeout / 2).await;
assert!(
connection.poll_noop_waker().is_pending(),
"`KeepAlive::Until` is less than idle-timeout, honor idle-timeout"
);
tokio::time::sleep(idle_timeout / 2).await;
assert!(matches!(
connection.poll_noop_waker(),
Poll::Ready(Err(ConnectionError::KeepAliveTimeout))
));
}
#[test]
fn checked_add_fraction_can_add_u64_max() {
let _ = env_logger::try_init();
let start = Instant::now();
let duration = checked_add_fraction(start, Duration::from_secs(u64::MAX));
assert!(start.checked_add(duration).is_some())
}
#[test]
fn compute_new_shutdown_does_not_panic() {
let _ = env_logger::try_init();
#[derive(Debug)]
struct ArbitraryShutdown(Shutdown);
impl Clone for ArbitraryShutdown {
fn clone(&self) -> Self {
let shutdown = match self.0 {
Shutdown::None => Shutdown::None,
Shutdown::Asap => Shutdown::Asap,
Shutdown::Later(_, instant) => Shutdown::Later(
Delay::new(Duration::from_secs(1)),
instant,
),
};
ArbitraryShutdown(shutdown)
}
}
impl Arbitrary for ArbitraryShutdown {
fn arbitrary(g: &mut Gen) -> Self {
let shutdown = match g.gen_range(1u8..4) {
1 => Shutdown::None,
2 => Shutdown::Asap,
3 => Shutdown::Later(
Delay::new(Duration::from_secs(u32::arbitrary(g) as u64)),
Instant::now()
.checked_add(Duration::arbitrary(g))
.unwrap_or(Instant::now()),
),
_ => unreachable!(),
};
Self(shutdown)
}
}
fn prop(
handler_keep_alive: KeepAlive,
current_shutdown: ArbitraryShutdown,
idle_timeout: Duration,
) {
compute_new_shutdown(handler_keep_alive, ¤t_shutdown.0, idle_timeout);
}
QuickCheck::new().quickcheck(prop as fn(_, _, _));
}
struct KeepAliveUntilConnectionHandler {
until: Instant,
}
impl ConnectionHandler for KeepAliveUntilConnectionHandler {
type FromBehaviour = Void;
type ToBehaviour = Void;
type Error = Void;
type InboundProtocol = DeniedUpgrade;
type OutboundProtocol = DeniedUpgrade;
type InboundOpenInfo = ();
type OutboundOpenInfo = Void;
fn listen_protocol(
&self,
) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
SubstreamProtocol::new(DeniedUpgrade, ())
}
fn connection_keep_alive(&self) -> KeepAlive {
#[allow(deprecated)]
KeepAlive::Until(self.until)
}
#[allow(deprecated)]
fn poll(
&mut self,
_: &mut Context<'_>,
) -> Poll<
ConnectionHandlerEvent<
Self::OutboundProtocol,
Self::OutboundOpenInfo,
Self::ToBehaviour,
Self::Error,
>,
> {
Poll::Pending
}
fn on_behaviour_event(&mut self, _: Self::FromBehaviour) {}
fn on_connection_event(
&mut self,
_: ConnectionEvent<
Self::InboundProtocol,
Self::OutboundProtocol,
Self::InboundOpenInfo,
Self::OutboundOpenInfo,
>,
) {
}
}
struct DummyStreamMuxer {
counter: Arc<()>,
}
impl StreamMuxer for DummyStreamMuxer {
type Substream = PendingSubstream;
type Error = Void;
fn poll_inbound(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
Poll::Ready(Ok(PendingSubstream(Arc::downgrade(&self.counter))))
}
fn poll_outbound(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
Poll::Pending
}
}
struct PendingStreamMuxer;
impl StreamMuxer for PendingStreamMuxer {
type Substream = PendingSubstream;
type Error = Void;
fn poll_inbound(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
Poll::Pending
}
fn poll_outbound(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Pending
}
fn poll(
self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
Poll::Pending
}
}
struct PendingSubstream(Weak<()>);
impl AsyncRead for PendingSubstream {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
Poll::Pending
}
}
impl AsyncWrite for PendingSubstream {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Poll::Pending
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Poll::Pending
}
}
struct MockConnectionHandler {
outbound_requested: bool,
error: Option<StreamUpgradeError<Void>>,
upgrade_timeout: Duration,
}
impl MockConnectionHandler {
fn new(upgrade_timeout: Duration) -> Self {
Self {
outbound_requested: false,
error: None,
upgrade_timeout,
}
}
fn open_new_outbound(&mut self) {
self.outbound_requested = true;
}
}
#[derive(Default)]
struct ConfigurableProtocolConnectionHandler {
events: Vec<ConnectionHandlerEvent<DeniedUpgrade, (), Void, Void>>,
active_protocols: HashSet<StreamProtocol>,
local_added: Vec<Vec<StreamProtocol>>,
local_removed: Vec<Vec<StreamProtocol>>,
remote_added: Vec<Vec<StreamProtocol>>,
remote_removed: Vec<Vec<StreamProtocol>>,
}
impl ConfigurableProtocolConnectionHandler {
fn listen_on(&mut self, protocols: &[&'static str]) {
self.active_protocols = protocols.iter().copied().map(StreamProtocol::new).collect();
}
fn remote_adds_support_for(&mut self, protocols: &[&'static str]) {
self.events
.push(ConnectionHandlerEvent::ReportRemoteProtocols(
ProtocolSupport::Added(
protocols.iter().copied().map(StreamProtocol::new).collect(),
),
));
}
fn remote_removes_support_for(&mut self, protocols: &[&'static str]) {
self.events
.push(ConnectionHandlerEvent::ReportRemoteProtocols(
ProtocolSupport::Removed(
protocols.iter().copied().map(StreamProtocol::new).collect(),
),
));
}
}
impl ConnectionHandler for MockConnectionHandler {
type FromBehaviour = Void;
type ToBehaviour = Void;
type Error = Void;
type InboundProtocol = DeniedUpgrade;
type OutboundProtocol = DeniedUpgrade;
type InboundOpenInfo = ();
type OutboundOpenInfo = ();
fn listen_protocol(
&self,
) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
SubstreamProtocol::new(DeniedUpgrade, ()).with_timeout(self.upgrade_timeout)
}
fn on_connection_event(
&mut self,
event: ConnectionEvent<
Self::InboundProtocol,
Self::OutboundProtocol,
Self::InboundOpenInfo,
Self::OutboundOpenInfo,
>,
) {
match event {
ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
protocol,
..
}) => void::unreachable(protocol),
ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
protocol,
..
}) => void::unreachable(protocol),
ConnectionEvent::DialUpgradeError(DialUpgradeError { error, .. }) => {
self.error = Some(error)
}
ConnectionEvent::AddressChange(_)
| ConnectionEvent::ListenUpgradeError(_)
| ConnectionEvent::LocalProtocolsChange(_)
| ConnectionEvent::RemoteProtocolsChange(_) => {}
}
}
fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
void::unreachable(event)
}
fn connection_keep_alive(&self) -> KeepAlive {
KeepAlive::Yes
}
#[allow(deprecated)]
fn poll(
&mut self,
_: &mut Context<'_>,
) -> Poll<
ConnectionHandlerEvent<
Self::OutboundProtocol,
Self::OutboundOpenInfo,
Self::ToBehaviour,
Self::Error,
>,
> {
if self.outbound_requested {
self.outbound_requested = false;
return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(DeniedUpgrade, ())
.with_timeout(self.upgrade_timeout),
});
}
Poll::Pending
}
}
impl ConnectionHandler for ConfigurableProtocolConnectionHandler {
type FromBehaviour = Void;
type ToBehaviour = Void;
type Error = Void;
type InboundProtocol = ManyProtocolsUpgrade;
type OutboundProtocol = DeniedUpgrade;
type InboundOpenInfo = ();
type OutboundOpenInfo = ();
fn listen_protocol(
&self,
) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
SubstreamProtocol::new(
ManyProtocolsUpgrade {
protocols: Vec::from_iter(self.active_protocols.clone()),
},
(),
)
}
fn on_connection_event(
&mut self,
event: ConnectionEvent<
Self::InboundProtocol,
Self::OutboundProtocol,
Self::InboundOpenInfo,
Self::OutboundOpenInfo,
>,
) {
match event {
ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Added(added)) => {
self.local_added.push(added.cloned().collect())
}
ConnectionEvent::LocalProtocolsChange(ProtocolsChange::Removed(removed)) => {
self.local_removed.push(removed.cloned().collect())
}
ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Added(added)) => {
self.remote_added.push(added.cloned().collect())
}
ConnectionEvent::RemoteProtocolsChange(ProtocolsChange::Removed(removed)) => {
self.remote_removed.push(removed.cloned().collect())
}
_ => {}
}
}
fn on_behaviour_event(&mut self, event: Self::FromBehaviour) {
void::unreachable(event)
}
fn connection_keep_alive(&self) -> KeepAlive {
KeepAlive::Yes
}
#[allow(deprecated)]
fn poll(
&mut self,
_: &mut Context<'_>,
) -> Poll<
ConnectionHandlerEvent<
Self::OutboundProtocol,
Self::OutboundOpenInfo,
Self::ToBehaviour,
Self::Error,
>,
> {
if let Some(event) = self.events.pop() {
return Poll::Ready(event);
}
Poll::Pending
}
}
struct ManyProtocolsUpgrade {
protocols: Vec<StreamProtocol>,
}
impl UpgradeInfo for ManyProtocolsUpgrade {
type Info = StreamProtocol;
type InfoIter = std::vec::IntoIter<Self::Info>;
fn protocol_info(&self) -> Self::InfoIter {
self.protocols.clone().into_iter()
}
}
impl<C> InboundUpgrade<C> for ManyProtocolsUpgrade {
type Output = C;
type Error = Void;
type Future = future::Ready<Result<Self::Output, Self::Error>>;
fn upgrade_inbound(self, stream: C, _: Self::Info) -> Self::Future {
future::ready(Ok(stream))
}
}
impl<C> OutboundUpgrade<C> for ManyProtocolsUpgrade {
type Output = C;
type Error = Void;
type Future = future::Ready<Result<Self::Output, Self::Error>>;
fn upgrade_outbound(self, stream: C, _: Self::Info) -> Self::Future {
future::ready(Ok(stream))
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum PendingPoint {
Dialer {
role_override: Endpoint,
},
Listener {
local_addr: Multiaddr,
send_back_addr: Multiaddr,
},
}
impl From<ConnectedPoint> for PendingPoint {
fn from(endpoint: ConnectedPoint) -> Self {
match endpoint {
ConnectedPoint::Dialer { role_override, .. } => PendingPoint::Dialer { role_override },
ConnectedPoint::Listener {
local_addr,
send_back_addr,
} => PendingPoint::Listener {
local_addr,
send_back_addr,
},
}
}
}