use crate::types::ConnectionId;
use std::collections::HashSet;
#[derive(Debug, Clone, Default)]
pub struct ConnectionLimitsConfig {
max_incoming_connections: Option<usize>,
max_outgoing_connections: Option<usize>,
}
impl ConnectionLimitsConfig {
pub fn max_incoming_connections(mut self, limit: Option<usize>) -> Self {
self.max_incoming_connections = limit;
self
}
pub fn max_outgoing_connections(mut self, limit: Option<usize>) -> Self {
self.max_outgoing_connections = limit;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionLimitsError {
MaxIncomingConnectionsExceeded,
MaxOutgoingConnectionsExceeded,
}
#[derive(Debug, Clone)]
pub struct ConnectionLimits {
config: ConnectionLimitsConfig,
incoming_connections: HashSet<ConnectionId>,
outgoing_connections: HashSet<ConnectionId>,
}
impl ConnectionLimits {
pub fn new(config: ConnectionLimitsConfig) -> Self {
let max_incoming_connections = config.max_incoming_connections.unwrap_or(0);
let max_outgoing_connections = config.max_outgoing_connections.unwrap_or(0);
Self {
config,
incoming_connections: HashSet::with_capacity(max_incoming_connections),
outgoing_connections: HashSet::with_capacity(max_outgoing_connections),
}
}
pub fn on_dial_address(&mut self) -> Result<usize, ConnectionLimitsError> {
if let Some(max_outgoing_connections) = self.config.max_outgoing_connections {
if self.outgoing_connections.len() >= max_outgoing_connections {
return Err(ConnectionLimitsError::MaxOutgoingConnectionsExceeded);
}
return Ok(max_outgoing_connections - self.outgoing_connections.len());
}
Ok(usize::MAX)
}
pub fn on_incoming(&mut self) -> Result<(), ConnectionLimitsError> {
if let Some(max_incoming_connections) = self.config.max_incoming_connections {
if self.incoming_connections.len() >= max_incoming_connections {
return Err(ConnectionLimitsError::MaxIncomingConnectionsExceeded);
}
}
Ok(())
}
pub fn on_connection_established(
&mut self,
connection_id: ConnectionId,
is_listener: bool,
) -> Result<(), ConnectionLimitsError> {
if is_listener {
if let Some(max_incoming_connections) = self.config.max_incoming_connections {
if self.incoming_connections.len() >= max_incoming_connections {
return Err(ConnectionLimitsError::MaxIncomingConnectionsExceeded);
}
}
} else if let Some(max_outgoing_connections) = self.config.max_outgoing_connections {
if self.outgoing_connections.len() >= max_outgoing_connections {
return Err(ConnectionLimitsError::MaxOutgoingConnectionsExceeded);
}
}
if is_listener {
if self.config.max_incoming_connections.is_some() {
self.incoming_connections.insert(connection_id);
}
} else if self.config.max_outgoing_connections.is_some() {
self.outgoing_connections.insert(connection_id);
}
Ok(())
}
pub fn on_connection_closed(&mut self, connection_id: ConnectionId) {
self.incoming_connections.remove(&connection_id);
self.outgoing_connections.remove(&connection_id);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::ConnectionId;
#[test]
fn connection_limits() {
let config = ConnectionLimitsConfig::default()
.max_incoming_connections(Some(3))
.max_outgoing_connections(Some(2));
let mut limits = ConnectionLimits::new(config);
let connection_id_in_1 = ConnectionId::random();
let connection_id_in_2 = ConnectionId::random();
let connection_id_out_1 = ConnectionId::random();
let connection_id_out_2 = ConnectionId::random();
let connection_id_in_3 = ConnectionId::random();
let connection_id_out_3 = ConnectionId::random();
assert!(limits.on_connection_established(connection_id_in_1, true).is_ok());
assert_eq!(limits.incoming_connections.len(), 1);
assert!(limits.on_connection_established(connection_id_in_2, true).is_ok());
assert_eq!(limits.incoming_connections.len(), 2);
assert!(limits.on_connection_established(connection_id_in_3, true).is_ok());
assert_eq!(limits.incoming_connections.len(), 3);
assert_eq!(
limits.on_connection_established(ConnectionId::random(), true).unwrap_err(),
ConnectionLimitsError::MaxIncomingConnectionsExceeded
);
assert_eq!(limits.incoming_connections.len(), 3);
assert!(limits.on_connection_established(connection_id_out_1, false).is_ok());
assert_eq!(limits.incoming_connections.len(), 3);
assert_eq!(limits.outgoing_connections.len(), 1);
assert!(limits.on_connection_established(connection_id_out_2, false).is_ok());
assert_eq!(limits.incoming_connections.len(), 3);
assert_eq!(limits.outgoing_connections.len(), 2);
assert_eq!(
limits.on_connection_established(connection_id_out_3, false).unwrap_err(),
ConnectionLimitsError::MaxOutgoingConnectionsExceeded
);
limits.on_connection_closed(connection_id_in_1);
assert_eq!(limits.incoming_connections.len(), 2);
assert_eq!(limits.outgoing_connections.len(), 2);
limits.on_connection_closed(connection_id_out_1);
assert_eq!(limits.incoming_connections.len(), 2);
assert_eq!(limits.outgoing_connections.len(), 1);
}
}