litep2p/transport/manager/
limits.rs1use crate::types::ConnectionId;
24
25use std::collections::HashSet;
26
27#[derive(Debug, Clone, Default)]
29pub struct ConnectionLimitsConfig {
30 max_incoming_connections: Option<usize>,
32 max_outgoing_connections: Option<usize>,
34}
35
36impl ConnectionLimitsConfig {
37 pub fn max_incoming_connections(mut self, limit: Option<usize>) -> Self {
39 self.max_incoming_connections = limit;
40 self
41 }
42
43 pub fn max_outgoing_connections(mut self, limit: Option<usize>) -> Self {
45 self.max_outgoing_connections = limit;
46 self
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum ConnectionLimitsError {
53 MaxIncomingConnectionsExceeded,
55 MaxOutgoingConnectionsExceeded,
57}
58
59#[derive(Debug, Clone)]
61pub struct ConnectionLimits {
62 config: ConnectionLimitsConfig,
64
65 incoming_connections: HashSet<ConnectionId>,
67 outgoing_connections: HashSet<ConnectionId>,
69}
70
71impl ConnectionLimits {
72 pub fn new(config: ConnectionLimitsConfig) -> Self {
74 let max_incoming_connections = config.max_incoming_connections.unwrap_or(0);
75 let max_outgoing_connections = config.max_outgoing_connections.unwrap_or(0);
76
77 Self {
78 config,
79 incoming_connections: HashSet::with_capacity(max_incoming_connections),
80 outgoing_connections: HashSet::with_capacity(max_outgoing_connections),
81 }
82 }
83
84 pub fn on_dial_address(&mut self) -> Result<usize, ConnectionLimitsError> {
93 if let Some(max_outgoing_connections) = self.config.max_outgoing_connections {
94 if self.outgoing_connections.len() >= max_outgoing_connections {
95 return Err(ConnectionLimitsError::MaxOutgoingConnectionsExceeded);
96 }
97
98 return Ok(max_outgoing_connections - self.outgoing_connections.len());
99 }
100
101 Ok(usize::MAX)
102 }
103
104 pub fn on_incoming(&mut self) -> Result<(), ConnectionLimitsError> {
106 if let Some(max_incoming_connections) = self.config.max_incoming_connections {
107 if self.incoming_connections.len() >= max_incoming_connections {
108 return Err(ConnectionLimitsError::MaxIncomingConnectionsExceeded);
109 }
110 }
111
112 Ok(())
113 }
114
115 pub fn can_accept_connection(
119 &mut self,
120 is_listener: bool,
121 ) -> Result<(), ConnectionLimitsError> {
122 if is_listener {
124 if let Some(max_incoming_connections) = self.config.max_incoming_connections {
125 if self.incoming_connections.len() >= max_incoming_connections {
126 return Err(ConnectionLimitsError::MaxIncomingConnectionsExceeded);
127 }
128 }
129 } else if let Some(max_outgoing_connections) = self.config.max_outgoing_connections {
130 if self.outgoing_connections.len() >= max_outgoing_connections {
131 return Err(ConnectionLimitsError::MaxOutgoingConnectionsExceeded);
132 }
133 }
134
135 Ok(())
136 }
137
138 pub fn accept_established_connection(
145 &mut self,
146 connection_id: ConnectionId,
147 is_listener: bool,
148 ) {
149 if is_listener {
150 if self.config.max_incoming_connections.is_some() {
151 self.incoming_connections.insert(connection_id);
152 }
153 } else if self.config.max_outgoing_connections.is_some() {
154 self.outgoing_connections.insert(connection_id);
155 }
156 }
157
158 pub fn on_connection_closed(&mut self, connection_id: ConnectionId) {
160 self.incoming_connections.remove(&connection_id);
161 self.outgoing_connections.remove(&connection_id);
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::types::ConnectionId;
169
170 #[test]
171 fn connection_limits() {
172 let config = ConnectionLimitsConfig::default()
173 .max_incoming_connections(Some(3))
174 .max_outgoing_connections(Some(2));
175 let mut limits = ConnectionLimits::new(config);
176
177 let connection_id_in_1 = ConnectionId::random();
178 let connection_id_in_2 = ConnectionId::random();
179 let connection_id_out_1 = ConnectionId::random();
180 let connection_id_out_2 = ConnectionId::random();
181 let connection_id_in_3 = ConnectionId::random();
182
183 assert!(limits.can_accept_connection(true).is_ok());
185 limits.accept_established_connection(connection_id_in_1, true);
186 assert_eq!(limits.incoming_connections.len(), 1);
187
188 assert!(limits.can_accept_connection(true).is_ok());
189 limits.accept_established_connection(connection_id_in_2, true);
190 assert_eq!(limits.incoming_connections.len(), 2);
191
192 assert!(limits.can_accept_connection(true).is_ok());
193 limits.accept_established_connection(connection_id_in_3, true);
194 assert_eq!(limits.incoming_connections.len(), 3);
195
196 assert_eq!(
197 limits.can_accept_connection(true).unwrap_err(),
198 ConnectionLimitsError::MaxIncomingConnectionsExceeded
199 );
200 assert_eq!(limits.incoming_connections.len(), 3);
201
202 assert!(limits.can_accept_connection(false).is_ok());
204 limits.accept_established_connection(connection_id_out_1, false);
205 assert_eq!(limits.incoming_connections.len(), 3);
206 assert_eq!(limits.outgoing_connections.len(), 1);
207
208 assert!(limits.can_accept_connection(false).is_ok());
209 limits.accept_established_connection(connection_id_out_2, false);
210 assert_eq!(limits.incoming_connections.len(), 3);
211 assert_eq!(limits.outgoing_connections.len(), 2);
212
213 assert_eq!(
214 limits.can_accept_connection(false).unwrap_err(),
215 ConnectionLimitsError::MaxOutgoingConnectionsExceeded
216 );
217
218 limits.on_connection_closed(connection_id_in_1);
220 assert_eq!(limits.incoming_connections.len(), 2);
221 assert_eq!(limits.outgoing_connections.len(), 2);
222
223 limits.on_connection_closed(connection_id_out_1);
224 assert_eq!(limits.incoming_connections.len(), 2);
225 assert_eq!(limits.outgoing_connections.len(), 1);
226 }
227}