1use codec::{Decode, Encode};
22use sp_blockchain::{Error, Result};
23use sp_database::{Database, Transaction};
24use sp_runtime::traits::AtLeast32Bit;
25use std::{cmp::Reverse, collections::BTreeMap};
26
27type DbHash = sp_core::H256;
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30struct LeafSetItem<H, N> {
31 hash: H,
32 number: Reverse<N>,
33}
34
35pub struct ImportOutcome<H, N> {
37 inserted: LeafSetItem<H, N>,
38 removed: Option<H>,
39}
40
41pub struct RemoveOutcome<H, N> {
43 inserted: Option<H>,
44 removed: LeafSetItem<H, N>,
45}
46
47pub struct FinalizationOutcome<I, H, N>
49where
50 I: Iterator<Item = (N, H)>,
51{
52 removed: I,
53}
54
55impl<I, H: Ord, N: Ord> FinalizationOutcome<I, H, N>
56where
57 I: Iterator<Item = (N, H)>,
58{
59 pub fn new(new_displaced: I) -> Self {
61 FinalizationOutcome { removed: new_displaced }
62 }
63}
64
65#[derive(Debug, Clone, PartialEq, Eq)]
69pub struct LeafSet<H, N> {
70 storage: BTreeMap<Reverse<N>, Vec<H>>,
71}
72
73impl<H, N> LeafSet<H, N>
74where
75 H: Clone + PartialEq + Decode + Encode,
76 N: std::fmt::Debug + Copy + AtLeast32Bit + Decode + Encode,
77{
78 pub fn new() -> Self {
80 Self { storage: BTreeMap::new() }
81 }
82
83 pub fn read_from_db(db: &dyn Database<DbHash>, column: u32, prefix: &[u8]) -> Result<Self> {
85 let mut storage = BTreeMap::new();
86
87 match db.get(column, prefix) {
88 Some(leaves) => {
89 let vals: Vec<_> = match Decode::decode(&mut leaves.as_ref()) {
90 Ok(vals) => vals,
91 Err(_) => return Err(Error::Backend("Error decoding leaves".into())),
92 };
93 for (number, hashes) in vals.into_iter() {
94 storage.insert(Reverse(number), hashes);
95 }
96 },
97 None => {},
98 }
99 Ok(Self { storage })
100 }
101
102 pub fn import(&mut self, hash: H, number: N, parent_hash: H) -> ImportOutcome<H, N> {
104 let number = Reverse(number);
105
106 let removed = if number.0 != N::zero() {
107 let parent_number = Reverse(number.0 - N::one());
108 self.remove_leaf(&parent_number, &parent_hash).then(|| parent_hash)
109 } else {
110 None
111 };
112
113 self.insert_leaf(number, hash.clone());
114
115 ImportOutcome { inserted: LeafSetItem { hash, number }, removed }
116 }
117
118 pub fn remove(
127 &mut self,
128 hash: H,
129 number: N,
130 parent_hash: Option<H>,
131 ) -> Option<RemoveOutcome<H, N>> {
132 let number = Reverse(number);
133
134 if !self.remove_leaf(&number, &hash) {
135 return None
136 }
137
138 let inserted = parent_hash.and_then(|parent_hash| {
139 if number.0 != N::zero() {
140 let parent_number = Reverse(number.0 - N::one());
141 self.insert_leaf(parent_number, parent_hash.clone());
142 Some(parent_hash)
143 } else {
144 None
145 }
146 });
147
148 Some(RemoveOutcome { inserted, removed: LeafSetItem { hash, number } })
149 }
150
151 pub fn remove_displaced_leaves<I>(&mut self, displaced_leaves: FinalizationOutcome<I, H, N>)
153 where
154 I: Iterator<Item = (N, H)>,
155 {
156 for (number, hash) in displaced_leaves.removed {
157 self.remove_leaf(&Reverse(number), &hash);
158 }
159 }
160
161 pub fn undo(&mut self) -> Undo<H, N> {
168 Undo { inner: self }
169 }
170
171 pub fn revert(&mut self, best_hash: H, best_number: N) {
174 let items = self
175 .storage
176 .iter()
177 .flat_map(|(number, hashes)| hashes.iter().map(move |h| (h.clone(), *number)))
178 .collect::<Vec<_>>();
179
180 for (hash, number) in items {
181 if number.0 > best_number {
182 assert!(
183 self.remove_leaf(&number, &hash),
184 "item comes from an iterator over storage; qed",
185 );
186 }
187 }
188
189 let best_number = Reverse(best_number);
190 let leaves_contains_best = self
191 .storage
192 .get(&best_number)
193 .map_or(false, |hashes| hashes.contains(&best_hash));
194
195 if !leaves_contains_best {
198 self.insert_leaf(best_number, best_hash.clone());
199 }
200 }
201
202 pub fn hashes(&self) -> Vec<H> {
205 self.storage.iter().flat_map(|(_, hashes)| hashes.iter()).cloned().collect()
206 }
207
208 pub fn count(&self) -> usize {
210 self.storage.values().map(|level| level.len()).sum()
211 }
212
213 pub fn prepare_transaction(
215 &mut self,
216 tx: &mut Transaction<DbHash>,
217 column: u32,
218 prefix: &[u8],
219 ) {
220 let leaves: Vec<_> = self.storage.iter().map(|(n, h)| (n.0, h.clone())).collect();
221 tx.set_from_vec(column, prefix, leaves.encode());
222 }
223
224 pub fn contains(&self, number: N, hash: H) -> bool {
226 self.storage
227 .get(&Reverse(number))
228 .map_or(false, |hashes| hashes.contains(&hash))
229 }
230
231 fn insert_leaf(&mut self, number: Reverse<N>, hash: H) {
232 self.storage.entry(number).or_insert_with(Vec::new).push(hash);
233 }
234
235 fn remove_leaf(&mut self, number: &Reverse<N>, hash: &H) -> bool {
237 let mut empty = false;
238 let removed = self.storage.get_mut(number).map_or(false, |leaves| {
239 let mut found = false;
240 leaves.retain(|h| {
241 if h == hash {
242 found = true;
243 false
244 } else {
245 true
246 }
247 });
248
249 if leaves.is_empty() {
250 empty = true
251 }
252
253 found
254 });
255
256 if removed && empty {
257 self.storage.remove(number);
258 }
259
260 removed
261 }
262
263 pub fn highest_leaf(&self) -> Option<(N, &[H])> {
265 self.storage.iter().next().map(|(k, v)| (k.0, &v[..]))
266 }
267}
268
269pub struct Undo<'a, H: 'a, N: 'a> {
271 inner: &'a mut LeafSet<H, N>,
272}
273
274impl<'a, H: 'a, N: 'a> Undo<'a, H, N>
275where
276 H: Clone + PartialEq + Decode + Encode,
277 N: std::fmt::Debug + Copy + AtLeast32Bit + Decode + Encode,
278{
279 pub fn undo_import(&mut self, outcome: ImportOutcome<H, N>) {
282 if let Some(removed_hash) = outcome.removed {
283 let removed_number = Reverse(outcome.inserted.number.0 - N::one());
284 self.inner.insert_leaf(removed_number, removed_hash);
285 }
286 self.inner.remove_leaf(&outcome.inserted.number, &outcome.inserted.hash);
287 }
288
289 pub fn undo_remove(&mut self, outcome: RemoveOutcome<H, N>) {
292 if let Some(inserted_hash) = outcome.inserted {
293 let inserted_number = Reverse(outcome.removed.number.0 - N::one());
294 self.inner.remove_leaf(&inserted_number, &inserted_hash);
295 }
296 self.inner.insert_leaf(outcome.removed.number, outcome.removed.hash);
297 }
298
299 pub fn undo_finalization<I>(&mut self, outcome: FinalizationOutcome<I, H, N>)
302 where
303 I: Iterator<Item = (N, H)>,
304 {
305 for (number, hash) in outcome.removed {
306 self.inner.storage.entry(Reverse(number)).or_default().push(hash);
307 }
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use std::sync::Arc;
315
316 #[test]
317 fn import_works() {
318 let mut set = LeafSet::new();
319 set.import(0u32, 0u32, 0u32);
320
321 set.import(1_1, 1, 0);
322 set.import(2_1, 2, 1_1);
323 set.import(3_1, 3, 2_1);
324
325 assert_eq!(set.count(), 1);
326 assert!(set.contains(3, 3_1));
327 assert!(!set.contains(2, 2_1));
328 assert!(!set.contains(1, 1_1));
329 assert!(!set.contains(0, 0));
330
331 set.import(2_2, 2, 1_1);
332 set.import(1_2, 1, 0);
333 set.import(2_3, 2, 1_2);
334
335 assert_eq!(set.count(), 3);
336 assert!(set.contains(3, 3_1));
337 assert!(set.contains(2, 2_2));
338 assert!(set.contains(2, 2_3));
339
340 let outcome = set.import(2_4, 2, 1_1);
343 assert_eq!(outcome.inserted.hash, 2_4);
344 assert_eq!(outcome.removed, None);
345 assert_eq!(set.count(), 4);
346 assert!(set.contains(2, 2_4));
347
348 set.undo().undo_import(outcome);
349 assert_eq!(set.count(), 3);
350 assert!(set.contains(3, 3_1));
351 assert!(set.contains(2, 2_2));
352 assert!(set.contains(2, 2_3));
353
354 let outcome = set.import(3_2, 3, 2_3);
355 assert_eq!(outcome.inserted.hash, 3_2);
356 assert_eq!(outcome.removed, Some(2_3));
357 assert_eq!(set.count(), 3);
358 assert!(set.contains(3, 3_2));
359
360 set.undo().undo_import(outcome);
361 assert_eq!(set.count(), 3);
362 assert!(set.contains(3, 3_1));
363 assert!(set.contains(2, 2_2));
364 assert!(set.contains(2, 2_3));
365 }
366
367 #[test]
368 fn removal_works() {
369 let mut set = LeafSet::new();
370 set.import(10_1u32, 10u32, 0u32);
371 set.import(11_1, 11, 10_1);
372 set.import(11_2, 11, 10_1);
373 set.import(12_1, 12, 11_1);
374
375 let outcome = set.remove(12_1, 12, Some(11_1)).unwrap();
376 assert_eq!(outcome.removed.hash, 12_1);
377 assert_eq!(outcome.inserted, Some(11_1));
378 assert_eq!(set.count(), 2);
379 assert!(set.contains(11, 11_1));
380 assert!(set.contains(11, 11_2));
381
382 let outcome = set.remove(11_1, 11, None).unwrap();
383 assert_eq!(outcome.removed.hash, 11_1);
384 assert_eq!(outcome.inserted, None);
385 assert_eq!(set.count(), 1);
386 assert!(set.contains(11, 11_2));
387
388 let outcome = set.remove(11_2, 11, Some(10_1)).unwrap();
389 assert_eq!(outcome.removed.hash, 11_2);
390 assert_eq!(outcome.inserted, Some(10_1));
391 assert_eq!(set.count(), 1);
392 assert!(set.contains(10, 10_1));
393
394 set.undo().undo_remove(outcome);
395 assert_eq!(set.count(), 1);
396 assert!(set.contains(11, 11_2));
397 }
398
399 #[test]
400 fn flush_to_disk() {
401 const PREFIX: &[u8] = b"abcdefg";
402 let db = Arc::new(sp_database::MemDb::default());
403
404 let mut set = LeafSet::new();
405 set.import(0u32, 0u32, 0u32);
406
407 set.import(1_1, 1, 0);
408 set.import(2_1, 2, 1_1);
409 set.import(3_1, 3, 2_1);
410
411 let mut tx = Transaction::new();
412
413 set.prepare_transaction(&mut tx, 0, PREFIX);
414 db.commit(tx).unwrap();
415
416 let set2 = LeafSet::read_from_db(&*db, 0, PREFIX).unwrap();
417 assert_eq!(set, set2);
418 }
419
420 #[test]
421 fn two_leaves_same_height_can_be_included() {
422 let mut set = LeafSet::new();
423
424 set.import(1_1u32, 10u32, 0u32);
425 set.import(1_2, 10, 0);
426
427 assert!(set.storage.contains_key(&Reverse(10)));
428 assert!(set.contains(10, 1_1));
429 assert!(set.contains(10, 1_2));
430 assert!(!set.contains(10, 1_3));
431 }
432}