use std::{marker::PhantomData, pin::Pin, sync::Arc};
use crate::{graph::ChainApi, LOG_TARGET};
use sc_utils::mpsc::{tracing_unbounded, TracingUnboundedReceiver, TracingUnboundedSender};
use sp_blockchain::HashAndNumber;
use sp_runtime::traits::Block as BlockT;
use super::tx_mem_pool::TxMemPool;
use futures::prelude::*;
use super::view::{FinishRevalidationWorkerChannels, View};
enum WorkerPayload<Api, Block>
where
Block: BlockT,
Api: ChainApi<Block = Block> + 'static,
{
RevalidateView(Arc<View<Api>>, FinishRevalidationWorkerChannels<Api>),
RevalidateMempool(Arc<TxMemPool<Api, Block>>, HashAndNumber<Block>),
}
struct RevalidationWorker<Block: BlockT> {
_phantom: PhantomData<Block>,
}
impl<Block> RevalidationWorker<Block>
where
Block: BlockT,
<Block as BlockT>::Hash: Unpin,
{
fn new() -> Self {
Self { _phantom: Default::default() }
}
pub async fn run<Api: ChainApi<Block = Block> + 'static>(
self,
from_queue: TracingUnboundedReceiver<WorkerPayload<Api, Block>>,
) {
let mut from_queue = from_queue.fuse();
loop {
let Some(payload) = from_queue.next().await else {
break;
};
match payload {
WorkerPayload::RevalidateView(view, worker_channels) =>
view.revalidate(worker_channels).await,
WorkerPayload::RevalidateMempool(mempool, finalized_hash_and_number) =>
mempool.revalidate(finalized_hash_and_number).await,
};
}
}
}
pub struct RevalidationQueue<Api, Block>
where
Api: ChainApi<Block = Block> + 'static,
Block: BlockT,
{
background: Option<TracingUnboundedSender<WorkerPayload<Api, Block>>>,
}
impl<Api, Block> RevalidationQueue<Api, Block>
where
Api: ChainApi<Block = Block> + 'static,
Block: BlockT,
<Block as BlockT>::Hash: Unpin,
{
pub fn new() -> Self {
Self { background: None }
}
pub fn new_with_worker() -> (Self, Pin<Box<dyn Future<Output = ()> + Send>>) {
let (to_worker, from_queue) = tracing_unbounded("mpsc_revalidation_queue", 100_000);
(Self { background: Some(to_worker) }, RevalidationWorker::new().run(from_queue).boxed())
}
pub async fn revalidate_view(
&self,
view: Arc<View<Api>>,
finish_revalidation_worker_channels: FinishRevalidationWorkerChannels<Api>,
) {
log::trace!(
target: LOG_TARGET,
"revalidation_queue::revalidate_view: Sending view to revalidation queue at {}",
view.at.hash
);
if let Some(ref to_worker) = self.background {
if let Err(e) = to_worker.unbounded_send(WorkerPayload::RevalidateView(
view,
finish_revalidation_worker_channels,
)) {
log::warn!(target: LOG_TARGET, "revalidation_queue::revalidate_view: Failed to update background worker: {:?}", e);
}
} else {
view.revalidate(finish_revalidation_worker_channels).await
}
}
pub async fn revalidate_mempool(
&self,
mempool: Arc<TxMemPool<Api, Block>>,
finalized_hash: HashAndNumber<Block>,
) {
log::trace!(
target: LOG_TARGET,
"Sent mempool to revalidation queue at hash: {:?}",
finalized_hash
);
if let Some(ref to_worker) = self.background {
if let Err(e) =
to_worker.unbounded_send(WorkerPayload::RevalidateMempool(mempool, finalized_hash))
{
log::warn!(target: LOG_TARGET, "Failed to update background worker: {:?}", e);
}
} else {
mempool.revalidate(finalized_hash).await
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
common::tests::{uxt, TestApi},
fork_aware_txpool::view::FinishRevalidationLocalChannels,
TimedTransactionSource,
};
use futures::executor::block_on;
use substrate_test_runtime::{AccountId, Transfer, H256};
use substrate_test_runtime_client::Sr25519Keyring::Alice;
#[test]
fn revalidation_queue_works() {
let api = Arc::new(TestApi::default());
let block0 = api.expect_hash_and_number(0);
let view = Arc::new(View::new(
api.clone(),
block0,
Default::default(),
Default::default(),
false.into(),
));
let queue = Arc::new(RevalidationQueue::new());
let uxt = uxt(Transfer {
from: Alice.into(),
to: AccountId::from_h256(H256::from_low_u64_be(2)),
amount: 5,
nonce: 0,
});
let _ = block_on(view.submit_many(std::iter::once((
TimedTransactionSource::new_external(false),
uxt.clone().into(),
))));
assert_eq!(api.validation_requests().len(), 1);
let (finish_revalidation_request_tx, finish_revalidation_request_rx) =
tokio::sync::mpsc::channel(1);
let (revalidation_result_tx, revalidation_result_rx) = tokio::sync::mpsc::channel(1);
let finish_revalidation_worker_channels = FinishRevalidationWorkerChannels::new(
finish_revalidation_request_rx,
revalidation_result_tx,
);
let _finish_revalidation_local_channels = FinishRevalidationLocalChannels::new(
finish_revalidation_request_tx,
revalidation_result_rx,
);
block_on(queue.revalidate_view(view.clone(), finish_revalidation_worker_channels));
assert_eq!(api.validation_requests().len(), 2);
assert_eq!(view.status().ready, 1);
}
}