1use crate::PeerId;
36use polkadot_primitives::{AuthorityDiscoveryId, SessionIndex, ValidatorIndex};
37use rand::{CryptoRng, Rng};
38use std::{
39 collections::{hash_map, HashMap, HashSet},
40 fmt::Debug,
41};
42
43const LOG_TARGET: &str = "parachain::grid-topology";
44
45pub const DEFAULT_RANDOM_SAMPLE_RATE: usize = crate::MIN_GOSSIP_PEERS;
50
51pub const DEFAULT_RANDOM_CIRCULATION: usize = 4;
53
54#[derive(Debug, Clone, PartialEq)]
56pub struct TopologyPeerInfo {
57 pub peer_ids: Vec<PeerId>,
59 pub validator_index: ValidatorIndex,
62 pub discovery_id: AuthorityDiscoveryId,
65}
66
67#[derive(Default, Clone, Debug, PartialEq)]
69pub struct SessionGridTopology {
70 shuffled_indices: Vec<usize>,
74 canonical_shuffling: Vec<TopologyPeerInfo>,
76 peer_ids: HashSet<PeerId>,
78}
79
80impl SessionGridTopology {
81 pub fn new(shuffled_indices: Vec<usize>, canonical_shuffling: Vec<TopologyPeerInfo>) -> Self {
83 let mut peer_ids = HashSet::new();
84 for peer_info in canonical_shuffling.iter() {
85 for peer_id in peer_info.peer_ids.iter() {
86 peer_ids.insert(*peer_id);
87 }
88 }
89 SessionGridTopology { shuffled_indices, canonical_shuffling, peer_ids }
90 }
91
92 pub fn update_authority_ids(
94 &mut self,
95 peer_id: PeerId,
96 ids: &HashSet<AuthorityDiscoveryId>,
97 ) -> bool {
98 let mut updated = false;
99 if !self.peer_ids.contains(&peer_id) {
100 for peer in self
101 .canonical_shuffling
102 .iter_mut()
103 .filter(|peer| ids.contains(&peer.discovery_id))
104 {
105 peer.peer_ids.push(peer_id);
106 self.peer_ids.insert(peer_id);
107 updated = true;
108 }
109 }
110 updated
111 }
112 pub fn compute_grid_neighbors_for(&self, v: ValidatorIndex) -> Option<GridNeighbors> {
116 if self.shuffled_indices.len() != self.canonical_shuffling.len() {
117 return None
118 }
119 let shuffled_val_index = *self.shuffled_indices.get(v.0 as usize)?;
120
121 let neighbors = matrix_neighbors(shuffled_val_index, self.shuffled_indices.len())?;
122
123 let mut grid_subset = GridNeighbors::empty();
124 for r_n in neighbors.row_neighbors {
125 let n = &self.canonical_shuffling[r_n];
126 grid_subset.validator_indices_x.insert(n.validator_index);
127 for p in &n.peer_ids {
128 grid_subset.peers_x.insert(*p);
129 }
130 }
131
132 for c_n in neighbors.column_neighbors {
133 let n = &self.canonical_shuffling[c_n];
134 grid_subset.validator_indices_y.insert(n.validator_index);
135 for p in &n.peer_ids {
136 grid_subset.peers_y.insert(*p);
137 }
138 }
139
140 Some(grid_subset)
141 }
142
143 pub fn is_validator(&self, peer: &PeerId) -> bool {
145 self.peer_ids.contains(peer)
146 }
147}
148
149struct MatrixNeighbors<R, C> {
150 row_neighbors: R,
151 column_neighbors: C,
152}
153
154fn matrix_neighbors(
156 val_index: usize,
157 len: usize,
158) -> Option<MatrixNeighbors<impl Iterator<Item = usize>, impl Iterator<Item = usize>>> {
159 if val_index >= len {
160 return None
161 }
162
163 let sqrt = (len as f64).sqrt() as usize;
173 let our_row = val_index / sqrt;
174 let our_column = val_index % sqrt;
175 let row_neighbors = our_row * sqrt..std::cmp::min(our_row * sqrt + sqrt, len);
176 let column_neighbors = (our_column..len).step_by(sqrt);
177
178 Some(MatrixNeighbors {
179 row_neighbors: row_neighbors.filter(move |i| *i != val_index),
180 column_neighbors: column_neighbors.filter(move |i| *i != val_index),
181 })
182}
183
184#[derive(Debug, Clone, PartialEq)]
186pub struct GridNeighbors {
187 pub peers_x: HashSet<PeerId>,
189 pub validator_indices_x: HashSet<ValidatorIndex>,
191 pub peers_y: HashSet<PeerId>,
193 pub validator_indices_y: HashSet<ValidatorIndex>,
195}
196
197impl GridNeighbors {
198 pub fn empty() -> Self {
201 GridNeighbors {
202 peers_x: HashSet::new(),
203 validator_indices_x: HashSet::new(),
204 peers_y: HashSet::new(),
205 validator_indices_y: HashSet::new(),
206 }
207 }
208
209 pub fn required_routing_by_index(
212 &self,
213 originator: ValidatorIndex,
214 local: bool,
215 ) -> RequiredRouting {
216 if local {
217 return RequiredRouting::GridXY
218 }
219
220 let grid_x = self.validator_indices_x.contains(&originator);
221 let grid_y = self.validator_indices_y.contains(&originator);
222
223 match (grid_x, grid_y) {
224 (false, false) => RequiredRouting::None,
225 (true, false) => RequiredRouting::GridY, (false, true) => RequiredRouting::GridX, (true, true) => RequiredRouting::GridXY, }
230 }
231
232 pub fn required_routing_by_peer_id(&self, originator: PeerId, local: bool) -> RequiredRouting {
235 if local {
236 return RequiredRouting::GridXY
237 }
238
239 let grid_x = self.peers_x.contains(&originator);
240 let grid_y = self.peers_y.contains(&originator);
241
242 match (grid_x, grid_y) {
243 (false, false) => RequiredRouting::None,
244 (true, false) => RequiredRouting::GridY, (false, true) => RequiredRouting::GridX, (true, true) => {
247 gum::debug!(
248 target: LOG_TARGET,
249 ?originator,
250 "Grid topology is unexpected, play it safe and send to X AND Y"
251 );
252 RequiredRouting::GridXY
253 }, }
256 }
257
258 pub fn route_to_peer(&self, required_routing: RequiredRouting, peer: &PeerId) -> bool {
262 match required_routing {
263 RequiredRouting::All => true,
264 RequiredRouting::GridX => self.peers_x.contains(peer),
265 RequiredRouting::GridY => self.peers_y.contains(peer),
266 RequiredRouting::GridXY => self.peers_x.contains(peer) || self.peers_y.contains(peer),
267 RequiredRouting::None | RequiredRouting::PendingTopology => false,
268 }
269 }
270
271 pub fn peers_diff(&self, other: &Self) -> Vec<PeerId> {
273 self.peers_x
274 .iter()
275 .chain(self.peers_y.iter())
276 .filter(|peer_id| !(other.peers_x.contains(peer_id) || other.peers_y.contains(peer_id)))
277 .cloned()
278 .collect::<Vec<_>>()
279 }
280
281 pub fn len(&self) -> usize {
283 self.peers_x.len().saturating_add(self.peers_y.len())
284 }
285}
286
287#[derive(Debug)]
289pub struct SessionGridTopologyEntry {
290 topology: SessionGridTopology,
291 local_neighbors: GridNeighbors,
292 local_index: Option<ValidatorIndex>,
293}
294
295impl SessionGridTopologyEntry {
296 pub fn local_grid_neighbors(&self) -> &GridNeighbors {
298 &self.local_neighbors
299 }
300
301 pub fn local_grid_neighbors_mut(&mut self) -> &mut GridNeighbors {
303 &mut self.local_neighbors
304 }
305
306 pub fn get(&self) -> &SessionGridTopology {
308 &self.topology
309 }
310
311 pub fn is_validator(&self, peer: &PeerId) -> bool {
313 self.topology.is_validator(peer)
314 }
315
316 pub fn peers_to_route(&self, required_routing: RequiredRouting) -> Vec<PeerId> {
318 match required_routing {
319 RequiredRouting::All => self.topology.peer_ids.iter().copied().collect(),
320 RequiredRouting::GridX => self.local_neighbors.peers_x.iter().copied().collect(),
321 RequiredRouting::GridY => self.local_neighbors.peers_y.iter().copied().collect(),
322 RequiredRouting::GridXY => self
323 .local_neighbors
324 .peers_x
325 .iter()
326 .chain(self.local_neighbors.peers_y.iter())
327 .copied()
328 .collect(),
329 RequiredRouting::None | RequiredRouting::PendingTopology => Vec::new(),
330 }
331 }
332
333 pub fn update_authority_ids(
335 &mut self,
336 peer_id: PeerId,
337 ids: &HashSet<AuthorityDiscoveryId>,
338 ) -> bool {
339 let peer_id_updated = self.topology.update_authority_ids(peer_id, ids);
340 if peer_id_updated {
343 if let Some(local_index) = self.local_index.as_ref() {
344 if let Some(new_grid) = self.topology.compute_grid_neighbors_for(*local_index) {
345 self.local_neighbors = new_grid;
346 }
347 }
348 }
349 peer_id_updated
350 }
351}
352
353#[derive(Default)]
355pub struct SessionGridTopologies {
356 inner: HashMap<SessionIndex, (Option<SessionGridTopologyEntry>, usize)>,
357}
358
359impl SessionGridTopologies {
360 pub fn get_topology(&self, session: SessionIndex) -> Option<&SessionGridTopologyEntry> {
362 self.inner.get(&session).and_then(|val| val.0.as_ref())
363 }
364
365 pub fn update_authority_ids(
367 &mut self,
368 peer_id: PeerId,
369 ids: &HashSet<AuthorityDiscoveryId>,
370 ) -> bool {
371 self.inner
372 .iter_mut()
373 .map(|(_, topology)| {
374 topology.0.as_mut().map(|topology| topology.update_authority_ids(peer_id, ids))
375 })
376 .any(|updated| updated.unwrap_or_default())
377 }
378
379 pub fn inc_session_refs(&mut self, session: SessionIndex) {
381 self.inner.entry(session).or_insert((None, 0)).1 += 1;
382 }
383
384 pub fn dec_session_refs(&mut self, session: SessionIndex) {
386 if let hash_map::Entry::Occupied(mut occupied) = self.inner.entry(session) {
387 occupied.get_mut().1 = occupied.get().1.saturating_sub(1);
388 if occupied.get().1 == 0 {
389 let _ = occupied.remove();
390 }
391 }
392 }
393
394 pub fn insert_topology(
396 &mut self,
397 session: SessionIndex,
398 topology: SessionGridTopology,
399 local_index: Option<ValidatorIndex>,
400 ) {
401 let entry = self.inner.entry(session).or_insert((None, 0));
402 if entry.0.is_none() {
403 let local_neighbors = local_index
404 .and_then(|l| topology.compute_grid_neighbors_for(l))
405 .unwrap_or_else(GridNeighbors::empty);
406
407 entry.0 = Some(SessionGridTopologyEntry { topology, local_neighbors, local_index });
408 }
409 }
410}
411
412#[derive(Debug)]
414struct GridTopologySessionBound {
415 entry: SessionGridTopologyEntry,
416 session_index: SessionIndex,
417}
418
419#[derive(Debug)]
421pub struct SessionBoundGridTopologyStorage {
422 current_topology: GridTopologySessionBound,
423 prev_topology: Option<GridTopologySessionBound>,
424}
425
426impl Default for SessionBoundGridTopologyStorage {
427 fn default() -> Self {
428 SessionBoundGridTopologyStorage {
431 current_topology: GridTopologySessionBound {
432 session_index: SessionIndex::max_value(),
435 entry: SessionGridTopologyEntry {
436 topology: SessionGridTopology {
437 shuffled_indices: Vec::new(),
438 canonical_shuffling: Vec::new(),
439 peer_ids: Default::default(),
440 },
441 local_neighbors: GridNeighbors::empty(),
442 local_index: None,
443 },
444 },
445 prev_topology: None,
446 }
447 }
448}
449
450impl SessionBoundGridTopologyStorage {
451 pub fn get_topology_or_fallback(&self, idx: SessionIndex) -> &SessionGridTopologyEntry {
455 self.get_topology(idx).unwrap_or(&self.current_topology.entry)
456 }
457
458 pub fn get_topology(&self, idx: SessionIndex) -> Option<&SessionGridTopologyEntry> {
461 if let Some(prev_topology) = &self.prev_topology {
462 if idx == prev_topology.session_index {
463 return Some(&prev_topology.entry)
464 }
465 }
466 if self.current_topology.session_index == idx {
467 return Some(&self.current_topology.entry)
468 }
469
470 None
471 }
472
473 pub fn update_topology(
475 &mut self,
476 session_index: SessionIndex,
477 topology: SessionGridTopology,
478 local_index: Option<ValidatorIndex>,
479 ) {
480 let local_neighbors = local_index
481 .and_then(|l| topology.compute_grid_neighbors_for(l))
482 .unwrap_or_else(GridNeighbors::empty);
483
484 let old_current = std::mem::replace(
485 &mut self.current_topology,
486 GridTopologySessionBound {
487 entry: SessionGridTopologyEntry { topology, local_neighbors, local_index },
488 session_index,
489 },
490 );
491 self.prev_topology.replace(old_current);
492 }
493
494 pub fn get_current_topology(&self) -> &SessionGridTopologyEntry {
496 &self.current_topology.entry
497 }
498
499 pub fn get_current_session_index(&self) -> SessionIndex {
501 self.current_topology.session_index
502 }
503
504 pub fn get_current_topology_mut(&mut self) -> &mut SessionGridTopologyEntry {
507 &mut self.current_topology.entry
508 }
509}
510
511#[derive(Debug, Clone, Copy)]
513pub struct RandomRouting {
514 target: usize,
516 sent: usize,
518 sample_rate: usize,
520}
521
522impl Default for RandomRouting {
523 fn default() -> Self {
524 RandomRouting {
525 target: DEFAULT_RANDOM_CIRCULATION,
526 sent: 0_usize,
527 sample_rate: DEFAULT_RANDOM_SAMPLE_RATE,
528 }
529 }
530}
531
532impl RandomRouting {
533 pub fn sample(&self, n_peers_total: usize, rng: &mut (impl CryptoRng + Rng)) -> bool {
536 if n_peers_total == 0 || self.sent >= self.target {
537 false
538 } else if self.sample_rate > n_peers_total {
539 true
540 } else {
541 rng.gen_ratio(self.sample_rate as _, n_peers_total as _)
542 }
543 }
544
545 pub fn inc_sent(&mut self) {
547 self.sent += 1
548 }
549
550 pub fn is_complete(&self) -> bool {
552 self.sent >= self.target
553 }
554}
555
556#[derive(Debug, Clone, Copy, PartialEq)]
558pub enum RequiredRouting {
559 PendingTopology,
563 All,
565 GridXY,
567 GridX,
569 GridY,
571 None,
573}
574
575impl RequiredRouting {
576 pub fn is_empty(self) -> bool {
578 match self {
579 RequiredRouting::PendingTopology | RequiredRouting::None => true,
580 _ => false,
581 }
582 }
583
584 pub fn combine(self, other: Self) -> Self {
586 match (self, other) {
587 (RequiredRouting::All, _) | (_, RequiredRouting::All) => RequiredRouting::All,
588 (RequiredRouting::GridXY, _) | (_, RequiredRouting::GridXY) => RequiredRouting::GridXY,
589 (RequiredRouting::GridX, RequiredRouting::GridY) |
590 (RequiredRouting::GridY, RequiredRouting::GridX) => RequiredRouting::GridXY,
591 (RequiredRouting::GridX, RequiredRouting::GridX) => RequiredRouting::GridX,
592 (RequiredRouting::GridY, RequiredRouting::GridY) => RequiredRouting::GridY,
593 (RequiredRouting::None, RequiredRouting::PendingTopology) |
594 (RequiredRouting::PendingTopology, RequiredRouting::None) => RequiredRouting::PendingTopology,
595 (RequiredRouting::None, _) | (RequiredRouting::PendingTopology, _) => other,
596 (_, RequiredRouting::None) | (_, RequiredRouting::PendingTopology) => self,
597 }
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604 use rand::SeedableRng;
605 use rand_chacha::ChaCha12Rng;
606
607 fn dummy_rng() -> ChaCha12Rng {
608 rand_chacha::ChaCha12Rng::seed_from_u64(12345)
609 }
610
611 #[test]
612 fn test_required_routing_combine() {
613 assert_eq!(RequiredRouting::All.combine(RequiredRouting::None), RequiredRouting::All);
614 assert_eq!(RequiredRouting::All.combine(RequiredRouting::GridXY), RequiredRouting::All);
615 assert_eq!(RequiredRouting::GridXY.combine(RequiredRouting::All), RequiredRouting::All);
616 assert_eq!(RequiredRouting::None.combine(RequiredRouting::All), RequiredRouting::All);
617 assert_eq!(RequiredRouting::None.combine(RequiredRouting::None), RequiredRouting::None);
618 assert_eq!(
619 RequiredRouting::PendingTopology.combine(RequiredRouting::GridX),
620 RequiredRouting::GridX
621 );
622
623 assert_eq!(
624 RequiredRouting::GridX.combine(RequiredRouting::PendingTopology),
625 RequiredRouting::GridX
626 );
627 assert_eq!(RequiredRouting::GridX.combine(RequiredRouting::GridY), RequiredRouting::GridXY);
628 assert_eq!(RequiredRouting::GridY.combine(RequiredRouting::GridX), RequiredRouting::GridXY);
629 assert_eq!(
630 RequiredRouting::GridXY.combine(RequiredRouting::GridXY),
631 RequiredRouting::GridXY
632 );
633 assert_eq!(RequiredRouting::GridX.combine(RequiredRouting::GridX), RequiredRouting::GridX);
634 assert_eq!(RequiredRouting::GridY.combine(RequiredRouting::GridY), RequiredRouting::GridY);
635
636 assert_eq!(RequiredRouting::None.combine(RequiredRouting::GridY), RequiredRouting::GridY);
637 assert_eq!(RequiredRouting::None.combine(RequiredRouting::GridX), RequiredRouting::GridX);
638 assert_eq!(RequiredRouting::None.combine(RequiredRouting::GridXY), RequiredRouting::GridXY);
639
640 assert_eq!(RequiredRouting::GridY.combine(RequiredRouting::None), RequiredRouting::GridY);
641 assert_eq!(RequiredRouting::GridX.combine(RequiredRouting::None), RequiredRouting::GridX);
642 assert_eq!(RequiredRouting::GridXY.combine(RequiredRouting::None), RequiredRouting::GridXY);
643
644 assert_eq!(
645 RequiredRouting::PendingTopology.combine(RequiredRouting::None),
646 RequiredRouting::PendingTopology
647 );
648
649 assert_eq!(
650 RequiredRouting::None.combine(RequiredRouting::PendingTopology),
651 RequiredRouting::PendingTopology
652 );
653 }
654
655 #[test]
656 fn test_random_routing_sample() {
657 let mut rng = dummy_rng();
660 let mut random_routing = RandomRouting { target: 4, sent: 0, sample_rate: 8 };
661
662 assert_eq!(random_routing.sample(16, &mut rng), true);
663 random_routing.inc_sent();
664 assert_eq!(random_routing.sample(16, &mut rng), false);
665 assert_eq!(random_routing.sample(16, &mut rng), false);
666 assert_eq!(random_routing.sample(16, &mut rng), true);
667 random_routing.inc_sent();
668 assert_eq!(random_routing.sample(16, &mut rng), true);
669 random_routing.inc_sent();
670 assert_eq!(random_routing.sample(16, &mut rng), false);
671 assert_eq!(random_routing.sample(16, &mut rng), false);
672 assert_eq!(random_routing.sample(16, &mut rng), false);
673 assert_eq!(random_routing.sample(16, &mut rng), true);
674 random_routing.inc_sent();
675
676 for _ in 0..16 {
677 assert_eq!(random_routing.sample(16, &mut rng), false);
678 }
679 }
680
681 fn run_random_routing(
682 random_routing: &mut RandomRouting,
683 rng: &mut (impl CryptoRng + Rng),
684 npeers: usize,
685 iters: usize,
686 ) -> usize {
687 let mut ret = 0_usize;
688
689 for _ in 0..iters {
690 if random_routing.sample(npeers, rng) {
691 random_routing.inc_sent();
692 ret += 1;
693 }
694 }
695
696 ret
697 }
698
699 #[test]
700 fn test_random_routing_distribution() {
701 let mut rng = dummy_rng();
702
703 let mut random_routing = RandomRouting { target: 4, sent: 0, sample_rate: 8 };
704 assert_eq!(run_random_routing(&mut random_routing, &mut rng, 100, 10000), 4);
705
706 let mut random_routing = RandomRouting { target: 8, sent: 0, sample_rate: 100 };
707 assert_eq!(run_random_routing(&mut random_routing, &mut rng, 100, 10000), 8);
708
709 let mut random_routing = RandomRouting { target: 0, sent: 0, sample_rate: 100 };
710 assert_eq!(run_random_routing(&mut random_routing, &mut rng, 100, 10000), 0);
711
712 let mut random_routing = RandomRouting { target: 10, sent: 0, sample_rate: 10 };
713 assert_eq!(run_random_routing(&mut random_routing, &mut rng, 10, 100), 10);
714 }
715
716 #[test]
717 fn test_matrix_neighbors() {
718 for (our_index, len, expected_row, expected_column) in vec![
719 (0usize, 1usize, vec![], vec![]),
720 (1, 2, vec![], vec![0usize]),
721 (0, 9, vec![1, 2], vec![3, 6]),
722 (9, 10, vec![], vec![0, 3, 6]),
723 (10, 11, vec![9], vec![1, 4, 7]),
724 (7, 11, vec![6, 8], vec![1, 4, 10]),
725 ]
726 .into_iter()
727 {
728 let matrix = matrix_neighbors(our_index, len).unwrap();
729 let mut row_result: Vec<_> = matrix.row_neighbors.collect();
730 let mut column_result: Vec<_> = matrix.column_neighbors.collect();
731 row_result.sort();
732 column_result.sort();
733
734 assert_eq!(row_result, expected_row);
735 assert_eq!(column_result, expected_column);
736 }
737 }
738}