sc_network_sync/strategy/
state_sync.rs1use crate::{
22	schema::v1::{KeyValueStateEntry, StateEntry, StateRequest, StateResponse},
23	LOG_TARGET,
24};
25use codec::{Decode, Encode};
26use log::debug;
27use sc_client_api::{CompactProof, KeyValueStates, ProofProvider};
28use sc_consensus::ImportedState;
29use smallvec::SmallVec;
30use sp_core::storage::well_known_keys;
31use sp_runtime::{
32	traits::{Block as BlockT, Header, NumberFor},
33	Justifications,
34};
35use std::{collections::HashMap, fmt, sync::Arc};
36
37pub trait StateSyncProvider<B: BlockT>: Send + Sync {
39	fn import(&mut self, response: StateResponse) -> ImportResult<B>;
41	fn next_request(&self) -> StateRequest;
43	fn is_complete(&self) -> bool;
45	fn target_number(&self) -> NumberFor<B>;
47	fn target_hash(&self) -> B::Hash;
49	fn progress(&self) -> StateSyncProgress;
51}
52
53#[derive(Clone, Eq, PartialEq, Debug)]
55pub enum StateSyncPhase {
56	DownloadingState,
58	ImportingState,
60}
61
62impl fmt::Display for StateSyncPhase {
63	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
64		match self {
65			Self::DownloadingState => write!(f, "Downloading state"),
66			Self::ImportingState => write!(f, "Importing state"),
67		}
68	}
69}
70
71#[derive(Clone, Eq, PartialEq, Debug)]
73pub struct StateSyncProgress {
74	pub percentage: u32,
76	pub size: u64,
78	pub phase: StateSyncPhase,
80}
81
82pub enum ImportResult<B: BlockT> {
84	Import(B::Hash, B::Header, ImportedState<B>, Option<Vec<B::Extrinsic>>, Option<Justifications>),
86	Continue,
88	BadResponse,
90}
91
92struct StateSyncMetadata<B: BlockT> {
93	last_key: SmallVec<[Vec<u8>; 2]>,
94	target_header: B::Header,
95	target_body: Option<Vec<B::Extrinsic>>,
96	target_justifications: Option<Justifications>,
97	complete: bool,
98	imported_bytes: u64,
99	skip_proof: bool,
100}
101
102impl<B: BlockT> StateSyncMetadata<B> {
103	fn target_hash(&self) -> B::Hash {
104		self.target_header.hash()
105	}
106
107	fn target_number(&self) -> NumberFor<B> {
109		*self.target_header.number()
110	}
111
112	fn target_root(&self) -> B::Hash {
113		*self.target_header.state_root()
114	}
115
116	fn next_request(&self) -> StateRequest {
117		StateRequest {
118			block: self.target_hash().encode(),
119			start: self.last_key.clone().into_vec(),
120			no_proof: self.skip_proof,
121		}
122	}
123
124	fn progress(&self) -> StateSyncProgress {
125		let cursor = *self.last_key.get(0).and_then(|last| last.get(0)).unwrap_or(&0u8);
126		let percent_done = cursor as u32 * 100 / 256;
127		StateSyncProgress {
128			percentage: percent_done,
129			size: self.imported_bytes,
130			phase: if self.complete {
131				StateSyncPhase::ImportingState
132			} else {
133				StateSyncPhase::DownloadingState
134			},
135		}
136	}
137}
138
139pub struct StateSync<B: BlockT, Client> {
143	metadata: StateSyncMetadata<B>,
144	state: HashMap<Vec<u8>, (Vec<(Vec<u8>, Vec<u8>)>, Vec<Vec<u8>>)>,
145	client: Arc<Client>,
146}
147
148impl<B, Client> StateSync<B, Client>
149where
150	B: BlockT,
151	Client: ProofProvider<B> + Send + Sync + 'static,
152{
153	pub fn new(
155		client: Arc<Client>,
156		target_header: B::Header,
157		target_body: Option<Vec<B::Extrinsic>>,
158		target_justifications: Option<Justifications>,
159		skip_proof: bool,
160	) -> Self {
161		Self {
162			client,
163			metadata: StateSyncMetadata {
164				last_key: SmallVec::default(),
165				target_header,
166				target_body,
167				target_justifications,
168				complete: false,
169				imported_bytes: 0,
170				skip_proof,
171			},
172			state: HashMap::default(),
173		}
174	}
175
176	fn process_state_key_values(
177		&mut self,
178		state_root: Vec<u8>,
179		key_values: impl IntoIterator<Item = (Vec<u8>, Vec<u8>)>,
180	) {
181		let is_top = state_root.is_empty();
182
183		let entry = self.state.entry(state_root).or_default();
184
185		if entry.0.len() > 0 && entry.1.len() > 1 {
186			return;
189		}
190
191		let mut child_storage_roots = Vec::new();
192
193		for (key, value) in key_values {
194			if is_top && well_known_keys::is_child_storage_key(key.as_slice()) {
196				child_storage_roots.push((value, key));
197			} else {
198				self.metadata.imported_bytes += key.len() as u64;
199				entry.0.push((key, value));
200			}
201		}
202
203		for (root, storage_key) in child_storage_roots {
204			self.state.entry(root).or_default().1.push(storage_key);
205		}
206	}
207
208	fn process_state_verified(&mut self, values: KeyValueStates) {
209		for values in values.0 {
210			self.process_state_key_values(values.state_root, values.key_values);
211		}
212	}
213
214	fn process_state_unverified(&mut self, response: StateResponse) -> bool {
215		let mut complete = true;
216		if self.metadata.last_key.len() == 2 && response.entries[0].entries.is_empty() {
221			self.metadata.last_key.pop();
223		} else {
224			self.metadata.last_key.clear();
225		}
226		for state in response.entries {
227			debug!(
228				target: LOG_TARGET,
229				"Importing state from {:?} to {:?}",
230				state.entries.last().map(|e| sp_core::hexdisplay::HexDisplay::from(&e.key)),
231				state.entries.first().map(|e| sp_core::hexdisplay::HexDisplay::from(&e.key)),
232			);
233
234			if !state.complete {
235				if let Some(e) = state.entries.last() {
236					self.metadata.last_key.push(e.key.clone());
237				}
238				complete = false;
239			}
240
241			let KeyValueStateEntry { state_root, entries, complete: _ } = state;
242			self.process_state_key_values(
243				state_root,
244				entries.into_iter().map(|StateEntry { key, value }| (key, value)),
245			);
246		}
247		complete
248	}
249}
250
251impl<B, Client> StateSyncProvider<B> for StateSync<B, Client>
252where
253	B: BlockT,
254	Client: ProofProvider<B> + Send + Sync + 'static,
255{
256	fn import(&mut self, response: StateResponse) -> ImportResult<B> {
258		if response.entries.is_empty() && response.proof.is_empty() {
259			debug!(target: LOG_TARGET, "Bad state response");
260			return ImportResult::BadResponse
261		}
262		if !self.metadata.skip_proof && response.proof.is_empty() {
263			debug!(target: LOG_TARGET, "Missing proof");
264			return ImportResult::BadResponse
265		}
266		let complete = if !self.metadata.skip_proof {
267			debug!(target: LOG_TARGET, "Importing state from {} trie nodes", response.proof.len());
268			let proof_size = response.proof.len() as u64;
269			let proof = match CompactProof::decode(&mut response.proof.as_ref()) {
270				Ok(proof) => proof,
271				Err(e) => {
272					debug!(target: LOG_TARGET, "Error decoding proof: {:?}", e);
273					return ImportResult::BadResponse
274				},
275			};
276			let (values, completed) = match self.client.verify_range_proof(
277				self.metadata.target_root(),
278				proof,
279				self.metadata.last_key.as_slice(),
280			) {
281				Err(e) => {
282					debug!(
283						target: LOG_TARGET,
284						"StateResponse failed proof verification: {}",
285						e,
286					);
287					return ImportResult::BadResponse
288				},
289				Ok(values) => values,
290			};
291			debug!(target: LOG_TARGET, "Imported with {} keys", values.len());
292
293			let complete = completed == 0;
294			if !complete && !values.update_last_key(completed, &mut self.metadata.last_key) {
295				debug!(target: LOG_TARGET, "Error updating key cursor, depth: {}", completed);
296			};
297
298			self.process_state_verified(values);
299			self.metadata.imported_bytes += proof_size;
300			complete
301		} else {
302			self.process_state_unverified(response)
303		};
304		if complete {
305			self.metadata.complete = true;
306			let target_hash = self.metadata.target_hash();
307			ImportResult::Import(
308				target_hash,
309				self.metadata.target_header.clone(),
310				ImportedState { block: target_hash, state: std::mem::take(&mut self.state).into() },
311				self.metadata.target_body.clone(),
312				self.metadata.target_justifications.clone(),
313			)
314		} else {
315			ImportResult::Continue
316		}
317	}
318
319	fn next_request(&self) -> StateRequest {
321		self.metadata.next_request()
322	}
323
324	fn is_complete(&self) -> bool {
326		self.metadata.complete
327	}
328
329	fn target_number(&self) -> NumberFor<B> {
331		self.metadata.target_number()
332	}
333
334	fn target_hash(&self) -> B::Hash {
336		self.metadata.target_hash()
337	}
338
339	fn progress(&self) -> StateSyncProgress {
341		self.metadata.progress()
342	}
343}