1use crate::peer_store::{PeerStoreProvider, ProtocolHandle as ProtocolHandleT};
45
46use futures::{channel::oneshot, future::Either, FutureExt, StreamExt};
47use libp2p::PeerId;
48use log::{debug, error, trace, warn};
49use sc_utils::mpsc::{tracing_unbounded, TracingUnboundedReceiver, TracingUnboundedSender};
50use sp_arithmetic::traits::SaturatedConversion;
51use std::{
52	collections::{HashMap, HashSet},
53	sync::Arc,
54	time::{Duration, Instant},
55};
56use wasm_timer::Delay;
57
58pub const LOG_TARGET: &str = "peerset";
60
61#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
67pub struct SetId(usize);
68
69impl SetId {
70	pub const fn from(id: usize) -> Self {
72		Self(id)
73	}
74}
75
76impl From<usize> for SetId {
77	fn from(id: usize) -> Self {
78		Self(id)
79	}
80}
81
82impl From<SetId> for usize {
83	fn from(id: SetId) -> Self {
84		id.0
85	}
86}
87
88#[derive(Debug)]
90pub struct ProtoSetConfig {
91	pub in_peers: u32,
93
94	pub out_peers: u32,
96
97	pub reserved_nodes: HashSet<PeerId>,
102
103	pub reserved_only: bool,
105}
106
107#[derive(Debug, PartialEq)]
109pub enum Message {
110	Connect {
113		set_id: SetId,
115		peer_id: PeerId,
117	},
118
119	Drop {
121		set_id: SetId,
123		peer_id: PeerId,
125	},
126
127	Accept(IncomingIndex),
129
130	Reject(IncomingIndex),
132}
133
134#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
136pub struct IncomingIndex(pub u64);
137
138impl From<u64> for IncomingIndex {
139	fn from(val: u64) -> Self {
140		Self(val)
141	}
142}
143
144#[derive(Debug)]
146enum Action {
147	AddReservedPeer(PeerId),
149	RemoveReservedPeer(PeerId),
151	SetReservedPeers(HashSet<PeerId>),
153	SetReservedOnly(bool),
155	DisconnectPeer(PeerId),
157	GetReservedPeers(oneshot::Sender<Vec<PeerId>>),
159}
160
161#[derive(Debug)]
163enum Event {
164	IncomingConnection(PeerId, IncomingIndex),
166	Dropped(PeerId),
168}
169
170#[derive(Debug, Clone)]
173pub struct ProtocolHandle {
174	actions_tx: TracingUnboundedSender<Action>,
176	events_tx: TracingUnboundedSender<Event>,
178}
179
180impl ProtocolHandle {
181	pub fn add_reserved_peer(&self, peer_id: PeerId) {
189		let _ = self.actions_tx.unbounded_send(Action::AddReservedPeer(peer_id));
190	}
191
192	pub fn remove_reserved_peer(&self, peer_id: PeerId) {
196		let _ = self.actions_tx.unbounded_send(Action::RemoveReservedPeer(peer_id));
197	}
198
199	pub fn set_reserved_peers(&self, peer_ids: HashSet<PeerId>) {
201		let _ = self.actions_tx.unbounded_send(Action::SetReservedPeers(peer_ids));
202	}
203
204	pub fn set_reserved_only(&self, reserved: bool) {
207		let _ = self.actions_tx.unbounded_send(Action::SetReservedOnly(reserved));
208	}
209
210	pub fn disconnect_peer(&self, peer_id: PeerId) {
213		let _ = self.actions_tx.unbounded_send(Action::DisconnectPeer(peer_id));
214	}
215
216	pub fn reserved_peers(&self, pending_response: oneshot::Sender<Vec<PeerId>>) {
218		let _ = self.actions_tx.unbounded_send(Action::GetReservedPeers(pending_response));
219	}
220
221	pub fn incoming_connection(&self, peer_id: PeerId, incoming_index: IncomingIndex) {
223		let _ = self
224			.events_tx
225			.unbounded_send(Event::IncomingConnection(peer_id, incoming_index));
226	}
227
228	pub fn dropped(&self, peer_id: PeerId) {
230		let _ = self.events_tx.unbounded_send(Event::Dropped(peer_id));
231	}
232}
233
234impl ProtocolHandleT for ProtocolHandle {
235	fn disconnect_peer(&self, peer_id: sc_network_types::PeerId) {
236		let _ = self.actions_tx.unbounded_send(Action::DisconnectPeer(peer_id.into()));
237	}
238}
239
240#[derive(Clone, Copy, Debug)]
242enum Direction {
243	Inbound,
244	Outbound,
245}
246
247#[derive(Clone, Debug)]
249enum PeerState {
250	Connected(Direction),
252	NotConnected,
254}
255
256impl PeerState {
257	fn is_connected(&self) -> bool {
259		matches!(self, PeerState::Connected(_))
260	}
261}
262
263impl Default for PeerState {
264	fn default() -> PeerState {
265		PeerState::NotConnected
266	}
267}
268
269#[derive(Debug)]
271pub struct ProtocolController {
272	set_id: SetId,
275	actions_rx: TracingUnboundedReceiver<Action>,
277	events_rx: TracingUnboundedReceiver<Event>,
279	num_in: u32,
281	num_out: u32,
283	max_in: u32,
285	max_out: u32,
287	nodes: HashMap<PeerId, Direction>,
289	reserved_nodes: HashMap<PeerId, PeerState>,
291	reserved_only: bool,
293	next_periodic_alloc_slots: Instant,
295	to_notifications: TracingUnboundedSender<Message>,
297	peer_store: Arc<dyn PeerStoreProvider>,
300}
301
302impl ProtocolController {
303	pub fn new(
305		set_id: SetId,
306		config: ProtoSetConfig,
307		to_notifications: TracingUnboundedSender<Message>,
308		peer_store: Arc<dyn PeerStoreProvider>,
309	) -> (ProtocolHandle, ProtocolController) {
310		let (actions_tx, actions_rx) = tracing_unbounded("mpsc_api_protocol", 10_000);
311		let (events_tx, events_rx) = tracing_unbounded("mpsc_notifications_protocol", 10_000);
312		let handle = ProtocolHandle { actions_tx, events_tx };
313		peer_store.register_protocol(Arc::new(handle.clone()));
314		let reserved_nodes =
315			config.reserved_nodes.iter().map(|p| (*p, PeerState::NotConnected)).collect();
316		let controller = ProtocolController {
317			set_id,
318			actions_rx,
319			events_rx,
320			num_in: 0,
321			num_out: 0,
322			max_in: config.in_peers,
323			max_out: config.out_peers,
324			nodes: HashMap::new(),
325			reserved_nodes,
326			reserved_only: config.reserved_only,
327			next_periodic_alloc_slots: Instant::now(),
328			to_notifications,
329			peer_store,
330		};
331		(handle, controller)
332	}
333
334	pub async fn run(mut self) {
337		while self.next_action().await {}
338	}
339
340	pub async fn next_action(&mut self) -> bool {
344		let either = loop {
345			let mut next_alloc_slots = Delay::new_at(self.next_periodic_alloc_slots).fuse();
346
347			futures::select_biased! {
349				event = self.events_rx.next() => match event {
350					Some(event) => break Either::Left(event),
351					None => return false,
352				},
353				action = self.actions_rx.next() => match action {
354					Some(action) => break Either::Right(action),
355					None => return false,
356				},
357				_ = next_alloc_slots => {
358					self.alloc_slots();
359					self.next_periodic_alloc_slots = Instant::now() + Duration::new(1, 0);
360				},
361			}
362		};
363
364		match either {
365			Either::Left(event) => self.process_event(event),
366			Either::Right(action) => self.process_action(action),
367		}
368
369		true
370	}
371
372	fn process_event(&mut self, event: Event) {
374		match event {
375			Event::IncomingConnection(peer_id, index) =>
376				self.on_incoming_connection(peer_id, index),
377			Event::Dropped(peer_id) => self.on_peer_dropped(peer_id),
378		}
379	}
380
381	fn process_action(&mut self, action: Action) {
383		match action {
384			Action::AddReservedPeer(peer_id) => self.on_add_reserved_peer(peer_id),
385			Action::RemoveReservedPeer(peer_id) => self.on_remove_reserved_peer(peer_id),
386			Action::SetReservedPeers(peer_ids) => self.on_set_reserved_peers(peer_ids),
387			Action::SetReservedOnly(reserved_only) => self.on_set_reserved_only(reserved_only),
388			Action::DisconnectPeer(peer_id) => self.on_disconnect_peer(peer_id),
389			Action::GetReservedPeers(pending_response) =>
390				self.on_get_reserved_peers(pending_response),
391		}
392	}
393
394	fn accept_connection(&mut self, peer_id: PeerId, incoming_index: IncomingIndex) {
396		trace!(
397			target: LOG_TARGET,
398			"Accepting {peer_id} ({incoming_index:?}) on {:?} ({}/{} num_in/max_in).",
399			self.set_id,
400			self.num_in,
401			self.max_in,
402		);
403
404		let _ = self.to_notifications.unbounded_send(Message::Accept(incoming_index));
405	}
406
407	fn reject_connection(&mut self, peer_id: PeerId, incoming_index: IncomingIndex) {
409		trace!(
410			target: LOG_TARGET,
411			"Rejecting {peer_id} ({incoming_index:?}) on {:?} ({}/{} num_in/max_in).",
412			self.set_id,
413			self.num_in,
414			self.max_in,
415		);
416
417		let _ = self.to_notifications.unbounded_send(Message::Reject(incoming_index));
418	}
419
420	fn start_connection(&mut self, peer_id: PeerId) {
422		trace!(
423			target: LOG_TARGET,
424			"Connecting to {peer_id} on {:?} ({}/{} num_out/max_out).",
425			self.set_id,
426			self.num_out,
427			self.max_out,
428		);
429
430		let _ = self
431			.to_notifications
432			.unbounded_send(Message::Connect { set_id: self.set_id, peer_id });
433	}
434
435	fn drop_connection(&mut self, peer_id: PeerId) {
437		trace!(
438			target: LOG_TARGET,
439			"Dropping {peer_id} on {:?} ({}/{} num_in/max_in, {}/{} num_out/max_out).",
440			self.set_id,
441			self.num_in,
442			self.max_in,
443			self.num_out,
444			self.max_out,
445		);
446
447		let _ = self
448			.to_notifications
449			.unbounded_send(Message::Drop { set_id: self.set_id, peer_id });
450	}
451
452	fn report_disconnect(&mut self, peer_id: PeerId) {
455		self.peer_store.report_disconnect(peer_id.into());
456	}
457
458	fn is_banned(&self, peer_id: &PeerId) -> bool {
460		self.peer_store.is_banned(&peer_id.into())
461	}
462
463	fn on_add_reserved_peer(&mut self, peer_id: PeerId) {
466		if self.reserved_nodes.contains_key(&peer_id) {
467			debug!(
468				target: LOG_TARGET,
469				"Trying to add an already reserved node {peer_id} as reserved on {:?}.",
470				self.set_id,
471			);
472			return
473		}
474
475		let state = match self.nodes.remove(&peer_id) {
477			Some(direction) => {
478				trace!(
479					target: LOG_TARGET,
480					"Marking previously connected node {} ({:?}) as reserved on {:?}.",
481					peer_id,
482					direction,
483					self.set_id
484				);
485				PeerState::Connected(direction)
486			},
487			None => {
488				trace!(target: LOG_TARGET, "Adding reserved node {peer_id} on {:?}.", self.set_id,);
489				PeerState::NotConnected
490			},
491		};
492
493		self.reserved_nodes.insert(peer_id, state.clone());
494
495		match state {
497			PeerState::Connected(Direction::Inbound) => self.num_in -= 1,
498			PeerState::Connected(Direction::Outbound) => self.num_out -= 1,
499			PeerState::NotConnected => self.alloc_slots(),
500		}
501	}
502
503	fn on_remove_reserved_peer(&mut self, peer_id: PeerId) {
506		let state = match self.reserved_nodes.remove(&peer_id) {
507			Some(state) => state,
508			None => {
509				warn!(
510					target: LOG_TARGET,
511					"Trying to remove unknown reserved node {peer_id} from {:?}.", self.set_id,
512				);
513				return
514			},
515		};
516
517		if let PeerState::Connected(direction) = state {
518			let disconnect = self.reserved_only ||
520				match direction {
521					Direction::Inbound => self.num_in >= self.max_in,
522					Direction::Outbound => self.num_out >= self.max_out,
523				};
524
525			if disconnect {
526				trace!(
528					target: LOG_TARGET,
529					"Disconnecting previously reserved node {peer_id} ({direction:?}) on {:?}.",
530					self.set_id,
531				);
532				self.drop_connection(peer_id);
533			} else {
534				trace!(
536					target: LOG_TARGET,
537					"Making a connected reserved node {peer_id} ({:?}) on {:?} a regular one.",
538					direction,
539					self.set_id,
540				);
541
542				match direction {
543					Direction::Inbound => self.num_in += 1,
544					Direction::Outbound => self.num_out += 1,
545				}
546
547				let prev = self.nodes.insert(peer_id, direction);
549				assert!(prev.is_none(), "Corrupted state: reserved node was also non-reserved.");
550			}
551		} else {
552			trace!(
553				target: LOG_TARGET,
554				"Removed disconnected reserved node {peer_id} from {:?}.",
555				self.set_id,
556			);
557		}
558	}
559
560	fn on_set_reserved_peers(&mut self, peer_ids: HashSet<PeerId>) {
562		let current = self.reserved_nodes.keys().cloned().collect();
564		let to_insert = peer_ids.difference(¤t).cloned().collect::<Vec<_>>();
565		let to_remove = current.difference(&peer_ids).cloned().collect::<Vec<_>>();
566
567		for node in to_insert {
568			self.on_add_reserved_peer(node);
569		}
570
571		for node in to_remove {
572			self.on_remove_reserved_peer(node);
573		}
574	}
575
576	fn on_set_reserved_only(&mut self, reserved_only: bool) {
579		trace!(target: LOG_TARGET, "Set reserved only to `{reserved_only}` on {:?}", self.set_id);
580
581		self.reserved_only = reserved_only;
582
583		if !reserved_only {
584			return self.alloc_slots()
585		}
586
587		self.nodes
589			.iter()
590			.map(|(k, v)| (*k, *v))
591			.collect::<Vec<(_, _)>>()
592			.iter()
593			.for_each(|(peer_id, direction)| {
594				match direction {
596					Direction::Inbound => self.num_in -= 1,
597					Direction::Outbound => self.num_out -= 1,
598				}
599				self.drop_connection(*peer_id)
600			});
601		self.nodes.clear();
602	}
603
604	fn on_get_reserved_peers(&self, pending_response: oneshot::Sender<Vec<PeerId>>) {
606		let _ = pending_response.send(self.reserved_nodes.keys().cloned().collect());
607	}
608
609	fn on_disconnect_peer(&mut self, peer_id: PeerId) {
611		if self.reserved_nodes.contains_key(&peer_id) {
613			debug!(
614				target: LOG_TARGET,
615				"Ignoring request to disconnect reserved peer {peer_id} from {:?}.", self.set_id,
616			);
617			return
618		}
619
620		match self.nodes.remove(&peer_id) {
621			Some(direction) => {
622				trace!(
623					target: LOG_TARGET,
624					"Disconnecting peer {peer_id} ({direction:?}) from {:?}.",
625					self.set_id
626				);
627				match direction {
628					Direction::Inbound => self.num_in -= 1,
629					Direction::Outbound => self.num_out -= 1,
630				}
631				self.drop_connection(peer_id);
632			},
633			None => {
634				debug!(
635					target: LOG_TARGET,
636					"Trying to disconnect unknown peer {peer_id} from {:?}.", self.set_id,
637				);
638			},
639		}
640	}
641
642	fn on_incoming_connection(&mut self, peer_id: PeerId, incoming_index: IncomingIndex) {
654		trace!(
655			target: LOG_TARGET,
656			"Incoming connection from peer {peer_id} ({incoming_index:?}) on {:?}.",
657			self.set_id,
658		);
659
660		if self.reserved_only && !self.reserved_nodes.contains_key(&peer_id) {
661			self.reject_connection(peer_id, incoming_index);
662			return
663		}
664
665		if let Some(state) = self.reserved_nodes.get_mut(&peer_id) {
667			match state {
668				PeerState::Connected(ref mut direction) => {
669					*direction = Direction::Inbound;
672					self.accept_connection(peer_id, incoming_index);
673				},
674				PeerState::NotConnected =>
675					if self.peer_store.is_banned(&peer_id.into()) {
676						self.reject_connection(peer_id, incoming_index);
677					} else {
678						*state = PeerState::Connected(Direction::Inbound);
679						self.accept_connection(peer_id, incoming_index);
680					},
681			}
682			return
683		}
684
685		if let Some(direction) = self.nodes.remove(&peer_id) {
688			trace!(
689				target: LOG_TARGET,
690				"Handling incoming connection from peer {} we think we already connected as {:?} on {:?}.",
691				peer_id,
692				direction,
693				self.set_id
694			);
695			match direction {
696				Direction::Inbound => self.num_in -= 1,
697				Direction::Outbound => self.num_out -= 1,
698			}
699		}
700
701		if self.num_in >= self.max_in {
702			self.reject_connection(peer_id, incoming_index);
703			return
704		}
705
706		if self.is_banned(&peer_id) {
707			self.reject_connection(peer_id, incoming_index);
708			return
709		}
710
711		self.num_in += 1;
712		self.nodes.insert(peer_id, Direction::Inbound);
713		self.accept_connection(peer_id, incoming_index);
714	}
715
716	fn on_peer_dropped(&mut self, peer_id: PeerId) {
718		self.on_peer_dropped_inner(peer_id).unwrap_or_else(|peer_id| {
719			trace!(
723				target: LOG_TARGET,
724				"Received `Action::Dropped` for not connected peer {peer_id} on {:?}.",
725				self.set_id,
726			)
727		});
728	}
729
730	fn on_peer_dropped_inner(&mut self, peer_id: PeerId) -> Result<(), PeerId> {
733		if self.drop_reserved_peer(&peer_id)? || self.drop_regular_peer(&peer_id) {
734			self.report_disconnect(peer_id);
736			Ok(())
737		} else {
738			Err(peer_id)
740		}
741	}
742
743	fn drop_reserved_peer(&mut self, peer_id: &PeerId) -> Result<bool, PeerId> {
747		let Some(state) = self.reserved_nodes.get_mut(peer_id) else { return Ok(false) };
748
749		if let PeerState::Connected(direction) = state {
750			trace!(
751				target: LOG_TARGET,
752				"Reserved peer {peer_id} ({direction:?}) dropped from {:?}.",
753				self.set_id,
754			);
755			*state = PeerState::NotConnected;
756			Ok(true)
757		} else {
758			Err(*peer_id)
759		}
760	}
761
762	fn drop_regular_peer(&mut self, peer_id: &PeerId) -> bool {
765		let Some(direction) = self.nodes.remove(peer_id) else { return false };
766
767		trace!(
768			target: LOG_TARGET,
769			"Peer {peer_id} ({direction:?}) dropped from {:?}.",
770			self.set_id,
771		);
772
773		match direction {
774			Direction::Inbound => self.num_in -= 1,
775			Direction::Outbound => self.num_out -= 1,
776		}
777
778		true
779	}
780
781	fn alloc_slots(&mut self) {
784		self.reserved_nodes
786			.iter_mut()
787			.filter_map(|(peer_id, state)| {
788				(!state.is_connected() && !self.peer_store.is_banned(&peer_id.into())).then(|| {
789					*state = PeerState::Connected(Direction::Outbound);
790					peer_id
791				})
792			})
793			.cloned()
794			.collect::<Vec<_>>()
795			.into_iter()
796			.for_each(|peer_id| {
797				self.start_connection(peer_id);
798			});
799
800		if self.reserved_only || self.num_out >= self.max_out {
802			return
803		}
804
805		let available_slots = (self.max_out - self.num_out).saturated_into();
807
808		let ignored = self
811			.reserved_nodes
812			.keys()
813			.map(From::from)
814			.collect::<HashSet<sc_network_types::PeerId>>()
815			.union(
816				&self.nodes.keys().map(From::from).collect::<HashSet<sc_network_types::PeerId>>(),
817			)
818			.cloned()
819			.collect();
820
821		let candidates = self
822			.peer_store
823			.outgoing_candidates(available_slots, ignored)
824			.into_iter()
825			.filter_map(|peer_id| {
826				(!self.reserved_nodes.contains_key(&peer_id.into()) &&
827					!self.nodes.contains_key(&peer_id.into()))
828				.then_some(peer_id)
829				.or_else(|| {
830					error!(
831						target: LOG_TARGET,
832						"`PeerStore` returned a node we asked to ignore: {peer_id}.",
833					);
834					debug_assert!(false, "`PeerStore` returned a node we asked to ignore.");
835					None
836				})
837			})
838			.collect::<Vec<_>>();
839
840		if candidates.len() > available_slots {
841			error!(
842				target: LOG_TARGET,
843				"`PeerStore` returned more nodes than there are slots available.",
844			);
845			debug_assert!(false, "`PeerStore` returned more nodes than there are slots available.");
846		}
847
848		candidates.into_iter().take(available_slots).for_each(|peer_id| {
849			self.num_out += 1;
850			self.nodes.insert(peer_id.into(), Direction::Outbound);
851			self.start_connection(peer_id.into());
852		})
853	}
854}
855
856#[cfg(test)]
857mod tests {
858	use super::*;
859	use crate::{
860		peer_store::{PeerStoreProvider, ProtocolHandle as ProtocolHandleT},
861		ReputationChange,
862	};
863	use libp2p::PeerId;
864	use sc_network_common::role::ObservedRole;
865	use sc_utils::mpsc::{tracing_unbounded, TryRecvError};
866	use std::collections::HashSet;
867
868	mockall::mock! {
869		#[derive(Debug)]
870		pub PeerStoreHandle {}
871
872		impl PeerStoreProvider for PeerStoreHandle {
873			fn is_banned(&self, peer_id: &sc_network_types::PeerId) -> bool;
874			fn register_protocol(&self, protocol_handle: Arc<dyn ProtocolHandleT>);
875			fn report_disconnect(&self, peer_id: sc_network_types::PeerId);
876			fn set_peer_role(&self, peer_id: &sc_network_types::PeerId, role: ObservedRole);
877			fn report_peer(&self, peer_id: sc_network_types::PeerId, change: ReputationChange);
878			fn peer_reputation(&self, peer_id: &sc_network_types::PeerId) -> i32;
879			fn peer_role(&self, peer_id: &sc_network_types::PeerId) -> Option<ObservedRole>;
880			fn outgoing_candidates(&self, count: usize, ignored: HashSet<sc_network_types::PeerId>) -> Vec<sc_network_types::PeerId>;
881			fn add_known_peer(&self, peer_id: sc_network_types::PeerId);
882		}
883	}
884
885	#[test]
886	fn reserved_nodes_are_connected_dropped_and_accepted() {
887		let reserved1 = PeerId::random();
888		let reserved2 = PeerId::random();
889
890		let config = ProtoSetConfig {
892			in_peers: 0,
893			out_peers: 0,
894			reserved_nodes: std::iter::once(reserved1).collect(),
895			reserved_only: true,
896		};
897		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
898
899		let mut peer_store = MockPeerStoreHandle::new();
900		peer_store.expect_register_protocol().once().return_const(());
901		peer_store.expect_is_banned().times(4).return_const(false);
902		peer_store.expect_report_disconnect().times(2).return_const(());
903
904		let (_handle, mut controller) =
905			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
906
907		controller.on_add_reserved_peer(reserved2);
909
910		controller.alloc_slots();
913
914		let mut messages = Vec::new();
915		while let Some(message) = rx.try_recv().ok() {
916			messages.push(message);
917		}
918		assert_eq!(messages.len(), 2);
919		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: reserved1 }));
920		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: reserved2 }));
921
922		assert_eq!(controller.num_out, 0);
924		assert_eq!(controller.num_in, 0);
925
926		controller.on_peer_dropped(reserved1);
928		controller.on_peer_dropped(reserved2);
929
930		let incoming1 = IncomingIndex(1);
932		controller.on_incoming_connection(reserved1, incoming1);
933		assert_eq!(rx.try_recv().unwrap(), Message::Accept(incoming1));
934		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
935
936		let incoming2 = IncomingIndex(2);
938		controller.on_incoming_connection(reserved2, incoming2);
939		assert_eq!(rx.try_recv().unwrap(), Message::Accept(incoming2));
940		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
941
942		assert_eq!(controller.num_out, 0);
944		assert_eq!(controller.num_in, 0);
945	}
946
947	#[test]
948	fn banned_reserved_nodes_are_not_connected_and_not_accepted() {
949		let reserved1 = PeerId::random();
950		let reserved2 = PeerId::random();
951
952		let config = ProtoSetConfig {
954			in_peers: 0,
955			out_peers: 0,
956			reserved_nodes: std::iter::once(reserved1).collect(),
957			reserved_only: true,
958		};
959		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
960
961		let mut peer_store = MockPeerStoreHandle::new();
962		peer_store.expect_register_protocol().once().return_const(());
963		peer_store.expect_is_banned().times(6).return_const(true);
964
965		let (_handle, mut controller) =
966			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
967
968		controller.on_add_reserved_peer(reserved2);
970
971		controller.alloc_slots();
973
974		assert_eq!(controller.num_out, 0);
976		assert_eq!(controller.num_in, 0);
977
978		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
980
981		let incoming1 = IncomingIndex(1);
983		controller.on_incoming_connection(reserved1, incoming1);
984		assert_eq!(rx.try_recv().unwrap(), Message::Reject(incoming1));
985		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
986
987		let incoming2 = IncomingIndex(2);
989		controller.on_incoming_connection(reserved2, incoming2);
990		assert_eq!(rx.try_recv().unwrap(), Message::Reject(incoming2));
991		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
992
993		assert_eq!(controller.num_out, 0);
995		assert_eq!(controller.num_in, 0);
996	}
997
998	#[test]
999	fn we_try_to_reconnect_to_dropped_reserved_nodes() {
1000		let reserved1 = PeerId::random();
1001		let reserved2 = PeerId::random();
1002
1003		let config = ProtoSetConfig {
1005			in_peers: 0,
1006			out_peers: 0,
1007			reserved_nodes: std::iter::once(reserved1).collect(),
1008			reserved_only: true,
1009		};
1010		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1011
1012		let mut peer_store = MockPeerStoreHandle::new();
1013		peer_store.expect_register_protocol().once().return_const(());
1014		peer_store.expect_is_banned().times(4).return_const(false);
1015		peer_store.expect_report_disconnect().times(2).return_const(());
1016
1017		let (_handle, mut controller) =
1018			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1019
1020		controller.on_add_reserved_peer(reserved2);
1022
1023		controller.alloc_slots();
1025
1026		let mut messages = Vec::new();
1027		while let Some(message) = rx.try_recv().ok() {
1028			messages.push(message);
1029		}
1030
1031		assert_eq!(messages.len(), 2);
1032		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: reserved1 }));
1033		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: reserved2 }));
1034
1035		controller.on_peer_dropped(reserved1);
1037		controller.on_peer_dropped(reserved2);
1038
1039		controller.alloc_slots();
1041
1042		let mut messages = Vec::new();
1043		while let Some(message) = rx.try_recv().ok() {
1044			messages.push(message);
1045		}
1046
1047		assert_eq!(messages.len(), 2);
1048		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: reserved1 }));
1049		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: reserved2 }));
1050
1051		assert_eq!(controller.num_out, 0);
1053		assert_eq!(controller.num_in, 0);
1054	}
1055
1056	#[test]
1057	fn nodes_supplied_by_peer_store_are_connected() {
1058		let peer1 = PeerId::random();
1059		let peer2 = PeerId::random();
1060		let candidates = vec![peer1.into(), peer2.into()];
1061
1062		let config = ProtoSetConfig {
1063			in_peers: 0,
1064			out_peers: 2,
1066			reserved_nodes: HashSet::new(),
1067			reserved_only: false,
1068		};
1069		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1070
1071		let mut peer_store = MockPeerStoreHandle::new();
1072		peer_store.expect_register_protocol().once().return_const(());
1073		peer_store.expect_outgoing_candidates().once().return_const(candidates);
1074
1075		let (_handle, mut controller) =
1076			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1077
1078		controller.alloc_slots();
1080
1081		let mut messages = Vec::new();
1082		while let Some(message) = rx.try_recv().ok() {
1083			messages.push(message);
1084		}
1085
1086		assert_eq!(messages.len(), 2);
1088		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: peer1 }));
1089		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: peer2 }));
1090
1091		assert_eq!(controller.num_out, 2);
1093		assert_eq!(controller.num_in, 0);
1094
1095		controller.alloc_slots();
1097		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1098
1099		assert_eq!(controller.num_out, 2);
1101		assert_eq!(controller.num_in, 0);
1102	}
1103
1104	#[test]
1105	fn both_reserved_nodes_and_nodes_supplied_by_peer_store_are_connected() {
1106		let reserved1 = PeerId::random();
1107		let reserved2 = PeerId::random();
1108		let regular1 = PeerId::random();
1109		let regular2 = PeerId::random();
1110		let outgoing_candidates = vec![regular1.into(), regular2.into()];
1111		let reserved_nodes = [reserved1, reserved2].iter().cloned().collect();
1112
1113		let config =
1114			ProtoSetConfig { in_peers: 10, out_peers: 10, reserved_nodes, reserved_only: false };
1115		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1116
1117		let mut peer_store = MockPeerStoreHandle::new();
1118		peer_store.expect_register_protocol().once().return_const(());
1119		peer_store.expect_is_banned().times(2).return_const(false);
1120		peer_store.expect_outgoing_candidates().once().return_const(outgoing_candidates);
1121
1122		let (_handle, mut controller) =
1123			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1124
1125		controller.alloc_slots();
1127
1128		let mut messages = Vec::new();
1129		while let Some(message) = rx.try_recv().ok() {
1130			messages.push(message);
1131		}
1132		assert_eq!(messages.len(), 4);
1133		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: reserved1 }));
1134		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: reserved2 }));
1135		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: regular1 }));
1136		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: regular2 }));
1137		assert_eq!(controller.num_out, 2);
1138		assert_eq!(controller.num_in, 0);
1139	}
1140
1141	#[test]
1142	fn if_slots_are_freed_we_try_to_allocate_them_again() {
1143		let peer1 = PeerId::random();
1144		let peer2 = PeerId::random();
1145		let peer3 = PeerId::random();
1146		let candidates1 = vec![peer1.into(), peer2.into()];
1147		let candidates2 = vec![peer3.into()];
1148
1149		let config = ProtoSetConfig {
1150			in_peers: 0,
1151			out_peers: 2,
1153			reserved_nodes: HashSet::new(),
1154			reserved_only: false,
1155		};
1156		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1157
1158		let mut peer_store = MockPeerStoreHandle::new();
1159		peer_store.expect_register_protocol().once().return_const(());
1160		peer_store.expect_outgoing_candidates().once().return_const(candidates1);
1161		peer_store.expect_outgoing_candidates().once().return_const(candidates2);
1162		peer_store.expect_report_disconnect().times(2).return_const(());
1163
1164		let (_handle, mut controller) =
1165			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1166
1167		controller.alloc_slots();
1169
1170		let mut messages = Vec::new();
1171		while let Some(message) = rx.try_recv().ok() {
1172			messages.push(message);
1173		}
1174
1175		assert_eq!(messages.len(), 2);
1177		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: peer1 }));
1178		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: peer2 }));
1179
1180		assert_eq!(controller.num_out, 2);
1182		assert_eq!(controller.num_in, 0);
1183
1184		controller.alloc_slots();
1186		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1187
1188		assert_eq!(controller.num_out, 2);
1190		assert_eq!(controller.num_in, 0);
1191
1192		controller.on_peer_dropped(peer1);
1194		controller.on_peer_dropped(peer2);
1195
1196		assert_eq!(controller.num_out, 0);
1198		assert_eq!(controller.num_in, 0);
1199
1200		controller.alloc_slots();
1202
1203		let mut messages = Vec::new();
1204		while let Some(message) = rx.try_recv().ok() {
1205			messages.push(message);
1206		}
1207
1208		assert_eq!(messages.len(), 1);
1210		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: peer3 }));
1211
1212		assert_eq!(controller.num_out, 1);
1214		assert_eq!(controller.num_in, 0);
1215	}
1216
1217	#[test]
1218	fn in_reserved_only_mode_no_peers_are_requested_from_peer_store_and_connected() {
1219		let config = ProtoSetConfig {
1220			in_peers: 0,
1221			out_peers: 2,
1223			reserved_nodes: HashSet::new(),
1224			reserved_only: true,
1225		};
1226		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1227
1228		let mut peer_store = MockPeerStoreHandle::new();
1229		peer_store.expect_register_protocol().once().return_const(());
1230
1231		let (_handle, mut controller) =
1232			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1233
1234		controller.alloc_slots();
1236
1237		assert_eq!(controller.num_out, 0);
1239		assert_eq!(controller.num_in, 0);
1240		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1241	}
1242
1243	#[test]
1244	fn in_reserved_only_mode_no_regular_peers_are_accepted() {
1245		let config = ProtoSetConfig {
1246			in_peers: 2,
1248			out_peers: 0,
1249			reserved_nodes: HashSet::new(),
1250			reserved_only: true,
1251		};
1252		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1253
1254		let mut peer_store = MockPeerStoreHandle::new();
1255		peer_store.expect_register_protocol().once().return_const(());
1256
1257		let (_handle, mut controller) =
1258			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1259
1260		let peer = PeerId::random();
1261		let incoming_index = IncomingIndex(1);
1262		controller.on_incoming_connection(peer, incoming_index);
1263
1264		let mut messages = Vec::new();
1265		while let Some(message) = rx.try_recv().ok() {
1266			messages.push(message);
1267		}
1268
1269		assert_eq!(messages.len(), 1);
1271		assert!(messages.contains(&Message::Reject(incoming_index)));
1272		assert_eq!(controller.num_out, 0);
1273		assert_eq!(controller.num_in, 0);
1274	}
1275
1276	#[test]
1277	fn disabling_reserved_only_mode_allows_to_connect_to_peers() {
1278		let peer1 = PeerId::random();
1279		let peer2 = PeerId::random();
1280		let candidates = vec![peer1.into(), peer2.into()];
1281
1282		let config = ProtoSetConfig {
1283			in_peers: 0,
1284			out_peers: 10,
1286			reserved_nodes: HashSet::new(),
1287			reserved_only: true,
1288		};
1289		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1290
1291		let mut peer_store = MockPeerStoreHandle::new();
1292		peer_store.expect_register_protocol().once().return_const(());
1293		peer_store.expect_outgoing_candidates().once().return_const(candidates);
1294
1295		let (_handle, mut controller) =
1296			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1297
1298		controller.alloc_slots();
1300
1301		assert_eq!(controller.num_out, 0);
1303		assert_eq!(controller.num_in, 0);
1304		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1305
1306		controller.on_set_reserved_only(false);
1308
1309		let mut messages = Vec::new();
1310		while let Some(message) = rx.try_recv().ok() {
1311			messages.push(message);
1312		}
1313
1314		assert_eq!(messages.len(), 2);
1315		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: peer1 }));
1316		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: peer2 }));
1317		assert_eq!(controller.num_out, 2);
1318		assert_eq!(controller.num_in, 0);
1319	}
1320
1321	#[test]
1322	fn enabling_reserved_only_mode_disconnects_regular_peers() {
1323		let reserved1 = PeerId::random();
1324		let reserved2 = PeerId::random();
1325		let regular1 = PeerId::random();
1326		let regular2 = PeerId::random();
1327		let outgoing_candidates = vec![regular1.into()];
1328
1329		let config = ProtoSetConfig {
1330			in_peers: 10,
1331			out_peers: 10,
1332			reserved_nodes: [reserved1, reserved2].iter().cloned().collect(),
1333			reserved_only: false,
1334		};
1335		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1336
1337		let mut peer_store = MockPeerStoreHandle::new();
1338		peer_store.expect_register_protocol().once().return_const(());
1339		peer_store.expect_is_banned().times(3).return_const(false);
1340		peer_store.expect_outgoing_candidates().once().return_const(outgoing_candidates);
1341
1342		let (_handle, mut controller) =
1343			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1344		assert_eq!(controller.num_out, 0);
1345		assert_eq!(controller.num_in, 0);
1346
1347		controller.alloc_slots();
1349
1350		let mut messages = Vec::new();
1351		while let Some(message) = rx.try_recv().ok() {
1352			messages.push(message);
1353		}
1354		assert_eq!(messages.len(), 3);
1355		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: reserved1 }));
1356		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: reserved2 }));
1357		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: regular1 }));
1358		assert_eq!(controller.num_out, 1);
1359		assert_eq!(controller.num_in, 0);
1360
1361		let incoming_index = IncomingIndex(1);
1363		controller.on_incoming_connection(regular2, incoming_index);
1364		assert_eq!(rx.try_recv().unwrap(), Message::Accept(incoming_index));
1365		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1366		assert_eq!(controller.num_out, 1);
1367		assert_eq!(controller.num_in, 1);
1368
1369		controller.on_set_reserved_only(true);
1371
1372		let mut messages = Vec::new();
1373		while let Some(message) = rx.try_recv().ok() {
1374			messages.push(message);
1375		}
1376		assert_eq!(messages.len(), 2);
1377		assert!(messages.contains(&Message::Drop { set_id: SetId::from(0), peer_id: regular1 }));
1378		assert!(messages.contains(&Message::Drop { set_id: SetId::from(0), peer_id: regular2 }));
1379		assert_eq!(controller.nodes.len(), 0);
1380		assert_eq!(controller.num_out, 0);
1381		assert_eq!(controller.num_in, 0);
1382	}
1383
1384	#[test]
1385	fn removed_disconnected_reserved_node_is_forgotten() {
1386		let reserved1 = PeerId::random();
1387		let reserved2 = PeerId::random();
1388
1389		let config = ProtoSetConfig {
1390			in_peers: 10,
1391			out_peers: 10,
1392			reserved_nodes: [reserved1, reserved2].iter().cloned().collect(),
1393			reserved_only: false,
1394		};
1395		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1396
1397		let mut peer_store = MockPeerStoreHandle::new();
1398		peer_store.expect_register_protocol().once().return_const(());
1399
1400		let (_handle, mut controller) =
1401			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1402		assert_eq!(controller.reserved_nodes.len(), 2);
1403		assert_eq!(controller.nodes.len(), 0);
1404		assert_eq!(controller.num_out, 0);
1405		assert_eq!(controller.num_in, 0);
1406
1407		controller.on_remove_reserved_peer(reserved1);
1408		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1409		assert_eq!(controller.reserved_nodes.len(), 1);
1410		assert!(!controller.reserved_nodes.contains_key(&reserved1));
1411		assert_eq!(controller.nodes.len(), 0);
1412		assert_eq!(controller.num_out, 0);
1413		assert_eq!(controller.num_in, 0);
1414	}
1415
1416	#[test]
1417	fn removed_connected_reserved_node_is_disconnected_in_reserved_only_mode() {
1418		let reserved1 = PeerId::random();
1419		let reserved2 = PeerId::random();
1420
1421		let config = ProtoSetConfig {
1422			in_peers: 10,
1423			out_peers: 10,
1424			reserved_nodes: [reserved1, reserved2].iter().cloned().collect(),
1425			reserved_only: true,
1426		};
1427		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1428
1429		let mut peer_store = MockPeerStoreHandle::new();
1430		peer_store.expect_register_protocol().once().return_const(());
1431		peer_store.expect_is_banned().times(2).return_const(false);
1432
1433		let (_handle, mut controller) =
1434			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1435
1436		controller.alloc_slots();
1438		let mut messages = Vec::new();
1439		while let Some(message) = rx.try_recv().ok() {
1440			messages.push(message);
1441		}
1442		assert_eq!(messages.len(), 2);
1443		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: reserved1 }));
1444		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: reserved2 }));
1445		assert_eq!(controller.reserved_nodes.len(), 2);
1446		assert!(controller.reserved_nodes.contains_key(&reserved1));
1447		assert!(controller.reserved_nodes.contains_key(&reserved2));
1448		assert!(controller.nodes.is_empty());
1449
1450		controller.on_remove_reserved_peer(reserved1);
1452		assert_eq!(
1453			rx.try_recv().unwrap(),
1454			Message::Drop { set_id: SetId::from(0), peer_id: reserved1 }
1455		);
1456		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1457		assert_eq!(controller.reserved_nodes.len(), 1);
1458		assert!(controller.reserved_nodes.contains_key(&reserved2));
1459		assert!(controller.nodes.is_empty());
1460	}
1461
1462	#[test]
1463	fn removed_connected_reserved_nodes_become_regular_in_non_reserved_mode() {
1464		let peer1 = PeerId::random();
1465		let peer2 = PeerId::random();
1466
1467		let config = ProtoSetConfig {
1468			in_peers: 10,
1469			out_peers: 10,
1470			reserved_nodes: [peer1, peer2].iter().cloned().collect(),
1471			reserved_only: false,
1472		};
1473		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1474
1475		let mut peer_store = MockPeerStoreHandle::new();
1476		peer_store.expect_register_protocol().once().return_const(());
1477		peer_store.expect_is_banned().times(2).return_const(false);
1478		peer_store
1479			.expect_outgoing_candidates()
1480			.once()
1481			.return_const(Vec::<sc_network_types::PeerId>::new());
1482
1483		let (_handle, mut controller) =
1484			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1485
1486		controller.on_incoming_connection(peer1, IncomingIndex(1));
1488		controller.alloc_slots();
1489		let mut messages = Vec::new();
1490		while let Some(message) = rx.try_recv().ok() {
1491			messages.push(message);
1492		}
1493		assert_eq!(messages.len(), 2);
1494		assert!(messages.contains(&Message::Accept(IncomingIndex(1))));
1495		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: peer2 }));
1496		assert_eq!(controller.num_out, 0);
1497		assert_eq!(controller.num_in, 0);
1498
1499		controller.on_remove_reserved_peer(peer1);
1501		controller.on_remove_reserved_peer(peer2);
1502		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1503		assert_eq!(controller.nodes.len(), 2);
1504		assert!(matches!(controller.nodes.get(&peer1), Some(Direction::Inbound)));
1505		assert!(matches!(controller.nodes.get(&peer2), Some(Direction::Outbound)));
1506		assert_eq!(controller.num_out, 1);
1507		assert_eq!(controller.num_in, 1);
1508	}
1509
1510	#[test]
1511	fn regular_nodes_stop_occupying_slots_when_become_reserved() {
1512		let peer1 = PeerId::random();
1513		let peer2 = PeerId::random();
1514		let outgoing_candidates = vec![peer1.into()];
1515
1516		let config = ProtoSetConfig {
1517			in_peers: 10,
1518			out_peers: 10,
1519			reserved_nodes: HashSet::new(),
1520			reserved_only: false,
1521		};
1522		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1523
1524		let mut peer_store = MockPeerStoreHandle::new();
1525		peer_store.expect_register_protocol().once().return_const(());
1526		peer_store.expect_is_banned().once().return_const(false);
1527		peer_store.expect_outgoing_candidates().once().return_const(outgoing_candidates);
1528
1529		let (_handle, mut controller) =
1530			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1531
1532		controller.alloc_slots();
1534		controller.on_incoming_connection(peer2, IncomingIndex(1));
1535		let mut messages = Vec::new();
1536		while let Some(message) = rx.try_recv().ok() {
1537			messages.push(message);
1538		}
1539		assert_eq!(messages.len(), 2);
1540		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: peer1 }));
1541		assert!(messages.contains(&Message::Accept(IncomingIndex(1))));
1542		assert_eq!(controller.num_in, 1);
1543		assert_eq!(controller.num_out, 1);
1544
1545		controller.on_add_reserved_peer(peer1);
1546		controller.on_add_reserved_peer(peer2);
1547		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1548		assert_eq!(controller.num_in, 0);
1549		assert_eq!(controller.num_out, 0);
1550	}
1551
1552	#[test]
1553	fn disconnecting_regular_peers_work() {
1554		let peer1 = PeerId::random();
1555		let peer2 = PeerId::random();
1556		let outgoing_candidates = vec![peer1.into()];
1557
1558		let config = ProtoSetConfig {
1559			in_peers: 10,
1560			out_peers: 10,
1561			reserved_nodes: HashSet::new(),
1562			reserved_only: false,
1563		};
1564		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1565
1566		let mut peer_store = MockPeerStoreHandle::new();
1567		peer_store.expect_register_protocol().once().return_const(());
1568		peer_store.expect_is_banned().once().return_const(false);
1569		peer_store.expect_outgoing_candidates().once().return_const(outgoing_candidates);
1570
1571		let (_handle, mut controller) =
1572			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1573
1574		controller.alloc_slots();
1576		controller.on_incoming_connection(peer2, IncomingIndex(1));
1577		let mut messages = Vec::new();
1578		while let Some(message) = rx.try_recv().ok() {
1579			messages.push(message);
1580		}
1581		assert_eq!(messages.len(), 2);
1582		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: peer1 }));
1583		assert!(messages.contains(&Message::Accept(IncomingIndex(1))));
1584		assert_eq!(controller.nodes.len(), 2);
1585		assert!(matches!(controller.nodes.get(&peer1), Some(Direction::Outbound)));
1586		assert!(matches!(controller.nodes.get(&peer2), Some(Direction::Inbound)));
1587		assert_eq!(controller.num_in, 1);
1588		assert_eq!(controller.num_out, 1);
1589
1590		controller.on_disconnect_peer(peer1);
1591		assert_eq!(
1592			rx.try_recv().unwrap(),
1593			Message::Drop { set_id: SetId::from(0), peer_id: peer1 }
1594		);
1595		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1596		assert_eq!(controller.nodes.len(), 1);
1597		assert!(!controller.nodes.contains_key(&peer1));
1598		assert_eq!(controller.num_in, 1);
1599		assert_eq!(controller.num_out, 0);
1600
1601		controller.on_disconnect_peer(peer2);
1602		assert_eq!(
1603			rx.try_recv().unwrap(),
1604			Message::Drop { set_id: SetId::from(0), peer_id: peer2 }
1605		);
1606		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1607		assert_eq!(controller.nodes.len(), 0);
1608		assert_eq!(controller.num_in, 0);
1609		assert_eq!(controller.num_out, 0);
1610	}
1611
1612	#[test]
1613	fn disconnecting_reserved_peers_is_a_noop() {
1614		let reserved1 = PeerId::random();
1615		let reserved2 = PeerId::random();
1616
1617		let config = ProtoSetConfig {
1618			in_peers: 10,
1619			out_peers: 10,
1620			reserved_nodes: [reserved1, reserved2].iter().cloned().collect(),
1621			reserved_only: false,
1622		};
1623		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1624
1625		let mut peer_store = MockPeerStoreHandle::new();
1626		peer_store.expect_register_protocol().once().return_const(());
1627		peer_store.expect_is_banned().times(2).return_const(false);
1628		peer_store.expect_outgoing_candidates().once().return_const(Vec::new());
1629
1630		let (_handle, mut controller) =
1631			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1632
1633		controller.on_incoming_connection(reserved1, IncomingIndex(1));
1635		controller.alloc_slots();
1636		let mut messages = Vec::new();
1637		while let Some(message) = rx.try_recv().ok() {
1638			messages.push(message);
1639		}
1640		assert_eq!(messages.len(), 2);
1641		assert!(messages.contains(&Message::Accept(IncomingIndex(1))));
1642		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: reserved2 }));
1643		assert!(matches!(
1644			controller.reserved_nodes.get(&reserved1),
1645			Some(PeerState::Connected(Direction::Inbound))
1646		));
1647		assert!(matches!(
1648			controller.reserved_nodes.get(&reserved2),
1649			Some(PeerState::Connected(Direction::Outbound))
1650		));
1651
1652		controller.on_disconnect_peer(reserved1);
1653		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1654		assert!(matches!(
1655			controller.reserved_nodes.get(&reserved1),
1656			Some(PeerState::Connected(Direction::Inbound))
1657		));
1658
1659		controller.on_disconnect_peer(reserved2);
1660		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1661		assert!(matches!(
1662			controller.reserved_nodes.get(&reserved2),
1663			Some(PeerState::Connected(Direction::Outbound))
1664		));
1665	}
1666
1667	#[test]
1668	fn dropping_regular_peers_work() {
1669		let peer1 = PeerId::random();
1670		let peer2 = PeerId::random();
1671		let outgoing_candidates = vec![peer1.into()];
1672
1673		let config = ProtoSetConfig {
1674			in_peers: 10,
1675			out_peers: 10,
1676			reserved_nodes: HashSet::new(),
1677			reserved_only: false,
1678		};
1679		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1680
1681		let mut peer_store = MockPeerStoreHandle::new();
1682		peer_store.expect_register_protocol().once().return_const(());
1683		peer_store.expect_is_banned().once().return_const(false);
1684		peer_store.expect_outgoing_candidates().once().return_const(outgoing_candidates);
1685		peer_store.expect_report_disconnect().times(2).return_const(());
1686
1687		let (_handle, mut controller) =
1688			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1689
1690		controller.alloc_slots();
1692		controller.on_incoming_connection(peer2, IncomingIndex(1));
1693		let mut messages = Vec::new();
1694		while let Some(message) = rx.try_recv().ok() {
1695			messages.push(message);
1696		}
1697		assert_eq!(messages.len(), 2);
1698		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: peer1 }));
1699		assert!(messages.contains(&Message::Accept(IncomingIndex(1))));
1700		assert_eq!(controller.nodes.len(), 2);
1701		assert!(matches!(controller.nodes.get(&peer1), Some(Direction::Outbound)));
1702		assert!(matches!(controller.nodes.get(&peer2), Some(Direction::Inbound)));
1703		assert_eq!(controller.num_in, 1);
1704		assert_eq!(controller.num_out, 1);
1705
1706		controller.on_peer_dropped(peer1);
1707		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1708		assert_eq!(controller.nodes.len(), 1);
1709		assert!(!controller.nodes.contains_key(&peer1));
1710		assert_eq!(controller.num_in, 1);
1711		assert_eq!(controller.num_out, 0);
1712
1713		controller.on_peer_dropped(peer2);
1714		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1715		assert_eq!(controller.nodes.len(), 0);
1716		assert_eq!(controller.num_in, 0);
1717		assert_eq!(controller.num_out, 0);
1718	}
1719
1720	#[test]
1721	fn incoming_request_for_connected_reserved_node_switches_it_to_inbound() {
1722		let reserved1 = PeerId::random();
1723		let reserved2 = PeerId::random();
1724
1725		let config = ProtoSetConfig {
1726			in_peers: 10,
1727			out_peers: 10,
1728			reserved_nodes: [reserved1, reserved2].iter().cloned().collect(),
1729			reserved_only: false,
1730		};
1731		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1732
1733		let mut peer_store = MockPeerStoreHandle::new();
1734		peer_store.expect_register_protocol().once().return_const(());
1735		peer_store.expect_is_banned().times(2).return_const(false);
1736		peer_store.expect_outgoing_candidates().once().return_const(Vec::new());
1737
1738		let (_handle, mut controller) =
1739			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1740
1741		controller.on_incoming_connection(reserved1, IncomingIndex(1));
1743		controller.alloc_slots();
1744		let mut messages = Vec::new();
1745		while let Some(message) = rx.try_recv().ok() {
1746			messages.push(message);
1747		}
1748		assert_eq!(messages.len(), 2);
1749		assert!(messages.contains(&Message::Accept(IncomingIndex(1))));
1750		assert!(messages.contains(&Message::Connect { set_id: SetId::from(0), peer_id: reserved2 }));
1751		assert!(matches!(
1752			controller.reserved_nodes.get(&reserved1),
1753			Some(PeerState::Connected(Direction::Inbound))
1754		));
1755		assert!(matches!(
1756			controller.reserved_nodes.get(&reserved2),
1757			Some(PeerState::Connected(Direction::Outbound))
1758		));
1759
1760		controller.on_incoming_connection(reserved1, IncomingIndex(2));
1762		assert_eq!(rx.try_recv().ok().unwrap(), Message::Accept(IncomingIndex(2)));
1763		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1764		assert!(matches!(
1765			controller.reserved_nodes.get(&reserved1),
1766			Some(PeerState::Connected(Direction::Inbound))
1767		));
1768
1769		controller.on_incoming_connection(reserved2, IncomingIndex(3));
1771		assert_eq!(rx.try_recv().ok().unwrap(), Message::Accept(IncomingIndex(3)));
1772		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1773		assert!(matches!(
1774			controller.reserved_nodes.get(&reserved2),
1775			Some(PeerState::Connected(Direction::Inbound))
1776		));
1777	}
1778
1779	#[test]
1780	fn incoming_request_for_connected_regular_node_switches_it_to_inbound() {
1781		let regular1 = PeerId::random();
1782		let regular2 = PeerId::random();
1783		let outgoing_candidates = vec![regular1.into()];
1784
1785		let config = ProtoSetConfig {
1786			in_peers: 10,
1787			out_peers: 10,
1788			reserved_nodes: HashSet::new(),
1789			reserved_only: false,
1790		};
1791		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1792
1793		let mut peer_store = MockPeerStoreHandle::new();
1794		peer_store.expect_register_protocol().once().return_const(());
1795		peer_store.expect_is_banned().times(3).return_const(false);
1796		peer_store.expect_outgoing_candidates().once().return_const(outgoing_candidates);
1797
1798		let (_handle, mut controller) =
1799			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1800		assert_eq!(controller.num_out, 0);
1801		assert_eq!(controller.num_in, 0);
1802
1803		controller.alloc_slots();
1805		assert_eq!(
1806			rx.try_recv().ok().unwrap(),
1807			Message::Connect { set_id: SetId::from(0), peer_id: regular1 }
1808		);
1809		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1810		assert!(matches!(controller.nodes.get(®ular1).unwrap(), Direction::Outbound,));
1811
1812		controller.on_incoming_connection(regular2, IncomingIndex(0));
1814		assert_eq!(rx.try_recv().ok().unwrap(), Message::Accept(IncomingIndex(0)));
1815		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1816		assert!(matches!(controller.nodes.get(®ular2).unwrap(), Direction::Inbound,));
1817
1818		controller.on_incoming_connection(regular1, IncomingIndex(1));
1820		assert_eq!(rx.try_recv().ok().unwrap(), Message::Accept(IncomingIndex(1)));
1821		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1822		assert!(matches!(controller.nodes.get(®ular1).unwrap(), Direction::Inbound,));
1823
1824		controller.on_incoming_connection(regular2, IncomingIndex(2));
1826		assert_eq!(rx.try_recv().ok().unwrap(), Message::Accept(IncomingIndex(2)));
1827		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1828		assert!(matches!(controller.nodes.get(®ular2).unwrap(), Direction::Inbound,));
1829	}
1830
1831	#[test]
1832	fn incoming_request_for_connected_node_is_rejected_if_its_banned() {
1833		let regular1 = PeerId::random();
1834		let regular2 = PeerId::random();
1835		let outgoing_candidates = vec![regular1.into()];
1836
1837		let config = ProtoSetConfig {
1838			in_peers: 10,
1839			out_peers: 10,
1840			reserved_nodes: HashSet::new(),
1841			reserved_only: false,
1842		};
1843		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1844
1845		let mut peer_store = MockPeerStoreHandle::new();
1846		peer_store.expect_register_protocol().once().return_const(());
1847		peer_store.expect_is_banned().once().return_const(false);
1848		peer_store.expect_is_banned().times(2).return_const(true);
1849		peer_store.expect_outgoing_candidates().once().return_const(outgoing_candidates);
1850
1851		let (_handle, mut controller) =
1852			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1853		assert_eq!(controller.num_out, 0);
1854		assert_eq!(controller.num_in, 0);
1855
1856		controller.alloc_slots();
1858		assert_eq!(
1859			rx.try_recv().ok().unwrap(),
1860			Message::Connect { set_id: SetId::from(0), peer_id: regular1 }
1861		);
1862		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1863		assert!(matches!(controller.nodes.get(®ular1).unwrap(), Direction::Outbound,));
1864
1865		controller.on_incoming_connection(regular2, IncomingIndex(0));
1867		assert_eq!(rx.try_recv().ok().unwrap(), Message::Accept(IncomingIndex(0)));
1868		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1869		assert!(matches!(controller.nodes.get(®ular2).unwrap(), Direction::Inbound,));
1870
1871		controller.on_incoming_connection(regular1, IncomingIndex(1));
1873		assert_eq!(rx.try_recv().ok().unwrap(), Message::Reject(IncomingIndex(1)));
1874		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1875		assert!(!controller.nodes.contains_key(®ular1));
1876
1877		controller.on_incoming_connection(regular2, IncomingIndex(2));
1879		assert_eq!(rx.try_recv().ok().unwrap(), Message::Reject(IncomingIndex(2)));
1880		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1881		assert!(!controller.nodes.contains_key(®ular2));
1882	}
1883
1884	#[test]
1885	fn incoming_request_for_connected_node_is_rejected_if_no_slots_available() {
1886		let regular1 = PeerId::random();
1887		let regular2 = PeerId::random();
1888		let outgoing_candidates = vec![regular1.into()];
1889
1890		let config = ProtoSetConfig {
1891			in_peers: 1,
1892			out_peers: 1,
1893			reserved_nodes: HashSet::new(),
1894			reserved_only: false,
1895		};
1896		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1897
1898		let mut peer_store = MockPeerStoreHandle::new();
1899		peer_store.expect_register_protocol().once().return_const(());
1900		peer_store.expect_is_banned().once().return_const(false);
1901		peer_store.expect_outgoing_candidates().once().return_const(outgoing_candidates);
1902
1903		let (_handle, mut controller) =
1904			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1905		assert_eq!(controller.num_out, 0);
1906		assert_eq!(controller.num_in, 0);
1907
1908		controller.alloc_slots();
1910		assert_eq!(
1911			rx.try_recv().ok().unwrap(),
1912			Message::Connect { set_id: SetId::from(0), peer_id: regular1 }
1913		);
1914		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1915		assert!(matches!(controller.nodes.get(®ular1).unwrap(), Direction::Outbound,));
1916
1917		controller.on_incoming_connection(regular2, IncomingIndex(0));
1919		assert_eq!(rx.try_recv().ok().unwrap(), Message::Accept(IncomingIndex(0)));
1920		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1921		assert!(matches!(controller.nodes.get(®ular2).unwrap(), Direction::Inbound,));
1922
1923		controller.max_in = 0;
1924
1925		controller.on_incoming_connection(regular1, IncomingIndex(1));
1927		assert_eq!(rx.try_recv().ok().unwrap(), Message::Reject(IncomingIndex(1)));
1928		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1929		assert!(!controller.nodes.contains_key(®ular1));
1930
1931		controller.on_incoming_connection(regular2, IncomingIndex(2));
1933		assert_eq!(rx.try_recv().ok().unwrap(), Message::Reject(IncomingIndex(2)));
1934		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1935		assert!(!controller.nodes.contains_key(®ular2));
1936	}
1937
1938	#[test]
1939	fn incoming_peers_that_exceed_slots_are_rejected() {
1940		let peer1 = PeerId::random();
1941		let peer2 = PeerId::random();
1942
1943		let config = ProtoSetConfig {
1944			in_peers: 1,
1945			out_peers: 10,
1946			reserved_nodes: HashSet::new(),
1947			reserved_only: false,
1948		};
1949		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1950
1951		let mut peer_store = MockPeerStoreHandle::new();
1952		peer_store.expect_register_protocol().once().return_const(());
1953		peer_store.expect_is_banned().once().return_const(false);
1954
1955		let (_handle, mut controller) =
1956			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1957
1958		controller.on_incoming_connection(peer1, IncomingIndex(1));
1960		assert_eq!(rx.try_recv().unwrap(), Message::Accept(IncomingIndex(1)));
1961		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1962
1963		controller.on_incoming_connection(peer2, IncomingIndex(2));
1965		assert_eq!(rx.try_recv().unwrap(), Message::Reject(IncomingIndex(2)));
1966		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1967	}
1968
1969	#[test]
1970	fn banned_regular_incoming_node_is_rejected() {
1971		let peer1 = PeerId::random();
1972
1973		let config = ProtoSetConfig {
1974			in_peers: 10,
1975			out_peers: 10,
1976			reserved_nodes: HashSet::new(),
1977			reserved_only: false,
1978		};
1979		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
1980
1981		let mut peer_store = MockPeerStoreHandle::new();
1982		peer_store.expect_register_protocol().once().return_const(());
1983		peer_store.expect_is_banned().once().return_const(true);
1984
1985		let (_handle, mut controller) =
1986			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
1987
1988		controller.on_incoming_connection(peer1, IncomingIndex(1));
1990		assert_eq!(rx.try_recv().unwrap(), Message::Reject(IncomingIndex(1)));
1991		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
1992	}
1993
1994	#[test]
1995	fn banned_reserved_incoming_node_is_rejected() {
1996		let reserved1 = PeerId::random();
1997
1998		let config = ProtoSetConfig {
1999			in_peers: 10,
2000			out_peers: 10,
2001			reserved_nodes: std::iter::once(reserved1).collect(),
2002			reserved_only: false,
2003		};
2004		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
2005
2006		let mut peer_store = MockPeerStoreHandle::new();
2007		peer_store.expect_register_protocol().once().return_const(());
2008		peer_store.expect_is_banned().once().return_const(true);
2009
2010		let (_handle, mut controller) =
2011			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
2012		assert!(controller.reserved_nodes.contains_key(&reserved1));
2013
2014		controller.on_incoming_connection(reserved1, IncomingIndex(1));
2016		assert_eq!(rx.try_recv().unwrap(), Message::Reject(IncomingIndex(1)));
2017		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
2018	}
2019
2020	#[test]
2021	fn we_dont_connect_to_banned_reserved_node() {
2022		let reserved1 = PeerId::random();
2023
2024		let config = ProtoSetConfig {
2025			in_peers: 10,
2026			out_peers: 10,
2027			reserved_nodes: std::iter::once(reserved1).collect(),
2028			reserved_only: false,
2029		};
2030		let (tx, mut rx) = tracing_unbounded("mpsc_test_to_notifications", 100);
2031
2032		let mut peer_store = MockPeerStoreHandle::new();
2033		peer_store.expect_register_protocol().once().return_const(());
2034		peer_store.expect_is_banned().once().return_const(true);
2035		peer_store.expect_outgoing_candidates().once().return_const(Vec::new());
2036
2037		let (_handle, mut controller) =
2038			ProtocolController::new(SetId::from(0), config, tx, Arc::new(peer_store));
2039		assert!(matches!(controller.reserved_nodes.get(&reserved1), Some(PeerState::NotConnected)));
2040
2041		controller.alloc_slots();
2043		assert!(matches!(controller.reserved_nodes.get(&reserved1), Some(PeerState::NotConnected)));
2044		assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty);
2045	}
2046}