use crate::{
schema::v1::{KeyValueStateEntry, StateEntry, StateRequest, StateResponse},
LOG_TARGET,
};
use codec::{Decode, Encode};
use log::debug;
use sc_client_api::{CompactProof, KeyValueStates, ProofProvider};
use sc_consensus::ImportedState;
use smallvec::SmallVec;
use sp_core::storage::well_known_keys;
use sp_runtime::{
traits::{Block as BlockT, Header, NumberFor},
Justifications,
};
use std::{collections::HashMap, fmt, sync::Arc};
pub trait StateSyncProvider<B: BlockT>: Send + Sync {
fn import(&mut self, response: StateResponse) -> ImportResult<B>;
fn next_request(&self) -> StateRequest;
fn is_complete(&self) -> bool;
fn target_number(&self) -> NumberFor<B>;
fn target_hash(&self) -> B::Hash;
fn progress(&self) -> StateSyncProgress;
}
#[derive(Clone, Eq, PartialEq, Debug)]
pub enum StateSyncPhase {
DownloadingState,
ImportingState,
}
impl fmt::Display for StateSyncPhase {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::DownloadingState => write!(f, "Downloading state"),
Self::ImportingState => write!(f, "Importing state"),
}
}
}
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct StateSyncProgress {
pub percentage: u32,
pub size: u64,
pub phase: StateSyncPhase,
}
pub enum ImportResult<B: BlockT> {
Import(B::Hash, B::Header, ImportedState<B>, Option<Vec<B::Extrinsic>>, Option<Justifications>),
Continue,
BadResponse,
}
struct StateSyncMetadata<B: BlockT> {
last_key: SmallVec<[Vec<u8>; 2]>,
target_header: B::Header,
target_body: Option<Vec<B::Extrinsic>>,
target_justifications: Option<Justifications>,
complete: bool,
imported_bytes: u64,
skip_proof: bool,
}
impl<B: BlockT> StateSyncMetadata<B> {
fn target_hash(&self) -> B::Hash {
self.target_header.hash()
}
fn target_number(&self) -> NumberFor<B> {
*self.target_header.number()
}
fn target_root(&self) -> B::Hash {
*self.target_header.state_root()
}
fn next_request(&self) -> StateRequest {
StateRequest {
block: self.target_hash().encode(),
start: self.last_key.clone().into_vec(),
no_proof: self.skip_proof,
}
}
fn progress(&self) -> StateSyncProgress {
let cursor = *self.last_key.get(0).and_then(|last| last.get(0)).unwrap_or(&0u8);
let percent_done = cursor as u32 * 100 / 256;
StateSyncProgress {
percentage: percent_done,
size: self.imported_bytes,
phase: if self.complete {
StateSyncPhase::ImportingState
} else {
StateSyncPhase::DownloadingState
},
}
}
}
pub struct StateSync<B: BlockT, Client> {
metadata: StateSyncMetadata<B>,
state: HashMap<Vec<u8>, (Vec<(Vec<u8>, Vec<u8>)>, Vec<Vec<u8>>)>,
client: Arc<Client>,
}
impl<B, Client> StateSync<B, Client>
where
B: BlockT,
Client: ProofProvider<B> + Send + Sync + 'static,
{
pub fn new(
client: Arc<Client>,
target_header: B::Header,
target_body: Option<Vec<B::Extrinsic>>,
target_justifications: Option<Justifications>,
skip_proof: bool,
) -> Self {
Self {
client,
metadata: StateSyncMetadata {
last_key: SmallVec::default(),
target_header,
target_body,
target_justifications,
complete: false,
imported_bytes: 0,
skip_proof,
},
state: HashMap::default(),
}
}
fn process_state_key_values(
&mut self,
state_root: Vec<u8>,
key_values: impl IntoIterator<Item = (Vec<u8>, Vec<u8>)>,
) {
let is_top = state_root.is_empty();
let entry = self.state.entry(state_root).or_default();
if entry.0.len() > 0 && entry.1.len() > 1 {
return;
}
let mut child_storage_roots = Vec::new();
for (key, value) in key_values {
if is_top && well_known_keys::is_child_storage_key(key.as_slice()) {
child_storage_roots.push((value, key));
} else {
self.metadata.imported_bytes += key.len() as u64;
entry.0.push((key, value));
}
}
for (root, storage_key) in child_storage_roots {
self.state.entry(root).or_default().1.push(storage_key);
}
}
fn process_state_verified(&mut self, values: KeyValueStates) {
for values in values.0 {
self.process_state_key_values(values.state_root, values.key_values);
}
}
fn process_state_unverified(&mut self, response: StateResponse) -> bool {
let mut complete = true;
if self.metadata.last_key.len() == 2 && response.entries[0].entries.is_empty() {
self.metadata.last_key.pop();
} else {
self.metadata.last_key.clear();
}
for state in response.entries {
debug!(
target: LOG_TARGET,
"Importing state from {:?} to {:?}",
state.entries.last().map(|e| sp_core::hexdisplay::HexDisplay::from(&e.key)),
state.entries.first().map(|e| sp_core::hexdisplay::HexDisplay::from(&e.key)),
);
if !state.complete {
if let Some(e) = state.entries.last() {
self.metadata.last_key.push(e.key.clone());
}
complete = false;
}
let KeyValueStateEntry { state_root, entries, complete: _ } = state;
self.process_state_key_values(
state_root,
entries.into_iter().map(|StateEntry { key, value }| (key, value)),
);
}
complete
}
}
impl<B, Client> StateSyncProvider<B> for StateSync<B, Client>
where
B: BlockT,
Client: ProofProvider<B> + Send + Sync + 'static,
{
fn import(&mut self, response: StateResponse) -> ImportResult<B> {
if response.entries.is_empty() && response.proof.is_empty() {
debug!(target: LOG_TARGET, "Bad state response");
return ImportResult::BadResponse
}
if !self.metadata.skip_proof && response.proof.is_empty() {
debug!(target: LOG_TARGET, "Missing proof");
return ImportResult::BadResponse
}
let complete = if !self.metadata.skip_proof {
debug!(target: LOG_TARGET, "Importing state from {} trie nodes", response.proof.len());
let proof_size = response.proof.len() as u64;
let proof = match CompactProof::decode(&mut response.proof.as_ref()) {
Ok(proof) => proof,
Err(e) => {
debug!(target: LOG_TARGET, "Error decoding proof: {:?}", e);
return ImportResult::BadResponse
},
};
let (values, completed) = match self.client.verify_range_proof(
self.metadata.target_root(),
proof,
self.metadata.last_key.as_slice(),
) {
Err(e) => {
debug!(
target: LOG_TARGET,
"StateResponse failed proof verification: {}",
e,
);
return ImportResult::BadResponse
},
Ok(values) => values,
};
debug!(target: LOG_TARGET, "Imported with {} keys", values.len());
let complete = completed == 0;
if !complete && !values.update_last_key(completed, &mut self.metadata.last_key) {
debug!(target: LOG_TARGET, "Error updating key cursor, depth: {}", completed);
};
self.process_state_verified(values);
self.metadata.imported_bytes += proof_size;
complete
} else {
self.process_state_unverified(response)
};
if complete {
self.metadata.complete = true;
let target_hash = self.metadata.target_hash();
ImportResult::Import(
target_hash,
self.metadata.target_header.clone(),
ImportedState { block: target_hash, state: std::mem::take(&mut self.state).into() },
self.metadata.target_body.clone(),
self.metadata.target_justifications.clone(),
)
} else {
ImportResult::Continue
}
}
fn next_request(&self) -> StateRequest {
self.metadata.next_request()
}
fn is_complete(&self) -> bool {
self.metadata.complete
}
fn target_number(&self) -> NumberFor<B> {
self.metadata.target_number()
}
fn target_hash(&self) -> B::Hash {
self.metadata.target_hash()
}
fn progress(&self) -> StateSyncProgress {
self.metadata.progress()
}
}