use crate::{
futures_undead::FuturesUndead,
task::{
strategy::{
do_post_recovery_check, is_unavailable, OngoingRequests, N_PARALLEL,
REGULAR_CHUNKS_REQ_RETRY_LIMIT,
},
RecoveryParams, State,
},
ErasureTask, RecoveryStrategy, LOG_TARGET,
};
use polkadot_node_primitives::AvailableData;
use polkadot_node_subsystem::{overseer, RecoveryError};
use polkadot_primitives::ValidatorIndex;
use futures::{channel::oneshot, SinkExt};
use rand::seq::SliceRandom;
use std::collections::VecDeque;
pub struct FetchChunksParams {
pub n_validators: usize,
}
pub struct FetchChunks {
error_count: usize,
total_received_responses: usize,
validators: VecDeque<ValidatorIndex>,
requesting_chunks: OngoingRequests,
}
impl FetchChunks {
pub fn new(params: FetchChunksParams) -> Self {
let mut validators: VecDeque<ValidatorIndex> =
(0..params.n_validators).map(|i| ValidatorIndex(i as u32)).collect();
validators.make_contiguous().shuffle(&mut rand::thread_rng());
Self {
error_count: 0,
total_received_responses: 0,
validators,
requesting_chunks: FuturesUndead::new(),
}
}
fn is_unavailable(
unrequested_validators: usize,
in_flight_requests: usize,
chunk_count: usize,
threshold: usize,
) -> bool {
is_unavailable(chunk_count, in_flight_requests, unrequested_validators, threshold)
}
fn get_desired_request_count(&self, chunk_count: usize, threshold: usize) -> usize {
let max_requests_boundary = std::cmp::min(N_PARALLEL, threshold);
let remaining_chunks = threshold.saturating_sub(chunk_count);
let inv_error_rate =
self.total_received_responses.checked_div(self.error_count).unwrap_or(0);
std::cmp::min(
max_requests_boundary,
remaining_chunks + remaining_chunks.checked_div(inv_error_rate).unwrap_or(0),
)
}
async fn attempt_recovery<Sender: overseer::AvailabilityRecoverySenderTrait>(
&mut self,
state: &mut State,
common_params: &RecoveryParams,
) -> Result<AvailableData, RecoveryError> {
let recovery_duration =
common_params
.metrics
.time_erasure_recovery(RecoveryStrategy::<Sender>::strategy_type(self));
let (avilable_data_tx, available_data_rx) = oneshot::channel();
let mut erasure_task_tx = common_params.erasure_task_tx.clone();
erasure_task_tx
.send(ErasureTask::Reconstruct(
common_params.n_validators,
std::mem::take(&mut state.received_chunks)
.into_iter()
.map(|(c_index, chunk)| (c_index, chunk.chunk))
.collect(),
avilable_data_tx,
))
.await
.map_err(|_| RecoveryError::ChannelClosed)?;
let available_data_response =
available_data_rx.await.map_err(|_| RecoveryError::ChannelClosed)?;
match available_data_response {
Ok(data) => do_post_recovery_check(common_params, data)
.await
.inspect_err(|_| {
recovery_duration.map(|rd| rd.stop_and_discard());
})
.inspect(|_| {
gum::trace!(
target: LOG_TARGET,
candidate_hash = ?common_params.candidate_hash,
erasure_root = ?common_params.erasure_root,
"Data recovery from chunks complete",
);
}),
Err(err) => {
recovery_duration.map(|rd| rd.stop_and_discard());
gum::debug!(
target: LOG_TARGET,
candidate_hash = ?common_params.candidate_hash,
erasure_root = ?common_params.erasure_root,
?err,
"Data recovery error",
);
Err(RecoveryError::Invalid)
},
}
}
}
#[async_trait::async_trait]
impl<Sender: overseer::AvailabilityRecoverySenderTrait> RecoveryStrategy<Sender> for FetchChunks {
fn display_name(&self) -> &'static str {
"Fetch chunks"
}
fn strategy_type(&self) -> &'static str {
"regular_chunks"
}
async fn run(
mut self: Box<Self>,
state: &mut State,
sender: &mut Sender,
common_params: &RecoveryParams,
) -> Result<AvailableData, RecoveryError> {
if !common_params.bypass_availability_store {
let local_chunk_indices = state.populate_from_av_store(common_params, sender).await;
self.validators.retain(|validator_index| {
!local_chunk_indices.iter().any(|(v_index, _)| v_index == validator_index)
});
}
self.validators.retain(|v_index| {
!state.received_chunks.values().any(|c| v_index == &c.validator_index) &&
state.can_retry_request(
&(common_params.validator_authority_keys[v_index.0 as usize].clone(), *v_index),
REGULAR_CHUNKS_REQ_RETRY_LIMIT,
)
});
let mut validators_queue: VecDeque<_> = std::mem::take(&mut self.validators)
.into_iter()
.map(|validator_index| {
(
common_params.validator_authority_keys[validator_index.0 as usize].clone(),
validator_index,
)
})
.collect();
loop {
if state.chunk_count() >= common_params.threshold {
return self.attempt_recovery::<Sender>(state, common_params).await
}
if Self::is_unavailable(
validators_queue.len(),
self.requesting_chunks.total_len(),
state.chunk_count(),
common_params.threshold,
) {
gum::debug!(
target: LOG_TARGET,
candidate_hash = ?common_params.candidate_hash,
erasure_root = ?common_params.erasure_root,
received = %state.chunk_count(),
requesting = %self.requesting_chunks.len(),
total_requesting = %self.requesting_chunks.total_len(),
n_validators = %common_params.n_validators,
"Data recovery from chunks is not possible",
);
return Err(RecoveryError::Unavailable)
}
let desired_requests_count =
self.get_desired_request_count(state.chunk_count(), common_params.threshold);
let already_requesting_count = self.requesting_chunks.len();
gum::debug!(
target: LOG_TARGET,
?common_params.candidate_hash,
?desired_requests_count,
error_count= ?self.error_count,
total_received = ?self.total_received_responses,
threshold = ?common_params.threshold,
?already_requesting_count,
"Requesting availability chunks for a candidate",
);
let strategy_type = RecoveryStrategy::<Sender>::strategy_type(&*self);
state
.launch_parallel_chunk_requests(
strategy_type,
common_params,
sender,
desired_requests_count,
&mut validators_queue,
&mut self.requesting_chunks,
)
.await;
let (total_responses, error_count) = state
.wait_for_chunks(
strategy_type,
common_params,
REGULAR_CHUNKS_REQ_RETRY_LIMIT,
&mut validators_queue,
&mut self.requesting_chunks,
&mut vec![],
|unrequested_validators,
in_flight_reqs,
chunk_count,
_systematic_chunk_count| {
chunk_count >= common_params.threshold ||
Self::is_unavailable(
unrequested_validators,
in_flight_reqs,
chunk_count,
common_params.threshold,
)
},
)
.await;
self.total_received_responses += total_responses;
self.error_count += error_count;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use polkadot_erasure_coding::recovery_threshold;
#[test]
fn test_get_desired_request_count() {
let n_validators = 100;
let threshold = recovery_threshold(n_validators).unwrap();
let mut fetch_chunks_task = FetchChunks::new(FetchChunksParams { n_validators });
assert_eq!(fetch_chunks_task.get_desired_request_count(0, threshold), threshold);
fetch_chunks_task.error_count = 1;
fetch_chunks_task.total_received_responses = 1;
assert_eq!(fetch_chunks_task.get_desired_request_count(0, threshold), threshold);
assert_eq!(fetch_chunks_task.get_desired_request_count(0, N_PARALLEL + 2), N_PARALLEL);
fetch_chunks_task.total_received_responses = 2;
assert_eq!(fetch_chunks_task.get_desired_request_count(1, threshold), threshold);
fetch_chunks_task.total_received_responses = 10;
assert_eq!(fetch_chunks_task.get_desired_request_count(9, threshold), 27);
assert_eq!(fetch_chunks_task.get_desired_request_count(9, N_PARALLEL + 9), N_PARALLEL);
fetch_chunks_task.error_count = 0;
assert_eq!(fetch_chunks_task.get_desired_request_count(10, threshold), threshold - 10);
}
}