use futures::{stream::FusedStream, StreamExt};
use sc_consensus::{BlockImport, StateAction};
use sc_utils::mpsc::{tracing_unbounded, TracingUnboundedReceiver, TracingUnboundedSender};
use sp_api::{ApiExt, CallApiAt, CallContext, Core, ProvideRuntimeApi, StorageProof};
use sp_runtime::traits::{Block as BlockT, Header as _};
use sp_trie::proof_size_extension::ProofSizeExt;
use std::sync::Arc;
pub struct SlotBasedBlockImportHandle<Block> {
receiver: TracingUnboundedReceiver<(Block, StorageProof)>,
}
impl<Block> SlotBasedBlockImportHandle<Block> {
pub async fn next(&mut self) -> (Block, StorageProof) {
loop {
if self.receiver.is_terminated() {
futures::pending!()
} else if let Some(res) = self.receiver.next().await {
return res
}
}
}
}
pub struct SlotBasedBlockImport<Block, BI, Client> {
inner: BI,
client: Arc<Client>,
sender: TracingUnboundedSender<(Block, StorageProof)>,
}
impl<Block, BI, Client> SlotBasedBlockImport<Block, BI, Client> {
pub fn new(inner: BI, client: Arc<Client>) -> (Self, SlotBasedBlockImportHandle<Block>) {
let (sender, receiver) = tracing_unbounded("SlotBasedBlockImportChannel", 1000);
(Self { sender, client, inner }, SlotBasedBlockImportHandle { receiver })
}
}
impl<Block, BI: Clone, Client> Clone for SlotBasedBlockImport<Block, BI, Client> {
fn clone(&self) -> Self {
Self { inner: self.inner.clone(), client: self.client.clone(), sender: self.sender.clone() }
}
}
#[async_trait::async_trait]
impl<Block, BI, Client> BlockImport<Block> for SlotBasedBlockImport<Block, BI, Client>
where
Block: BlockT,
BI: BlockImport<Block> + Send + Sync,
BI::Error: Into<sp_consensus::Error>,
Client: ProvideRuntimeApi<Block> + CallApiAt<Block> + Send + Sync,
Client::StateBackend: Send,
Client::Api: Core<Block>,
{
type Error = sp_consensus::Error;
async fn check_block(
&self,
block: sc_consensus::BlockCheckParams<Block>,
) -> Result<sc_consensus::ImportResult, Self::Error> {
self.inner.check_block(block).await.map_err(Into::into)
}
async fn import_block(
&self,
mut params: sc_consensus::BlockImportParams<Block>,
) -> Result<sc_consensus::ImportResult, Self::Error> {
if !self.sender.is_closed() && !matches!(params.state_action, StateAction::ApplyChanges(_))
{
let mut runtime_api = self.client.runtime_api();
runtime_api.set_call_context(CallContext::Onchain);
runtime_api.record_proof();
let recorder = runtime_api
.proof_recorder()
.expect("Proof recording is enabled in the line above; qed.");
runtime_api.register_extension(ProofSizeExt::new(recorder));
let parent_hash = *params.header.parent_hash();
let block = Block::new(params.header.clone(), params.body.clone().unwrap_or_default());
runtime_api
.execute_block(parent_hash, block.clone())
.map_err(|e| Box::new(e) as Box<_>)?;
let storage_proof =
runtime_api.extract_proof().expect("Proof recording was enabled above; qed");
let state = self.client.state_at(parent_hash).map_err(|e| Box::new(e) as Box<_>)?;
let gen_storage_changes = runtime_api
.into_storage_changes(&state, parent_hash)
.map_err(sp_consensus::Error::ChainLookup)?;
if params.header.state_root() != &gen_storage_changes.transaction_storage_root {
return Err(sp_consensus::Error::Other(Box::new(
sp_blockchain::Error::InvalidStateRoot,
)))
}
params.state_action = StateAction::ApplyChanges(sc_consensus::StorageChanges::Changes(
gen_storage_changes,
));
let _ = self.sender.unbounded_send((block, storage_proof));
}
self.inner.import_block(params).await.map_err(Into::into)
}
}