use std::{
collections::{HashMap, HashSet},
pin::Pin,
task::Poll,
time::Duration,
};
use futures::{channel::oneshot, future::poll_fn, Future};
use futures_timer::Delay;
use indexmap::{map::Entry, IndexMap};
use polkadot_node_network_protocol::request_response::v1::DisputeRequest;
use polkadot_node_primitives::{DisputeMessage, DisputeStatus};
use polkadot_node_subsystem::{
messages::DisputeCoordinatorMessage, overseer, ActiveLeavesUpdate, SubsystemSender,
};
use polkadot_node_subsystem_util::{nesting_sender::NestingSender, runtime::RuntimeInfo};
use polkadot_primitives::{CandidateHash, Hash, SessionIndex};
mod send_task;
use send_task::SendTask;
pub use send_task::TaskFinish;
mod error;
pub use error::{Error, FatalError, JfyiError, Result};
use self::error::JfyiErrorResult;
use crate::{Metrics, LOG_TARGET, SEND_RATE_LIMIT};
#[derive(Debug)]
pub enum DisputeSenderMessage {
TaskFinish(TaskFinish),
ActiveDisputesReady(JfyiErrorResult<Vec<(SessionIndex, CandidateHash, DisputeStatus)>>),
}
pub struct DisputeSender<M> {
active_heads: Vec<Hash>,
active_sessions: HashMap<SessionIndex, Hash>,
disputes: IndexMap<CandidateHash, SendTask<M>>,
tx: NestingSender<M, DisputeSenderMessage>,
waiting_for_active_disputes: Option<WaitForActiveDisputesState>,
rate_limit: RateLimit,
metrics: Metrics,
}
struct WaitForActiveDisputesState {
have_new_sessions: bool,
}
#[overseer::contextbounds(DisputeDistribution, prefix = self::overseer)]
impl<M: 'static + Send + Sync> DisputeSender<M> {
pub fn new(tx: NestingSender<M, DisputeSenderMessage>, metrics: Metrics) -> Self {
Self {
active_heads: Vec::new(),
active_sessions: HashMap::new(),
disputes: IndexMap::new(),
tx,
waiting_for_active_disputes: None,
rate_limit: RateLimit::new(),
metrics,
}
}
pub async fn start_sender<Context>(
&mut self,
ctx: &mut Context,
runtime: &mut RuntimeInfo,
msg: DisputeMessage,
) -> Result<()> {
let req: DisputeRequest = msg.into();
let candidate_hash = req.0.candidate_receipt.hash();
match self.disputes.entry(candidate_hash) {
Entry::Occupied(_) => {
gum::trace!(target: LOG_TARGET, ?candidate_hash, "Dispute sending already active.");
return Ok(())
},
Entry::Vacant(vacant) => {
self.rate_limit.limit("in start_sender", candidate_hash).await;
let send_task = SendTask::new(
ctx,
runtime,
&self.active_sessions,
NestingSender::new(self.tx.clone(), DisputeSenderMessage::TaskFinish),
req,
&self.metrics,
)
.await?;
vacant.insert(send_task);
},
}
Ok(())
}
pub async fn on_message<Context>(
&mut self,
ctx: &mut Context,
runtime: &mut RuntimeInfo,
msg: DisputeSenderMessage,
) -> Result<()> {
match msg {
DisputeSenderMessage::TaskFinish(msg) => {
let TaskFinish { candidate_hash, receiver, result } = msg;
self.metrics.on_sent_request(result.as_metrics_label());
let task = match self.disputes.get_mut(&candidate_hash) {
None => {
gum::trace!(
target: LOG_TARGET,
?result,
"Received `FromSendingTask::Finished` for non existing dispute."
);
return Ok(())
},
Some(task) => task,
};
task.on_finished_send(&receiver, result);
},
DisputeSenderMessage::ActiveDisputesReady(result) => {
let state = self.waiting_for_active_disputes.take();
let have_new_sessions = state.map(|s| s.have_new_sessions).unwrap_or(false);
let active_disputes = result?;
self.handle_new_active_disputes(ctx, runtime, active_disputes, have_new_sessions)
.await?;
},
}
Ok(())
}
pub async fn update_leaves<Context>(
&mut self,
ctx: &mut Context,
runtime: &mut RuntimeInfo,
update: ActiveLeavesUpdate,
) -> Result<()> {
let ActiveLeavesUpdate { activated, deactivated } = update;
let deactivated: HashSet<_> = deactivated.into_iter().collect();
self.active_heads.retain(|h| !deactivated.contains(h));
self.active_heads.extend(activated.into_iter().map(|l| l.hash));
let have_new_sessions = self.refresh_sessions(ctx, runtime).await?;
match self.waiting_for_active_disputes.take() {
None => {
self.waiting_for_active_disputes =
Some(WaitForActiveDisputesState { have_new_sessions });
let mut sender = ctx.sender().clone();
let mut tx = self.tx.clone();
let get_active_disputes_task = async move {
let result = get_active_disputes(&mut sender).await;
let result =
tx.send_message(DisputeSenderMessage::ActiveDisputesReady(result)).await;
if let Err(err) = result {
gum::debug!(
target: LOG_TARGET,
?err,
"Sending `DisputeSenderMessage` from background task failed."
);
}
};
ctx.spawn("get_active_disputes", Box::pin(get_active_disputes_task))
.map_err(FatalError::SpawnTask)?;
},
Some(state) => {
let have_new_sessions = state.have_new_sessions || have_new_sessions;
let new_state = WaitForActiveDisputesState { have_new_sessions };
self.waiting_for_active_disputes = Some(new_state);
gum::debug!(
target: LOG_TARGET,
"Dispute coordinator slow? We are still waiting for data on next active leaves update."
);
},
}
Ok(())
}
async fn handle_new_active_disputes<Context>(
&mut self,
ctx: &mut Context,
runtime: &mut RuntimeInfo,
active_disputes: Vec<(SessionIndex, CandidateHash, DisputeStatus)>,
have_new_sessions: bool,
) -> Result<()> {
let active_disputes: HashSet<_> = active_disputes.into_iter().map(|(_, c, _)| c).collect();
self.disputes
.retain(|candidate_hash, _| active_disputes.contains(candidate_hash));
let mut should_rate_limit = true;
for (candidate_hash, dispute) in self.disputes.iter_mut() {
if have_new_sessions || dispute.has_failed_sends() {
if should_rate_limit {
self.rate_limit
.limit("while going through new sessions/failed sends", *candidate_hash)
.await;
}
let sends_happened = dispute
.refresh_sends(ctx, runtime, &self.active_sessions, &self.metrics)
.await?;
should_rate_limit = sends_happened && have_new_sessions;
}
}
Ok(())
}
async fn refresh_sessions<Context>(
&mut self,
ctx: &mut Context,
runtime: &mut RuntimeInfo,
) -> Result<bool> {
let new_sessions = get_active_session_indices(ctx, runtime, &self.active_heads).await?;
let new_sessions_raw: HashSet<_> = new_sessions.keys().collect();
let old_sessions_raw: HashSet<_> = self.active_sessions.keys().collect();
let updated = new_sessions_raw != old_sessions_raw;
self.active_sessions = new_sessions;
Ok(updated)
}
}
struct RateLimit {
limit: Delay,
}
impl RateLimit {
fn new() -> Self {
Self { limit: Delay::new(Duration::new(0, 0)) }
}
fn new_limit() -> Self {
Self { limit: Delay::new(SEND_RATE_LIMIT) }
}
async fn limit(&mut self, occasion: &'static str, candidate_hash: CandidateHash) {
let mut num_wakes: u32 = 0;
poll_fn(|cx| {
let old_limit = Pin::new(&mut self.limit);
match old_limit.poll(cx) {
Poll::Pending => {
gum::debug!(
target: LOG_TARGET,
?occasion,
?candidate_hash,
?num_wakes,
"Sending rate limit hit, slowing down requests"
);
num_wakes += 1;
Poll::Pending
},
Poll::Ready(()) => Poll::Ready(()),
}
})
.await;
*self = Self::new_limit();
}
}
#[overseer::contextbounds(DisputeDistribution, prefix = self::overseer)]
async fn get_active_session_indices<Context>(
ctx: &mut Context,
runtime: &mut RuntimeInfo,
active_heads: &Vec<Hash>,
) -> Result<HashMap<SessionIndex, Hash>> {
let mut indices = HashMap::new();
for head in active_heads {
let session_index = runtime.get_session_index_for_child(ctx.sender(), *head).await?;
if let Err(err) =
runtime.get_session_info_by_index(ctx.sender(), *head, session_index).await
{
gum::debug!(target: LOG_TARGET, ?err, ?session_index, "Can't cache SessionInfo");
}
indices.insert(session_index, *head);
}
Ok(indices)
}
async fn get_active_disputes<Sender>(
sender: &mut Sender,
) -> JfyiErrorResult<Vec<(SessionIndex, CandidateHash, DisputeStatus)>>
where
Sender: SubsystemSender<DisputeCoordinatorMessage>,
{
let (tx, rx) = oneshot::channel();
sender.send_message(DisputeCoordinatorMessage::ActiveDisputes(tx)).await;
rx.await.map_err(|_| JfyiError::AskActiveDisputesCanceled)
}