use std::{
collections::{HashMap, VecDeque},
future::Future,
num::NonZeroUsize,
ops::Range,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use bitvec::{bitvec, vec::BitVec};
use futures::FutureExt;
use polkadot_node_network_protocol::PeerId;
use polkadot_primitives::{AuthorityDiscoveryId, CandidateHash, GroupIndex, SessionIndex};
pub const MAX_CHAINED_CANDIDATES_PER_RCB: NonZeroUsize = match NonZeroUsize::new(3) {
Some(cap) => cap,
None => panic!("max candidates per rcb cannot be zero"),
};
pub const VALIDATORS_BUFFER_CAPACITY: NonZeroUsize =
match NonZeroUsize::new(3 * MAX_CHAINED_CANDIDATES_PER_RCB.get()) {
Some(cap) => cap,
None => panic!("buffer capacity must be non-zero"),
};
#[derive(Debug)]
struct ValidatorsGroupInfo {
len: usize,
session_index: SessionIndex,
group_index: GroupIndex,
}
#[derive(Debug)]
pub struct ValidatorGroupsBuffer {
group_infos: VecDeque<ValidatorsGroupInfo>,
validators: VecDeque<AuthorityDiscoveryId>,
should_be_connected: HashMap<CandidateHash, BitVec>,
cap: NonZeroUsize,
}
impl ValidatorGroupsBuffer {
pub fn with_capacity(cap: NonZeroUsize) -> Self {
Self {
group_infos: VecDeque::new(),
validators: VecDeque::new(),
should_be_connected: HashMap::new(),
cap,
}
}
pub fn validators_to_connect(&self) -> Vec<AuthorityDiscoveryId> {
let validators_num = self.validators.len();
let bits = self
.should_be_connected
.values()
.fold(bitvec![0; validators_num], |acc, next| acc | next);
let mut should_be_connected: Vec<AuthorityDiscoveryId> = self
.validators
.iter()
.enumerate()
.filter_map(|(idx, authority_id)| bits[idx].then(|| authority_id.clone()))
.collect();
if let Some(last_group) = self.group_infos.iter().last() {
for validator in self.validators.iter().rev().take(last_group.len) {
if !should_be_connected.contains(validator) {
should_be_connected.push(validator.clone());
}
}
}
should_be_connected
}
pub fn note_collation_advertised(
&mut self,
candidate_hash: CandidateHash,
session_index: SessionIndex,
group_index: GroupIndex,
validators: &[AuthorityDiscoveryId],
) {
if validators.is_empty() {
return
}
match self.group_infos.iter().enumerate().find(|(_, group)| {
group.session_index == session_index && group.group_index == group_index
}) {
Some((idx, group)) => {
let group_start_idx = self.group_lengths_iter().take(idx).sum();
self.set_bits(candidate_hash, group_start_idx..(group_start_idx + group.len));
},
None => self.push(candidate_hash, session_index, group_index, validators),
}
}
pub fn reset_validator_interest(
&mut self,
candidate_hash: CandidateHash,
authority_id: &AuthorityDiscoveryId,
) {
let bits = match self.should_be_connected.get_mut(&candidate_hash) {
Some(bits) => bits,
None => return,
};
for (idx, auth_id) in self.validators.iter().enumerate() {
if auth_id == authority_id {
bits.set(idx, false);
}
}
}
pub fn remove_candidate(&mut self, candidate_hash: &CandidateHash) {
self.should_be_connected.remove(candidate_hash);
}
fn push(
&mut self,
candidate_hash: CandidateHash,
session_index: SessionIndex,
group_index: GroupIndex,
validators: &[AuthorityDiscoveryId],
) {
let new_group_info =
ValidatorsGroupInfo { len: validators.len(), session_index, group_index };
let buf = &mut self.group_infos;
let cap = self.cap.get();
if buf.len() >= cap {
let pruned_group = buf.pop_front().expect("buf is not empty; qed");
self.validators.drain(..pruned_group.len);
self.should_be_connected.values_mut().for_each(|bits| {
bits.as_mut_bitslice().shift_left(pruned_group.len);
});
}
self.validators.extend(validators.iter().cloned());
buf.push_back(new_group_info);
let buf_len = buf.len();
let group_start_idx = self.group_lengths_iter().take(buf_len - 1).sum();
let new_len = self.validators.len();
self.should_be_connected
.values_mut()
.for_each(|bits| bits.resize(new_len, false));
self.set_bits(candidate_hash, group_start_idx..(group_start_idx + validators.len()));
}
fn set_bits(&mut self, candidate_hash: CandidateHash, range: Range<usize>) {
let bits = self
.should_be_connected
.entry(candidate_hash)
.or_insert_with(|| bitvec![0; self.validators.len()]);
bits[range].fill(true);
}
fn group_lengths_iter(&self) -> impl Iterator<Item = usize> + '_ {
self.group_infos.iter().map(|group| group.len)
}
}
pub const RESET_INTEREST_TIMEOUT: Duration = Duration::from_secs(6);
pub struct ResetInterestTimeout {
fut: futures_timer::Delay,
candidate_hash: CandidateHash,
peer_id: PeerId,
}
impl ResetInterestTimeout {
pub fn new(candidate_hash: CandidateHash, peer_id: PeerId, delay: Duration) -> Self {
Self { fut: futures_timer::Delay::new(delay), candidate_hash, peer_id }
}
}
impl Future for ResetInterestTimeout {
type Output = (CandidateHash, PeerId);
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.fut.poll_unpin(cx).map(|_| (self.candidate_hash, self.peer_id))
}
}
#[cfg(test)]
mod tests {
use super::*;
use polkadot_primitives::Hash;
use sp_keyring::Sr25519Keyring;
#[test]
fn one_capacity_buffer() {
let cap = NonZeroUsize::new(1).unwrap();
let mut buf = ValidatorGroupsBuffer::with_capacity(cap);
let hash_a = CandidateHash(Hash::repeat_byte(0x1));
let hash_b = CandidateHash(Hash::repeat_byte(0x2));
let validators: Vec<_> = [
Sr25519Keyring::Alice,
Sr25519Keyring::Bob,
Sr25519Keyring::Charlie,
Sr25519Keyring::Dave,
Sr25519Keyring::Ferdie,
]
.into_iter()
.map(|key| AuthorityDiscoveryId::from(key.public()))
.collect();
assert!(buf.validators_to_connect().is_empty());
buf.note_collation_advertised(hash_a, 0, GroupIndex(0), &validators[..2]);
assert_eq!(buf.validators_to_connect(), validators[..2].to_vec());
buf.reset_validator_interest(hash_a, &validators[1]);
assert_eq!(buf.validators_to_connect(), validators[0..2].to_vec());
buf.note_collation_advertised(hash_b, 0, GroupIndex(1), &validators[2..]);
assert_eq!(buf.validators_to_connect(), validators[2..].to_vec());
for validator in &validators[2..] {
buf.reset_validator_interest(hash_b, validator);
}
let mut expected = validators[2..].to_vec();
expected.sort();
let mut result = buf.validators_to_connect();
result.sort();
assert_eq!(result, expected);
}
#[test]
fn buffer_works() {
let cap = NonZeroUsize::new(3).unwrap();
let mut buf = ValidatorGroupsBuffer::with_capacity(cap);
let hashes: Vec<_> = (0..5).map(|i| CandidateHash(Hash::repeat_byte(i))).collect();
let validators: Vec<_> = [
Sr25519Keyring::Alice,
Sr25519Keyring::Bob,
Sr25519Keyring::Charlie,
Sr25519Keyring::Dave,
Sr25519Keyring::Ferdie,
]
.into_iter()
.map(|key| AuthorityDiscoveryId::from(key.public()))
.collect();
buf.note_collation_advertised(hashes[0], 0, GroupIndex(0), &validators[..2]);
buf.note_collation_advertised(hashes[1], 0, GroupIndex(0), &validators[..2]);
buf.note_collation_advertised(hashes[2], 0, GroupIndex(1), &validators[2..4]);
buf.note_collation_advertised(hashes[2], 0, GroupIndex(1), &validators[2..4]);
assert_eq!(buf.validators_to_connect(), validators[..4].to_vec());
for validator in &validators[2..4] {
buf.reset_validator_interest(hashes[2], validator);
}
buf.reset_validator_interest(hashes[1], &validators[0]);
let mut expected: Vec<_> = validators[..4].iter().cloned().collect();
let mut result = buf.validators_to_connect();
expected.sort();
result.sort();
assert_eq!(result, expected);
buf.reset_validator_interest(hashes[0], &validators[0]);
let mut expected: Vec<_> = validators[1..4].iter().cloned().collect();
expected.sort();
let mut result = buf.validators_to_connect();
result.sort();
assert_eq!(result, expected);
buf.note_collation_advertised(hashes[3], 0, GroupIndex(1), &validators[2..4]);
buf.note_collation_advertised(
hashes[4],
0,
GroupIndex(2),
std::slice::from_ref(&validators[4]),
);
buf.reset_validator_interest(hashes[3], &validators[2]);
buf.note_collation_advertised(
hashes[4],
0,
GroupIndex(3),
std::slice::from_ref(&validators[0]),
);
assert_eq!(
buf.validators_to_connect(),
vec![validators[3].clone(), validators[4].clone(), validators[0].clone()]
);
}
}