1use libp2p_core::{ConnectedPoint, Endpoint, Multiaddr};
22use libp2p_identity::PeerId;
23use libp2p_swarm::{
24 behaviour::{ConnectionEstablished, DialFailure, ListenFailure},
25 dummy, ConnectionClosed, ConnectionDenied, ConnectionId, FromSwarm, NetworkBehaviour,
26 PollParameters, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm,
27};
28use std::collections::{HashMap, HashSet};
29use std::fmt;
30use std::task::{Context, Poll};
31use void::Void;
32
33pub struct Behaviour {
62 limits: ConnectionLimits,
63
64 pending_inbound_connections: HashSet<ConnectionId>,
65 pending_outbound_connections: HashSet<ConnectionId>,
66 established_inbound_connections: HashSet<ConnectionId>,
67 established_outbound_connections: HashSet<ConnectionId>,
68 established_per_peer: HashMap<PeerId, HashSet<ConnectionId>>,
69}
70
71impl Behaviour {
72 pub fn new(limits: ConnectionLimits) -> Self {
73 Self {
74 limits,
75 pending_inbound_connections: Default::default(),
76 pending_outbound_connections: Default::default(),
77 established_inbound_connections: Default::default(),
78 established_outbound_connections: Default::default(),
79 established_per_peer: Default::default(),
80 }
81 }
82
83 fn check_limit(
84 &mut self,
85 limit: Option<u32>,
86 current: usize,
87 kind: Kind,
88 ) -> Result<(), ConnectionDenied> {
89 let limit = limit.unwrap_or(u32::MAX);
90 let current = current as u32;
91
92 if current >= limit {
93 return Err(ConnectionDenied::new(Exceeded { limit, kind }));
94 }
95
96 Ok(())
97 }
98}
99
100#[derive(Debug, Clone, Copy)]
102pub struct Exceeded {
103 limit: u32,
104 kind: Kind,
105}
106
107impl Exceeded {
108 pub fn limit(&self) -> u32 {
109 self.limit
110 }
111}
112
113impl fmt::Display for Exceeded {
114 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115 write!(
116 f,
117 "connection limit exceeded: at most {} {} are allowed",
118 self.limit, self.kind
119 )
120 }
121}
122
123#[derive(Debug, Clone, Copy)]
124enum Kind {
125 PendingIncoming,
126 PendingOutgoing,
127 EstablishedIncoming,
128 EstablishedOutgoing,
129 EstablishedPerPeer,
130 EstablishedTotal,
131}
132
133impl fmt::Display for Kind {
134 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135 match self {
136 Kind::PendingIncoming => write!(f, "pending incoming connections"),
137 Kind::PendingOutgoing => write!(f, "pending outgoing connections"),
138 Kind::EstablishedIncoming => write!(f, "established incoming connections"),
139 Kind::EstablishedOutgoing => write!(f, "established outgoing connections"),
140 Kind::EstablishedPerPeer => write!(f, "established connections per peer"),
141 Kind::EstablishedTotal => write!(f, "established connections"),
142 }
143 }
144}
145
146impl std::error::Error for Exceeded {}
147
148#[derive(Debug, Clone, Default)]
150pub struct ConnectionLimits {
151 max_pending_incoming: Option<u32>,
152 max_pending_outgoing: Option<u32>,
153 max_established_incoming: Option<u32>,
154 max_established_outgoing: Option<u32>,
155 max_established_per_peer: Option<u32>,
156 max_established_total: Option<u32>,
157}
158
159impl ConnectionLimits {
160 pub fn with_max_pending_incoming(mut self, limit: Option<u32>) -> Self {
162 self.max_pending_incoming = limit;
163 self
164 }
165
166 pub fn with_max_pending_outgoing(mut self, limit: Option<u32>) -> Self {
168 self.max_pending_outgoing = limit;
169 self
170 }
171
172 pub fn with_max_established_incoming(mut self, limit: Option<u32>) -> Self {
174 self.max_established_incoming = limit;
175 self
176 }
177
178 pub fn with_max_established_outgoing(mut self, limit: Option<u32>) -> Self {
180 self.max_established_outgoing = limit;
181 self
182 }
183
184 pub fn with_max_established(mut self, limit: Option<u32>) -> Self {
191 self.max_established_total = limit;
192 self
193 }
194
195 pub fn with_max_established_per_peer(mut self, limit: Option<u32>) -> Self {
198 self.max_established_per_peer = limit;
199 self
200 }
201}
202
203impl NetworkBehaviour for Behaviour {
204 type ConnectionHandler = dummy::ConnectionHandler;
205 type ToSwarm = Void;
206
207 fn handle_pending_inbound_connection(
208 &mut self,
209 connection_id: ConnectionId,
210 _: &Multiaddr,
211 _: &Multiaddr,
212 ) -> Result<(), ConnectionDenied> {
213 self.check_limit(
214 self.limits.max_pending_incoming,
215 self.pending_inbound_connections.len(),
216 Kind::PendingIncoming,
217 )?;
218
219 self.pending_inbound_connections.insert(connection_id);
220
221 Ok(())
222 }
223
224 fn handle_established_inbound_connection(
225 &mut self,
226 connection_id: ConnectionId,
227 peer: PeerId,
228 _: &Multiaddr,
229 _: &Multiaddr,
230 ) -> Result<THandler<Self>, ConnectionDenied> {
231 self.pending_inbound_connections.remove(&connection_id);
232
233 self.check_limit(
234 self.limits.max_established_incoming,
235 self.established_inbound_connections.len(),
236 Kind::EstablishedIncoming,
237 )?;
238 self.check_limit(
239 self.limits.max_established_per_peer,
240 self.established_per_peer
241 .get(&peer)
242 .map(|connections| connections.len())
243 .unwrap_or(0),
244 Kind::EstablishedPerPeer,
245 )?;
246 self.check_limit(
247 self.limits.max_established_total,
248 self.established_inbound_connections.len()
249 + self.established_outbound_connections.len(),
250 Kind::EstablishedTotal,
251 )?;
252
253 Ok(dummy::ConnectionHandler)
254 }
255
256 fn handle_pending_outbound_connection(
257 &mut self,
258 connection_id: ConnectionId,
259 _: Option<PeerId>,
260 _: &[Multiaddr],
261 _: Endpoint,
262 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
263 self.check_limit(
264 self.limits.max_pending_outgoing,
265 self.pending_outbound_connections.len(),
266 Kind::PendingOutgoing,
267 )?;
268
269 self.pending_outbound_connections.insert(connection_id);
270
271 Ok(vec![])
272 }
273
274 fn handle_established_outbound_connection(
275 &mut self,
276 connection_id: ConnectionId,
277 peer: PeerId,
278 _: &Multiaddr,
279 _: Endpoint,
280 ) -> Result<THandler<Self>, ConnectionDenied> {
281 self.pending_outbound_connections.remove(&connection_id);
282
283 self.check_limit(
284 self.limits.max_established_outgoing,
285 self.established_outbound_connections.len(),
286 Kind::EstablishedOutgoing,
287 )?;
288 self.check_limit(
289 self.limits.max_established_per_peer,
290 self.established_per_peer
291 .get(&peer)
292 .map(|connections| connections.len())
293 .unwrap_or(0),
294 Kind::EstablishedPerPeer,
295 )?;
296 self.check_limit(
297 self.limits.max_established_total,
298 self.established_inbound_connections.len()
299 + self.established_outbound_connections.len(),
300 Kind::EstablishedTotal,
301 )?;
302
303 Ok(dummy::ConnectionHandler)
304 }
305
306 fn on_swarm_event(&mut self, event: FromSwarm<Self::ConnectionHandler>) {
307 match event {
308 FromSwarm::ConnectionClosed(ConnectionClosed {
309 peer_id,
310 connection_id,
311 ..
312 }) => {
313 self.established_inbound_connections.remove(&connection_id);
314 self.established_outbound_connections.remove(&connection_id);
315 self.established_per_peer
316 .entry(peer_id)
317 .or_default()
318 .remove(&connection_id);
319 }
320 FromSwarm::ConnectionEstablished(ConnectionEstablished {
321 peer_id,
322 endpoint,
323 connection_id,
324 ..
325 }) => {
326 match endpoint {
327 ConnectedPoint::Listener { .. } => {
328 self.established_inbound_connections.insert(connection_id);
329 }
330 ConnectedPoint::Dialer { .. } => {
331 self.established_outbound_connections.insert(connection_id);
332 }
333 }
334
335 self.established_per_peer
336 .entry(peer_id)
337 .or_default()
338 .insert(connection_id);
339 }
340 FromSwarm::DialFailure(DialFailure { connection_id, .. }) => {
341 self.pending_outbound_connections.remove(&connection_id);
342 }
343 FromSwarm::AddressChange(_) => {}
344 FromSwarm::ListenFailure(ListenFailure { connection_id, .. }) => {
345 self.pending_inbound_connections.remove(&connection_id);
346 }
347 FromSwarm::NewListener(_) => {}
348 FromSwarm::NewListenAddr(_) => {}
349 FromSwarm::ExpiredListenAddr(_) => {}
350 FromSwarm::ListenerError(_) => {}
351 FromSwarm::ListenerClosed(_) => {}
352 FromSwarm::NewExternalAddrCandidate(_) => {}
353 FromSwarm::ExternalAddrExpired(_) => {}
354 FromSwarm::ExternalAddrConfirmed(_) => {}
355 }
356 }
357
358 fn on_connection_handler_event(
359 &mut self,
360 _id: PeerId,
361 _: ConnectionId,
362 event: THandlerOutEvent<Self>,
363 ) {
364 void::unreachable(event)
365 }
366
367 fn poll(
368 &mut self,
369 _: &mut Context<'_>,
370 _: &mut impl PollParameters,
371 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
372 Poll::Pending
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use libp2p_swarm::{
380 behaviour::toggle::Toggle, dial_opts::DialOpts, DialError, ListenError, Swarm, SwarmEvent,
381 };
382 use libp2p_swarm_test::SwarmExt;
383 use quickcheck::*;
384
385 #[test]
386 fn max_outgoing() {
387 use rand::Rng;
388
389 let outgoing_limit = rand::thread_rng().gen_range(1..10);
390
391 let mut network = Swarm::new_ephemeral(|_| {
392 Behaviour::new(
393 ConnectionLimits::default().with_max_pending_outgoing(Some(outgoing_limit)),
394 )
395 });
396
397 let addr: Multiaddr = "/memory/1234".parse().unwrap();
398 let target = PeerId::random();
399
400 for _ in 0..outgoing_limit {
401 network
402 .dial(
403 DialOpts::peer_id(target)
404 .addresses(vec![addr.clone()])
405 .build(),
406 )
407 .expect("Unexpected connection limit.");
408 }
409
410 match network
411 .dial(DialOpts::peer_id(target).addresses(vec![addr]).build())
412 .expect_err("Unexpected dialing success.")
413 {
414 DialError::Denied { cause } => {
415 let exceeded = cause
416 .downcast::<Exceeded>()
417 .expect("connection denied because of limit");
418
419 assert_eq!(exceeded.limit(), outgoing_limit);
420 }
421 e => panic!("Unexpected error: {e:?}"),
422 }
423
424 let info = network.network_info();
425 assert_eq!(info.num_peers(), 0);
426 assert_eq!(
427 info.connection_counters().num_pending_outgoing(),
428 outgoing_limit
429 );
430 }
431
432 #[test]
433 fn max_established_incoming() {
434 fn prop(Limit(limit): Limit) {
435 let mut swarm1 = Swarm::new_ephemeral(|_| {
436 Behaviour::new(
437 ConnectionLimits::default().with_max_established_incoming(Some(limit)),
438 )
439 });
440 let mut swarm2 = Swarm::new_ephemeral(|_| {
441 Behaviour::new(
442 ConnectionLimits::default().with_max_established_incoming(Some(limit)),
443 )
444 });
445
446 async_std::task::block_on(async {
447 let (listen_addr, _) = swarm1.listen().await;
448
449 for _ in 0..limit {
450 swarm2.connect(&mut swarm1).await;
451 }
452
453 swarm2.dial(listen_addr).unwrap();
454
455 async_std::task::spawn(swarm2.loop_on_next());
456
457 let cause = swarm1
458 .wait(|event| match event {
459 SwarmEvent::IncomingConnectionError {
460 error: ListenError::Denied { cause },
461 ..
462 } => Some(cause),
463 _ => None,
464 })
465 .await;
466
467 assert_eq!(cause.downcast::<Exceeded>().unwrap().limit, limit);
468 });
469 }
470
471 #[derive(Debug, Clone)]
472 struct Limit(u32);
473
474 impl Arbitrary for Limit {
475 fn arbitrary(g: &mut Gen) -> Self {
476 Self(g.gen_range(1..10))
477 }
478 }
479
480 quickcheck(prop as fn(_));
481 }
482
483 #[test]
491 fn support_other_behaviour_denying_connection() {
492 let mut swarm1 = Swarm::new_ephemeral(|_| {
493 Behaviour::new_with_connection_denier(ConnectionLimits::default())
494 });
495 let mut swarm2 = Swarm::new_ephemeral(|_| Behaviour::new(ConnectionLimits::default()));
496
497 async_std::task::block_on(async {
498 let (listen_addr, _) = swarm1.listen().await;
500 swarm2.dial(listen_addr).unwrap();
501 async_std::task::spawn(swarm2.loop_on_next());
502
503 let cause = swarm1
505 .wait(|event| match event {
506 SwarmEvent::IncomingConnectionError {
507 error: ListenError::Denied { cause },
508 ..
509 } => Some(cause),
510 _ => None,
511 })
512 .await;
513
514 cause.downcast::<std::io::Error>().unwrap();
515
516 assert_eq!(
517 0,
518 swarm1
519 .behaviour_mut()
520 .limits
521 .established_inbound_connections
522 .len(),
523 "swarm1 connection limit behaviour to not count denied established connection as established connection"
524 )
525 });
526 }
527
528 #[derive(libp2p_swarm_derive::NetworkBehaviour)]
529 #[behaviour(prelude = "libp2p_swarm::derive_prelude")]
530 struct Behaviour {
531 limits: super::Behaviour,
532 keep_alive: libp2p_swarm::keep_alive::Behaviour,
533 connection_denier: Toggle<ConnectionDenier>,
534 }
535
536 impl Behaviour {
537 fn new(limits: ConnectionLimits) -> Self {
538 Self {
539 limits: super::Behaviour::new(limits),
540 keep_alive: libp2p_swarm::keep_alive::Behaviour,
541 connection_denier: None.into(),
542 }
543 }
544 fn new_with_connection_denier(limits: ConnectionLimits) -> Self {
545 Self {
546 limits: super::Behaviour::new(limits),
547 keep_alive: libp2p_swarm::keep_alive::Behaviour,
548 connection_denier: Some(ConnectionDenier {}).into(),
549 }
550 }
551 }
552
553 struct ConnectionDenier {}
554
555 impl NetworkBehaviour for ConnectionDenier {
556 type ConnectionHandler = dummy::ConnectionHandler;
557 type ToSwarm = Void;
558
559 fn handle_established_inbound_connection(
560 &mut self,
561 _connection_id: ConnectionId,
562 _peer: PeerId,
563 _local_addr: &Multiaddr,
564 _remote_addr: &Multiaddr,
565 ) -> Result<THandler<Self>, ConnectionDenied> {
566 Err(ConnectionDenied::new(std::io::Error::new(
567 std::io::ErrorKind::Other,
568 "ConnectionDenier",
569 )))
570 }
571
572 fn handle_established_outbound_connection(
573 &mut self,
574 _connection_id: ConnectionId,
575 _peer: PeerId,
576 _addr: &Multiaddr,
577 _role_override: Endpoint,
578 ) -> Result<THandler<Self>, ConnectionDenied> {
579 Err(ConnectionDenied::new(std::io::Error::new(
580 std::io::ErrorKind::Other,
581 "ConnectionDenier",
582 )))
583 }
584
585 fn on_swarm_event(&mut self, _event: FromSwarm<Self::ConnectionHandler>) {}
586
587 fn on_connection_handler_event(
588 &mut self,
589 _peer_id: PeerId,
590 _connection_id: ConnectionId,
591 event: THandlerOutEvent<Self>,
592 ) {
593 void::unreachable(event)
594 }
595
596 fn poll(
597 &mut self,
598 _cx: &mut Context<'_>,
599 _params: &mut impl PollParameters,
600 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
601 Poll::Pending
602 }
603 }
604}