1use crate::LOG_TARGET;
22use codec::{Decode, Encode};
23use log::{info, trace};
24use sc_client_api::backend::AuxStore;
25use sp_blockchain::{Error as ClientError, Result as ClientResult};
26use sp_runtime::traits::{Block, NumberFor};
27
28const VERSION_KEY: &[u8] = b"mmr_auxschema_version";
29const GADGET_STATE: &[u8] = b"mmr_gadget_state";
30
31const CURRENT_VERSION: u32 = 1;
32pub(crate) type PersistedState<B> = NumberFor<B>;
33
34pub(crate) fn write_current_version<B: AuxStore>(backend: &B) -> ClientResult<()> {
35 info!(target: LOG_TARGET, "write aux schema version {:?}", CURRENT_VERSION);
36 AuxStore::insert_aux(backend, &[(VERSION_KEY, CURRENT_VERSION.encode().as_slice())], &[])
37}
38
39pub(crate) fn write_gadget_state<B: Block, BE: AuxStore>(
41 backend: &BE,
42 state: &PersistedState<B>,
43) -> ClientResult<()> {
44 trace!(target: LOG_TARGET, "persisting {:?}", state);
45 backend.insert_aux(&[(GADGET_STATE, state.encode().as_slice())], &[])
46}
47
48fn load_decode<B: AuxStore, T: Decode>(backend: &B, key: &[u8]) -> ClientResult<Option<T>> {
49 match backend.get_aux(key)? {
50 None => Ok(None),
51 Some(t) => T::decode(&mut &t[..])
52 .map_err(|e| ClientError::Backend(format!("MMR aux DB is corrupted: {}", e)))
53 .map(Some),
54 }
55}
56
57pub(crate) fn load_state<B, BE>(backend: &BE) -> ClientResult<Option<PersistedState<B>>>
59where
60 B: Block,
61 BE: AuxStore,
62{
63 let version: Option<u32> = load_decode(backend, VERSION_KEY)?;
64
65 match version {
66 None => (),
67 Some(1) => return load_decode::<_, PersistedState<B>>(backend, GADGET_STATE),
68 other =>
69 return Err(ClientError::Backend(format!("Unsupported MMR aux DB version: {:?}", other))),
70 }
71
72 Ok(None)
74}
75
76pub(crate) fn load_or_init_state<B, BE>(
78 backend: &BE,
79 default: NumberFor<B>,
80) -> sp_blockchain::Result<NumberFor<B>>
81where
82 B: Block,
83 BE: AuxStore,
84{
85 if let Some(best) = load_state::<B, BE>(backend)? {
87 info!(target: LOG_TARGET, "Loading MMR best canonicalized state from db: {:?}.", best);
88 Ok(best)
89 } else {
90 info!(
91 target: LOG_TARGET,
92 "Loading MMR from pallet genesis on what appears to be the first startup: {:?}.",
93 default
94 );
95 write_current_version(backend)?;
96 write_gadget_state::<B, BE>(backend, &default)?;
97 Ok(default)
98 }
99}
100
101#[cfg(test)]
102pub(crate) mod tests {
103 use super::*;
104 use crate::test_utils::{run_test_with_mmr_gadget_pre_post_using_client, MmrBlock, MockClient};
105 use parking_lot::Mutex;
106 use sp_runtime::generic::BlockId;
107 use std::{sync::Arc, time::Duration};
108 use substrate_test_runtime_client::{runtime::Block, Backend};
109
110 #[test]
111 fn should_load_persistent_sanity_checks() {
112 let client = MockClient::new();
113 let backend = &*client.backend;
114
115 assert_eq!(load_state::<Block, Backend>(backend).unwrap(), None);
117
118 write_current_version(backend).unwrap();
120 assert_eq!(load_decode(backend, VERSION_KEY).unwrap(), Some(CURRENT_VERSION));
122
123 assert_eq!(load_state::<Block, Backend>(backend).unwrap(), None);
125 }
126
127 #[test]
128 fn should_persist_progress_across_runs() {
129 sp_tracing::try_init_simple();
130
131 let client = Arc::new(MockClient::new());
132 let backend = client.backend.clone();
133
134 assert_eq!(load_decode::<Backend, Option<u32>>(&*backend, VERSION_KEY).unwrap(), None);
136 assert_eq!(load_state::<Block, Backend>(&*backend).unwrap(), None);
138 run_test_with_mmr_gadget_pre_post_using_client(
140 client.clone(),
141 |_| async {},
142 |client| async move {
143 let a1 = client.import_block(&BlockId::Number(0), b"a1", Some(0)).await;
144 let a2 = client.import_block(&BlockId::Number(1), b"a2", Some(1)).await;
145 let a3 = client.import_block(&BlockId::Number(2), b"a3", Some(2)).await;
146 client.finalize_block(a3.hash(), Some(3));
147 tokio::time::sleep(Duration::from_millis(200)).await;
148 client.assert_canonicalized(&[&a1, &a2, &a3]);
150 },
151 );
152
153 run_test_with_mmr_gadget_pre_post_using_client(
155 client.clone(),
156 |client| async move {
157 let backend = &*client.backend;
158 assert_eq!(load_decode(backend, VERSION_KEY).unwrap(), Some(CURRENT_VERSION));
160 assert_eq!(load_state::<Block, Backend>(backend).unwrap(), Some(3));
161 },
162 |client| async move {
163 let a4 = client.import_block(&BlockId::Number(3), b"a4", Some(3)).await;
164 let a5 = client.import_block(&BlockId::Number(4), b"a5", Some(4)).await;
165 let a6 = client.import_block(&BlockId::Number(5), b"a6", Some(5)).await;
166 client.finalize_block(a6.hash(), Some(6));
167 tokio::time::sleep(Duration::from_millis(200)).await;
168
169 client.assert_canonicalized(&[&a4, &a5, &a6]);
171 assert_eq!(load_state::<Block, Backend>(&*client.backend).unwrap(), Some(6));
173 },
174 );
175 }
176
177 #[test]
178 fn should_resume_from_persisted_state() {
179 sp_tracing::try_init_simple();
180
181 let client = Arc::new(MockClient::new());
182 let blocks = Arc::new(Mutex::new(Vec::<MmrBlock>::new()));
183 let blocks_clone = blocks.clone();
184
185 run_test_with_mmr_gadget_pre_post_using_client(
187 client.clone(),
188 |_| async {},
189 |client| async move {
190 let mut blocks = blocks_clone.lock();
191 blocks.push(client.import_block(&BlockId::Number(0), b"a1", Some(0)).await);
192 blocks.push(client.import_block(&BlockId::Number(1), b"a2", Some(1)).await);
193 blocks.push(client.import_block(&BlockId::Number(2), b"a3", Some(2)).await);
194 client.finalize_block(blocks.last().unwrap().hash(), Some(3));
195 tokio::time::sleep(Duration::from_millis(200)).await;
196 let slice: Vec<&MmrBlock> = blocks.iter().collect();
198 client.assert_canonicalized(&slice);
199
200 for mmr_block in slice {
202 client.undo_block_canonicalization(mmr_block)
203 }
204 },
205 );
206
207 let blocks_clone = blocks.clone();
208 run_test_with_mmr_gadget_pre_post_using_client(
210 client.clone(),
211 |client| async move {
212 let blocks = blocks_clone.lock();
213 let slice: Vec<&MmrBlock> = blocks.iter().collect();
214
215 assert_eq!(load_state::<Block, Backend>(&*client.backend).unwrap(), Some(3));
217 client.assert_not_canonicalized(&slice);
219 },
220 |client| async move {
221 let a4 = client.import_block(&BlockId::Number(3), b"a4", Some(3)).await;
222 let a5 = client.import_block(&BlockId::Number(4), b"a5", Some(4)).await;
223 let a6 = client.import_block(&BlockId::Number(5), b"a6", Some(5)).await;
224 client.finalize_block(a6.hash(), Some(6));
225 tokio::time::sleep(Duration::from_millis(200)).await;
226
227 let block_1_to_3 = blocks.lock();
228 let slice: Vec<&MmrBlock> = block_1_to_3.iter().collect();
229 client.assert_not_canonicalized(&slice);
231 client.assert_canonicalized(&[&a4, &a5, &a6]);
233 assert_eq!(load_state::<Block, Backend>(&*client.backend).unwrap(), Some(6));
235 },
236 );
237 }
238}