1use libp2p_core::{Endpoint, Multiaddr};
65use libp2p_identity::PeerId;
66use libp2p_swarm::{
67 dummy, CloseConnection, ConnectionDenied, ConnectionId, FromSwarm, NetworkBehaviour,
68 PollParameters, THandler, THandlerInEvent, THandlerOutEvent, ToSwarm,
69};
70use std::collections::{HashSet, VecDeque};
71use std::fmt;
72use std::task::{Context, Poll, Waker};
73use void::Void;
74
75#[derive(Default, Debug)]
77pub struct Behaviour<S> {
78 state: S,
79 close_connections: VecDeque<PeerId>,
80 waker: Option<Waker>,
81}
82
83#[derive(Default)]
85pub struct AllowedPeers {
86 peers: HashSet<PeerId>,
87}
88
89#[derive(Default)]
91pub struct BlockedPeers {
92 peers: HashSet<PeerId>,
93}
94
95impl Behaviour<AllowedPeers> {
96 pub fn allow_peer(&mut self, peer: PeerId) {
98 self.state.peers.insert(peer);
99 if let Some(waker) = self.waker.take() {
100 waker.wake()
101 }
102 }
103
104 pub fn disallow_peer(&mut self, peer: PeerId) {
108 self.state.peers.remove(&peer);
109 self.close_connections.push_back(peer);
110 if let Some(waker) = self.waker.take() {
111 waker.wake()
112 }
113 }
114}
115
116impl Behaviour<BlockedPeers> {
117 pub fn block_peer(&mut self, peer: PeerId) {
121 self.state.peers.insert(peer);
122 self.close_connections.push_back(peer);
123 if let Some(waker) = self.waker.take() {
124 waker.wake()
125 }
126 }
127
128 pub fn unblock_peer(&mut self, peer: PeerId) {
130 self.state.peers.remove(&peer);
131 if let Some(waker) = self.waker.take() {
132 waker.wake()
133 }
134 }
135}
136
137#[derive(Debug)]
139pub struct NotAllowed {
140 peer: PeerId,
141}
142
143impl fmt::Display for NotAllowed {
144 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145 write!(f, "peer {} is not in the allow list", self.peer)
146 }
147}
148
149impl std::error::Error for NotAllowed {}
150
151#[derive(Debug)]
153pub struct Blocked {
154 peer: PeerId,
155}
156
157impl fmt::Display for Blocked {
158 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159 write!(f, "peer {} is in the block list", self.peer)
160 }
161}
162
163impl std::error::Error for Blocked {}
164
165trait Enforce: 'static {
166 fn enforce(&self, peer: &PeerId) -> Result<(), ConnectionDenied>;
167}
168
169impl Enforce for AllowedPeers {
170 fn enforce(&self, peer: &PeerId) -> Result<(), ConnectionDenied> {
171 if !self.peers.contains(peer) {
172 return Err(ConnectionDenied::new(NotAllowed { peer: *peer }));
173 }
174
175 Ok(())
176 }
177}
178
179impl Enforce for BlockedPeers {
180 fn enforce(&self, peer: &PeerId) -> Result<(), ConnectionDenied> {
181 if self.peers.contains(peer) {
182 return Err(ConnectionDenied::new(Blocked { peer: *peer }));
183 }
184
185 Ok(())
186 }
187}
188
189impl<S> NetworkBehaviour for Behaviour<S>
190where
191 S: Enforce,
192{
193 type ConnectionHandler = dummy::ConnectionHandler;
194 type ToSwarm = Void;
195
196 fn handle_established_inbound_connection(
197 &mut self,
198 _: ConnectionId,
199 peer: PeerId,
200 _: &Multiaddr,
201 _: &Multiaddr,
202 ) -> Result<THandler<Self>, ConnectionDenied> {
203 self.state.enforce(&peer)?;
204
205 Ok(dummy::ConnectionHandler)
206 }
207
208 fn handle_pending_outbound_connection(
209 &mut self,
210 _: ConnectionId,
211 peer: Option<PeerId>,
212 _: &[Multiaddr],
213 _: Endpoint,
214 ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
215 if let Some(peer) = peer {
216 self.state.enforce(&peer)?;
217 }
218
219 Ok(vec![])
220 }
221
222 fn handle_established_outbound_connection(
223 &mut self,
224 _: ConnectionId,
225 peer: PeerId,
226 _: &Multiaddr,
227 _: Endpoint,
228 ) -> Result<THandler<Self>, ConnectionDenied> {
229 self.state.enforce(&peer)?;
230
231 Ok(dummy::ConnectionHandler)
232 }
233
234 fn on_swarm_event(&mut self, event: FromSwarm<Self::ConnectionHandler>) {
235 match event {
236 FromSwarm::ConnectionClosed(_) => {}
237 FromSwarm::ConnectionEstablished(_) => {}
238 FromSwarm::AddressChange(_) => {}
239 FromSwarm::DialFailure(_) => {}
240 FromSwarm::ListenFailure(_) => {}
241 FromSwarm::NewListener(_) => {}
242 FromSwarm::NewListenAddr(_) => {}
243 FromSwarm::ExpiredListenAddr(_) => {}
244 FromSwarm::ListenerError(_) => {}
245 FromSwarm::ListenerClosed(_) => {}
246 FromSwarm::NewExternalAddrCandidate(_) => {}
247 FromSwarm::ExternalAddrExpired(_) => {}
248 FromSwarm::ExternalAddrConfirmed(_) => {}
249 }
250 }
251
252 fn on_connection_handler_event(
253 &mut self,
254 _id: PeerId,
255 _: ConnectionId,
256 event: THandlerOutEvent<Self>,
257 ) {
258 void::unreachable(event)
259 }
260
261 fn poll(
262 &mut self,
263 cx: &mut Context<'_>,
264 _: &mut impl PollParameters,
265 ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
266 if let Some(peer) = self.close_connections.pop_front() {
267 return Poll::Ready(ToSwarm::CloseConnection {
268 peer_id: peer,
269 connection: CloseConnection::All,
270 });
271 }
272
273 self.waker = Some(cx.waker().clone());
274 Poll::Pending
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use libp2p_swarm::{dial_opts::DialOpts, DialError, ListenError, Swarm, SwarmEvent};
282 use libp2p_swarm_test::SwarmExt;
283
284 #[async_std::test]
285 async fn cannot_dial_blocked_peer() {
286 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::new());
287 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::new());
288 listener.listen().await;
289
290 dialer
291 .behaviour_mut()
292 .list
293 .block_peer(*listener.local_peer_id());
294
295 let DialError::Denied { cause } = dial(&mut dialer, &listener).unwrap_err() else {
296 panic!("unexpected dial error")
297 };
298 assert!(cause.downcast::<Blocked>().is_ok());
299 }
300
301 #[async_std::test]
302 async fn can_dial_unblocked_peer() {
303 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::new());
304 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::new());
305 listener.listen().await;
306
307 dialer
308 .behaviour_mut()
309 .list
310 .block_peer(*listener.local_peer_id());
311 dialer
312 .behaviour_mut()
313 .list
314 .unblock_peer(*listener.local_peer_id());
315
316 dial(&mut dialer, &listener).unwrap();
317 }
318
319 #[async_std::test]
320 async fn blocked_peer_cannot_dial_us() {
321 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::new());
322 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::new());
323 listener.listen().await;
324
325 listener
326 .behaviour_mut()
327 .list
328 .block_peer(*dialer.local_peer_id());
329 dial(&mut dialer, &listener).unwrap();
330 async_std::task::spawn(dialer.loop_on_next());
331
332 let cause = listener
333 .wait(|e| match e {
334 SwarmEvent::IncomingConnectionError {
335 error: ListenError::Denied { cause },
336 ..
337 } => Some(cause),
338 _ => None,
339 })
340 .await;
341 assert!(cause.downcast::<Blocked>().is_ok());
342 }
343
344 #[async_std::test]
345 async fn connections_get_closed_upon_blocked() {
346 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::new());
347 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<BlockedPeers>::new());
348 listener.listen().await;
349 dialer.connect(&mut listener).await;
350
351 dialer
352 .behaviour_mut()
353 .list
354 .block_peer(*listener.local_peer_id());
355
356 let (
357 [SwarmEvent::ConnectionClosed { peer_id: closed_dialer_peer, .. }],
358 [SwarmEvent::ConnectionClosed { peer_id: closed_listener_peer, .. }]
359 ) = libp2p_swarm_test::drive(&mut dialer, &mut listener).await else {
360 panic!("unexpected events")
361 };
362 assert_eq!(closed_dialer_peer, *listener.local_peer_id());
363 assert_eq!(closed_listener_peer, *dialer.local_peer_id());
364 }
365
366 #[async_std::test]
367 async fn cannot_dial_peer_unless_allowed() {
368 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::new());
369 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::new());
370 listener.listen().await;
371
372 let DialError::Denied { cause } = dial(&mut dialer, &listener).unwrap_err() else {
373 panic!("unexpected dial error")
374 };
375 assert!(cause.downcast::<NotAllowed>().is_ok());
376
377 dialer
378 .behaviour_mut()
379 .list
380 .allow_peer(*listener.local_peer_id());
381 assert!(dial(&mut dialer, &listener).is_ok());
382 }
383
384 #[async_std::test]
385 async fn cannot_dial_disallowed_peer() {
386 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::new());
387 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::new());
388 listener.listen().await;
389
390 dialer
391 .behaviour_mut()
392 .list
393 .allow_peer(*listener.local_peer_id());
394 dialer
395 .behaviour_mut()
396 .list
397 .disallow_peer(*listener.local_peer_id());
398
399 let DialError::Denied { cause } = dial(&mut dialer, &listener).unwrap_err() else {
400 panic!("unexpected dial error")
401 };
402 assert!(cause.downcast::<NotAllowed>().is_ok());
403 }
404
405 #[async_std::test]
406 async fn not_allowed_peer_cannot_dial_us() {
407 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::new());
408 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::new());
409 listener.listen().await;
410
411 dialer
412 .dial(
413 DialOpts::unknown_peer_id()
414 .address(listener.external_addresses().next().cloned().unwrap())
415 .build(),
416 )
417 .unwrap();
418
419 let (
420 [SwarmEvent::OutgoingConnectionError { error: DialError::Denied { cause: outgoing_cause }, .. }],
421 [_, _, _, SwarmEvent::IncomingConnectionError { error: ListenError::Denied { cause: incoming_cause }, .. }],
422 ) = libp2p_swarm_test::drive(&mut dialer, &mut listener).await else {
423 panic!("unexpected events")
424 };
425 assert!(outgoing_cause.downcast::<NotAllowed>().is_ok());
426 assert!(incoming_cause.downcast::<NotAllowed>().is_ok());
427 }
428
429 #[async_std::test]
430 async fn connections_get_closed_upon_disallow() {
431 let mut dialer = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::new());
432 let mut listener = Swarm::new_ephemeral(|_| Behaviour::<AllowedPeers>::new());
433 listener.listen().await;
434 dialer
435 .behaviour_mut()
436 .list
437 .allow_peer(*listener.local_peer_id());
438 listener
439 .behaviour_mut()
440 .list
441 .allow_peer(*dialer.local_peer_id());
442
443 dialer.connect(&mut listener).await;
444
445 dialer
446 .behaviour_mut()
447 .list
448 .disallow_peer(*listener.local_peer_id());
449 let (
450 [SwarmEvent::ConnectionClosed { peer_id: closed_dialer_peer, .. }],
451 [SwarmEvent::ConnectionClosed { peer_id: closed_listener_peer, .. }]
452 ) = libp2p_swarm_test::drive(&mut dialer, &mut listener).await else {
453 panic!("unexpected events")
454 };
455 assert_eq!(closed_dialer_peer, *listener.local_peer_id());
456 assert_eq!(closed_listener_peer, *dialer.local_peer_id());
457 }
458
459 fn dial<S>(
460 dialer: &mut Swarm<Behaviour<S>>,
461 listener: &Swarm<Behaviour<S>>,
462 ) -> Result<(), DialError>
463 where
464 S: Enforce,
465 {
466 dialer.dial(
467 DialOpts::peer_id(*listener.local_peer_id())
468 .addresses(listener.external_addresses().cloned().collect())
469 .build(),
470 )
471 }
472
473 #[derive(libp2p_swarm_derive::NetworkBehaviour)]
474 #[behaviour(prelude = "libp2p_swarm::derive_prelude")]
475 struct Behaviour<S> {
476 list: super::Behaviour<S>,
477 keep_alive: libp2p_swarm::keep_alive::Behaviour,
478 }
479
480 impl<S> Behaviour<S>
481 where
482 S: Default,
483 {
484 fn new() -> Self {
485 Self {
486 list: super::Behaviour {
487 waker: None,
488 close_connections: VecDeque::new(),
489 state: S::default(),
490 },
491 keep_alive: libp2p_swarm::keep_alive::Behaviour,
492 }
493 }
494 }
495}