litep2p/transport/manager/
limits.rs1use crate::types::ConnectionId;
24
25use std::collections::HashSet;
26
27#[derive(Debug, Clone, Default)]
29pub struct ConnectionLimitsConfig {
30 max_incoming_connections: Option<usize>,
32 max_outgoing_connections: Option<usize>,
34}
35
36impl ConnectionLimitsConfig {
37 pub fn max_incoming_connections(mut self, limit: Option<usize>) -> Self {
39 self.max_incoming_connections = limit;
40 self
41 }
42
43 pub fn max_outgoing_connections(mut self, limit: Option<usize>) -> Self {
45 self.max_outgoing_connections = limit;
46 self
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum ConnectionLimitsError {
53 MaxIncomingConnectionsExceeded,
55 MaxOutgoingConnectionsExceeded,
57}
58
59#[derive(Debug, Clone)]
61pub struct ConnectionLimits {
62 config: ConnectionLimitsConfig,
64
65 incoming_connections: HashSet<ConnectionId>,
67 outgoing_connections: HashSet<ConnectionId>,
69}
70
71impl ConnectionLimits {
72 pub fn new(config: ConnectionLimitsConfig) -> Self {
74 let max_incoming_connections = config.max_incoming_connections.unwrap_or(0);
75 let max_outgoing_connections = config.max_outgoing_connections.unwrap_or(0);
76
77 Self {
78 config,
79 incoming_connections: HashSet::with_capacity(max_incoming_connections),
80 outgoing_connections: HashSet::with_capacity(max_outgoing_connections),
81 }
82 }
83
84 pub fn on_dial_address(&mut self) -> Result<usize, ConnectionLimitsError> {
93 if let Some(max_outgoing_connections) = self.config.max_outgoing_connections {
94 if self.outgoing_connections.len() >= max_outgoing_connections {
95 return Err(ConnectionLimitsError::MaxOutgoingConnectionsExceeded);
96 }
97
98 return Ok(max_outgoing_connections - self.outgoing_connections.len());
99 }
100
101 Ok(usize::MAX)
102 }
103
104 pub fn on_incoming(&mut self) -> Result<(), ConnectionLimitsError> {
106 if let Some(max_incoming_connections) = self.config.max_incoming_connections {
107 if self.incoming_connections.len() >= max_incoming_connections {
108 return Err(ConnectionLimitsError::MaxIncomingConnectionsExceeded);
109 }
110 }
111
112 Ok(())
113 }
114
115 pub fn on_connection_established(
117 &mut self,
118 connection_id: ConnectionId,
119 is_listener: bool,
120 ) -> Result<(), ConnectionLimitsError> {
121 if is_listener {
123 if let Some(max_incoming_connections) = self.config.max_incoming_connections {
124 if self.incoming_connections.len() >= max_incoming_connections {
125 return Err(ConnectionLimitsError::MaxIncomingConnectionsExceeded);
126 }
127 }
128 } else if let Some(max_outgoing_connections) = self.config.max_outgoing_connections {
129 if self.outgoing_connections.len() >= max_outgoing_connections {
130 return Err(ConnectionLimitsError::MaxOutgoingConnectionsExceeded);
131 }
132 }
133
134 if is_listener {
136 if self.config.max_incoming_connections.is_some() {
137 self.incoming_connections.insert(connection_id);
138 }
139 } else if self.config.max_outgoing_connections.is_some() {
140 self.outgoing_connections.insert(connection_id);
141 }
142
143 Ok(())
144 }
145
146 pub fn on_connection_closed(&mut self, connection_id: ConnectionId) {
148 self.incoming_connections.remove(&connection_id);
149 self.outgoing_connections.remove(&connection_id);
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use crate::types::ConnectionId;
157
158 #[test]
159 fn connection_limits() {
160 let config = ConnectionLimitsConfig::default()
161 .max_incoming_connections(Some(3))
162 .max_outgoing_connections(Some(2));
163 let mut limits = ConnectionLimits::new(config);
164
165 let connection_id_in_1 = ConnectionId::random();
166 let connection_id_in_2 = ConnectionId::random();
167 let connection_id_out_1 = ConnectionId::random();
168 let connection_id_out_2 = ConnectionId::random();
169 let connection_id_in_3 = ConnectionId::random();
170 let connection_id_out_3 = ConnectionId::random();
171
172 assert!(limits.on_connection_established(connection_id_in_1, true).is_ok());
174 assert_eq!(limits.incoming_connections.len(), 1);
175
176 assert!(limits.on_connection_established(connection_id_in_2, true).is_ok());
177 assert_eq!(limits.incoming_connections.len(), 2);
178
179 assert!(limits.on_connection_established(connection_id_in_3, true).is_ok());
180 assert_eq!(limits.incoming_connections.len(), 3);
181
182 assert_eq!(
183 limits.on_connection_established(ConnectionId::random(), true).unwrap_err(),
184 ConnectionLimitsError::MaxIncomingConnectionsExceeded
185 );
186 assert_eq!(limits.incoming_connections.len(), 3);
187
188 assert!(limits.on_connection_established(connection_id_out_1, false).is_ok());
190 assert_eq!(limits.incoming_connections.len(), 3);
191 assert_eq!(limits.outgoing_connections.len(), 1);
192
193 assert!(limits.on_connection_established(connection_id_out_2, false).is_ok());
194 assert_eq!(limits.incoming_connections.len(), 3);
195 assert_eq!(limits.outgoing_connections.len(), 2);
196
197 assert_eq!(
198 limits.on_connection_established(connection_id_out_3, false).unwrap_err(),
199 ConnectionLimitsError::MaxOutgoingConnectionsExceeded
200 );
201
202 limits.on_connection_closed(connection_id_in_1);
204 assert_eq!(limits.incoming_connections.len(), 2);
205 assert_eq!(limits.outgoing_connections.len(), 2);
206
207 limits.on_connection_closed(connection_id_out_1);
208 assert_eq!(limits.incoming_connections.len(), 2);
209 assert_eq!(limits.outgoing_connections.len(), 1);
210 }
211}