1use codec::{Decode, Encode};
22use sp_blockchain::{Error, Result};
23use sp_database::{Database, Transaction};
24use sp_runtime::traits::AtLeast32Bit;
25use std::{cmp::Reverse, collections::BTreeMap};
26
27type DbHash = sp_core::H256;
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30struct LeafSetItem<H, N> {
31 hash: H,
32 number: Reverse<N>,
33}
34
35pub struct ImportOutcome<H, N> {
37 inserted: LeafSetItem<H, N>,
38 removed: Option<H>,
39}
40
41pub struct RemoveOutcome<H, N> {
43 inserted: Option<H>,
44 removed: LeafSetItem<H, N>,
45}
46
47pub struct FinalizationOutcome<I, H, N>
49where
50 I: Iterator<Item = (N, H)>,
51{
52 removed: I,
53}
54
55impl<I, H: Ord, N: Ord> FinalizationOutcome<I, H, N>
56where
57 I: Iterator<Item = (N, H)>,
58{
59 pub fn new(new_displaced: I) -> Self {
61 FinalizationOutcome { removed: new_displaced }
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
69pub struct LeafSet<H, N> {
70 storage: BTreeMap<Reverse<N>, Vec<H>>,
71}
72
73impl<H, N> LeafSet<H, N>
74where
75 H: Clone + PartialEq + Decode + Encode,
76 N: std::fmt::Debug + Copy + AtLeast32Bit + Decode + Encode,
77{
78 pub fn new() -> Self {
80 Self { storage: BTreeMap::new() }
81 }
82
83 pub fn read_from_db(db: &dyn Database<DbHash>, column: u32, prefix: &[u8]) -> Result<Self> {
85 let mut storage = BTreeMap::new();
86
87 match db.get(column, prefix) {
88 Some(leaves) => {
89 let vals: Vec<_> = match Decode::decode(&mut leaves.as_ref()) {
90 Ok(vals) => vals,
91 Err(_) => return Err(Error::Backend("Error decoding leaves".into())),
92 };
93 for (number, hashes) in vals.into_iter() {
94 storage.insert(Reverse(number), hashes);
95 }
96 },
97 None => {},
98 }
99 Ok(Self { storage })
100 }
101
102 pub fn import(&mut self, hash: H, number: N, parent_hash: H) -> ImportOutcome<H, N> {
104 let number = Reverse(number);
105
106 let removed = if number.0 != N::zero() {
107 let parent_number = Reverse(number.0 - N::one());
108 self.remove_leaf(&parent_number, &parent_hash).then(|| parent_hash)
109 } else {
110 None
111 };
112
113 self.insert_leaf(number, hash.clone());
114
115 ImportOutcome { inserted: LeafSetItem { hash, number }, removed }
116 }
117
118 pub fn remove(
127 &mut self,
128 hash: H,
129 number: N,
130 parent_hash: Option<H>,
131 ) -> Option<RemoveOutcome<H, N>> {
132 let number = Reverse(number);
133
134 if !self.remove_leaf(&number, &hash) {
135 return None
136 }
137
138 let inserted = parent_hash.and_then(|parent_hash| {
139 if number.0 != N::zero() {
140 let parent_number = Reverse(number.0 - N::one());
141 self.insert_leaf(parent_number, parent_hash.clone());
142 Some(parent_hash)
143 } else {
144 None
145 }
146 });
147
148 Some(RemoveOutcome { inserted, removed: LeafSetItem { hash, number } })
149 }
150
151 pub fn remove_displaced_leaves<I>(&mut self, displaced_leaves: FinalizationOutcome<I, H, N>)
153 where
154 I: Iterator<Item = (N, H)>,
155 {
156 for (number, hash) in displaced_leaves.removed {
157 self.remove_leaf(&Reverse(number), &hash);
158 }
159 }
160
161 pub fn undo(&mut self) -> Undo<'_, H, N> {
168 Undo { inner: self }
169 }
170
171 pub fn revert(&mut self, best_hash: H, best_number: N) -> impl Iterator<Item = (H, N)> {
176 let items = self
177 .storage
178 .iter()
179 .flat_map(|(number, hashes)| hashes.iter().map(move |h| (h.clone(), number.0)))
180 .collect::<Vec<_>>();
181
182 for (hash, number) in &items {
183 if *number > best_number {
184 assert!(
185 self.remove_leaf(&Reverse(*number), &hash),
186 "item comes from an iterator over storage; qed",
187 );
188 }
189 }
190
191 let best_number_rev = Reverse(best_number);
192 let leaves_contains_best = self
193 .storage
194 .get(&best_number_rev)
195 .map_or(false, |hashes| hashes.contains(&best_hash));
196
197 if !leaves_contains_best {
200 self.insert_leaf(best_number_rev, best_hash.clone());
201 }
202
203 items.into_iter().filter(move |(_, n)| *n > best_number)
204 }
205
206 pub fn hashes(&self) -> Vec<H> {
209 self.storage.iter().flat_map(|(_, hashes)| hashes.iter()).cloned().collect()
210 }
211
212 pub fn count(&self) -> usize {
214 self.storage.values().map(|level| level.len()).sum()
215 }
216
217 pub fn prepare_transaction(
219 &mut self,
220 tx: &mut Transaction<DbHash>,
221 column: u32,
222 prefix: &[u8],
223 ) {
224 let leaves: Vec<_> = self.storage.iter().map(|(n, h)| (n.0, h.clone())).collect();
225 tx.set_from_vec(column, prefix, leaves.encode());
226 }
227
228 pub fn contains(&self, number: N, hash: H) -> bool {
230 self.storage
231 .get(&Reverse(number))
232 .map_or(false, |hashes| hashes.contains(&hash))
233 }
234
235 fn insert_leaf(&mut self, number: Reverse<N>, hash: H) {
236 self.storage.entry(number).or_insert_with(Vec::new).push(hash);
237 }
238
239 fn remove_leaf(&mut self, number: &Reverse<N>, hash: &H) -> bool {
241 let mut empty = false;
242 let removed = self.storage.get_mut(number).map_or(false, |leaves| {
243 let mut found = false;
244 leaves.retain(|h| {
245 if h == hash {
246 found = true;
247 false
248 } else {
249 true
250 }
251 });
252
253 if leaves.is_empty() {
254 empty = true
255 }
256
257 found
258 });
259
260 if removed && empty {
261 self.storage.remove(number);
262 }
263
264 removed
265 }
266
267 pub fn highest_leaf(&self) -> Option<(N, &[H])> {
269 self.storage.iter().next().map(|(k, v)| (k.0, &v[..]))
270 }
271}
272
273pub struct Undo<'a, H: 'a, N: 'a> {
275 inner: &'a mut LeafSet<H, N>,
276}
277
278impl<'a, H: 'a, N: 'a> Undo<'a, H, N>
279where
280 H: Clone + PartialEq + Decode + Encode,
281 N: std::fmt::Debug + Copy + AtLeast32Bit + Decode + Encode,
282{
283 pub fn undo_import(&mut self, outcome: ImportOutcome<H, N>) {
286 if let Some(removed_hash) = outcome.removed {
287 let removed_number = Reverse(outcome.inserted.number.0 - N::one());
288 self.inner.insert_leaf(removed_number, removed_hash);
289 }
290 self.inner.remove_leaf(&outcome.inserted.number, &outcome.inserted.hash);
291 }
292
293 pub fn undo_remove(&mut self, outcome: RemoveOutcome<H, N>) {
296 if let Some(inserted_hash) = outcome.inserted {
297 let inserted_number = Reverse(outcome.removed.number.0 - N::one());
298 self.inner.remove_leaf(&inserted_number, &inserted_hash);
299 }
300 self.inner.insert_leaf(outcome.removed.number, outcome.removed.hash);
301 }
302
303 pub fn undo_finalization<I>(&mut self, outcome: FinalizationOutcome<I, H, N>)
306 where
307 I: Iterator<Item = (N, H)>,
308 {
309 for (number, hash) in outcome.removed {
310 self.inner.storage.entry(Reverse(number)).or_default().push(hash);
311 }
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use std::sync::Arc;
319
320 #[test]
321 fn import_works() {
322 let mut set = LeafSet::new();
323 set.import(0u32, 0u32, 0u32);
324
325 set.import(1_1, 1, 0);
326 set.import(2_1, 2, 1_1);
327 set.import(3_1, 3, 2_1);
328
329 assert_eq!(set.count(), 1);
330 assert!(set.contains(3, 3_1));
331 assert!(!set.contains(2, 2_1));
332 assert!(!set.contains(1, 1_1));
333 assert!(!set.contains(0, 0));
334
335 set.import(2_2, 2, 1_1);
336 set.import(1_2, 1, 0);
337 set.import(2_3, 2, 1_2);
338
339 assert_eq!(set.count(), 3);
340 assert!(set.contains(3, 3_1));
341 assert!(set.contains(2, 2_2));
342 assert!(set.contains(2, 2_3));
343
344 let outcome = set.import(2_4, 2, 1_1);
347 assert_eq!(outcome.inserted.hash, 2_4);
348 assert_eq!(outcome.removed, None);
349 assert_eq!(set.count(), 4);
350 assert!(set.contains(2, 2_4));
351
352 set.undo().undo_import(outcome);
353 assert_eq!(set.count(), 3);
354 assert!(set.contains(3, 3_1));
355 assert!(set.contains(2, 2_2));
356 assert!(set.contains(2, 2_3));
357
358 let outcome = set.import(3_2, 3, 2_3);
359 assert_eq!(outcome.inserted.hash, 3_2);
360 assert_eq!(outcome.removed, Some(2_3));
361 assert_eq!(set.count(), 3);
362 assert!(set.contains(3, 3_2));
363
364 set.undo().undo_import(outcome);
365 assert_eq!(set.count(), 3);
366 assert!(set.contains(3, 3_1));
367 assert!(set.contains(2, 2_2));
368 assert!(set.contains(2, 2_3));
369 }
370
371 #[test]
372 fn removal_works() {
373 let mut set = LeafSet::new();
374 set.import(10_1u32, 10u32, 0u32);
375 set.import(11_1, 11, 10_1);
376 set.import(11_2, 11, 10_1);
377 set.import(12_1, 12, 11_1);
378
379 let outcome = set.remove(12_1, 12, Some(11_1)).unwrap();
380 assert_eq!(outcome.removed.hash, 12_1);
381 assert_eq!(outcome.inserted, Some(11_1));
382 assert_eq!(set.count(), 2);
383 assert!(set.contains(11, 11_1));
384 assert!(set.contains(11, 11_2));
385
386 let outcome = set.remove(11_1, 11, None).unwrap();
387 assert_eq!(outcome.removed.hash, 11_1);
388 assert_eq!(outcome.inserted, None);
389 assert_eq!(set.count(), 1);
390 assert!(set.contains(11, 11_2));
391
392 let outcome = set.remove(11_2, 11, Some(10_1)).unwrap();
393 assert_eq!(outcome.removed.hash, 11_2);
394 assert_eq!(outcome.inserted, Some(10_1));
395 assert_eq!(set.count(), 1);
396 assert!(set.contains(10, 10_1));
397
398 set.undo().undo_remove(outcome);
399 assert_eq!(set.count(), 1);
400 assert!(set.contains(11, 11_2));
401 }
402
403 #[test]
404 fn flush_to_disk() {
405 const PREFIX: &[u8] = b"abcdefg";
406 let db = Arc::new(sp_database::MemDb::default());
407
408 let mut set = LeafSet::new();
409 set.import(0u32, 0u32, 0u32);
410
411 set.import(1_1, 1, 0);
412 set.import(2_1, 2, 1_1);
413 set.import(3_1, 3, 2_1);
414
415 let mut tx = Transaction::new();
416
417 set.prepare_transaction(&mut tx, 0, PREFIX);
418 db.commit(tx).unwrap();
419
420 let set2 = LeafSet::read_from_db(&*db, 0, PREFIX).unwrap();
421 assert_eq!(set, set2);
422 }
423
424 #[test]
425 fn two_leaves_same_height_can_be_included() {
426 let mut set = LeafSet::new();
427
428 set.import(1_1u32, 10u32, 0u32);
429 set.import(1_2, 10, 0);
430
431 assert!(set.storage.contains_key(&Reverse(10)));
432 assert!(set.contains(10, 1_1));
433 assert!(set.contains(10, 1_2));
434 assert!(!set.contains(10, 1_3));
435 }
436}