use codec::{Decode, Encode};
use sp_blockchain::{Error, Result};
use sp_database::{Database, Transaction};
use sp_runtime::traits::AtLeast32Bit;
use std::{cmp::Reverse, collections::BTreeMap};
type DbHash = sp_core::H256;
#[derive(Debug, Clone, PartialEq, Eq)]
struct LeafSetItem<H, N> {
hash: H,
number: Reverse<N>,
}
pub struct ImportOutcome<H, N> {
inserted: LeafSetItem<H, N>,
removed: Option<H>,
}
pub struct RemoveOutcome<H, N> {
inserted: Option<H>,
removed: LeafSetItem<H, N>,
}
pub struct FinalizationOutcome<I, H, N>
where
I: Iterator<Item = (N, H)>,
{
removed: I,
}
impl<I, H: Ord, N: Ord> FinalizationOutcome<I, H, N>
where
I: Iterator<Item = (N, H)>,
{
pub fn new(new_displaced: I) -> Self {
FinalizationOutcome { removed: new_displaced }
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LeafSet<H, N> {
storage: BTreeMap<Reverse<N>, Vec<H>>,
}
impl<H, N> LeafSet<H, N>
where
H: Clone + PartialEq + Decode + Encode,
N: std::fmt::Debug + Copy + AtLeast32Bit + Decode + Encode,
{
pub fn new() -> Self {
Self { storage: BTreeMap::new() }
}
pub fn read_from_db(db: &dyn Database<DbHash>, column: u32, prefix: &[u8]) -> Result<Self> {
let mut storage = BTreeMap::new();
match db.get(column, prefix) {
Some(leaves) => {
let vals: Vec<_> = match Decode::decode(&mut leaves.as_ref()) {
Ok(vals) => vals,
Err(_) => return Err(Error::Backend("Error decoding leaves".into())),
};
for (number, hashes) in vals.into_iter() {
storage.insert(Reverse(number), hashes);
}
},
None => {},
}
Ok(Self { storage })
}
pub fn import(&mut self, hash: H, number: N, parent_hash: H) -> ImportOutcome<H, N> {
let number = Reverse(number);
let removed = if number.0 != N::zero() {
let parent_number = Reverse(number.0 - N::one());
self.remove_leaf(&parent_number, &parent_hash).then(|| parent_hash)
} else {
None
};
self.insert_leaf(number, hash.clone());
ImportOutcome { inserted: LeafSetItem { hash, number }, removed }
}
pub fn remove(
&mut self,
hash: H,
number: N,
parent_hash: Option<H>,
) -> Option<RemoveOutcome<H, N>> {
let number = Reverse(number);
if !self.remove_leaf(&number, &hash) {
return None
}
let inserted = parent_hash.and_then(|parent_hash| {
if number.0 != N::zero() {
let parent_number = Reverse(number.0 - N::one());
self.insert_leaf(parent_number, parent_hash.clone());
Some(parent_hash)
} else {
None
}
});
Some(RemoveOutcome { inserted, removed: LeafSetItem { hash, number } })
}
pub fn remove_displaced_leaves<I>(&mut self, displaced_leaves: FinalizationOutcome<I, H, N>)
where
I: Iterator<Item = (N, H)>,
{
for (number, hash) in displaced_leaves.removed {
self.remove_leaf(&Reverse(number), &hash);
}
}
pub fn undo(&mut self) -> Undo<H, N> {
Undo { inner: self }
}
pub fn revert(&mut self, best_hash: H, best_number: N) {
let items = self
.storage
.iter()
.flat_map(|(number, hashes)| hashes.iter().map(move |h| (h.clone(), *number)))
.collect::<Vec<_>>();
for (hash, number) in items {
if number.0 > best_number {
assert!(
self.remove_leaf(&number, &hash),
"item comes from an iterator over storage; qed",
);
}
}
let best_number = Reverse(best_number);
let leaves_contains_best = self
.storage
.get(&best_number)
.map_or(false, |hashes| hashes.contains(&best_hash));
if !leaves_contains_best {
self.insert_leaf(best_number, best_hash.clone());
}
}
pub fn hashes(&self) -> Vec<H> {
self.storage.iter().flat_map(|(_, hashes)| hashes.iter()).cloned().collect()
}
pub fn count(&self) -> usize {
self.storage.values().map(|level| level.len()).sum()
}
pub fn prepare_transaction(
&mut self,
tx: &mut Transaction<DbHash>,
column: u32,
prefix: &[u8],
) {
let leaves: Vec<_> = self.storage.iter().map(|(n, h)| (n.0, h.clone())).collect();
tx.set_from_vec(column, prefix, leaves.encode());
}
pub fn contains(&self, number: N, hash: H) -> bool {
self.storage
.get(&Reverse(number))
.map_or(false, |hashes| hashes.contains(&hash))
}
fn insert_leaf(&mut self, number: Reverse<N>, hash: H) {
self.storage.entry(number).or_insert_with(Vec::new).push(hash);
}
fn remove_leaf(&mut self, number: &Reverse<N>, hash: &H) -> bool {
let mut empty = false;
let removed = self.storage.get_mut(number).map_or(false, |leaves| {
let mut found = false;
leaves.retain(|h| {
if h == hash {
found = true;
false
} else {
true
}
});
if leaves.is_empty() {
empty = true
}
found
});
if removed && empty {
self.storage.remove(number);
}
removed
}
pub fn highest_leaf(&self) -> Option<(N, &[H])> {
self.storage.iter().next().map(|(k, v)| (k.0, &v[..]))
}
}
pub struct Undo<'a, H: 'a, N: 'a> {
inner: &'a mut LeafSet<H, N>,
}
impl<'a, H: 'a, N: 'a> Undo<'a, H, N>
where
H: Clone + PartialEq + Decode + Encode,
N: std::fmt::Debug + Copy + AtLeast32Bit + Decode + Encode,
{
pub fn undo_import(&mut self, outcome: ImportOutcome<H, N>) {
if let Some(removed_hash) = outcome.removed {
let removed_number = Reverse(outcome.inserted.number.0 - N::one());
self.inner.insert_leaf(removed_number, removed_hash);
}
self.inner.remove_leaf(&outcome.inserted.number, &outcome.inserted.hash);
}
pub fn undo_remove(&mut self, outcome: RemoveOutcome<H, N>) {
if let Some(inserted_hash) = outcome.inserted {
let inserted_number = Reverse(outcome.removed.number.0 - N::one());
self.inner.remove_leaf(&inserted_number, &inserted_hash);
}
self.inner.insert_leaf(outcome.removed.number, outcome.removed.hash);
}
pub fn undo_finalization<I>(&mut self, outcome: FinalizationOutcome<I, H, N>)
where
I: Iterator<Item = (N, H)>,
{
for (number, hash) in outcome.removed {
self.inner.storage.entry(Reverse(number)).or_default().push(hash);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn import_works() {
let mut set = LeafSet::new();
set.import(0u32, 0u32, 0u32);
set.import(1_1, 1, 0);
set.import(2_1, 2, 1_1);
set.import(3_1, 3, 2_1);
assert_eq!(set.count(), 1);
assert!(set.contains(3, 3_1));
assert!(!set.contains(2, 2_1));
assert!(!set.contains(1, 1_1));
assert!(!set.contains(0, 0));
set.import(2_2, 2, 1_1);
set.import(1_2, 1, 0);
set.import(2_3, 2, 1_2);
assert_eq!(set.count(), 3);
assert!(set.contains(3, 3_1));
assert!(set.contains(2, 2_2));
assert!(set.contains(2, 2_3));
let outcome = set.import(2_4, 2, 1_1);
assert_eq!(outcome.inserted.hash, 2_4);
assert_eq!(outcome.removed, None);
assert_eq!(set.count(), 4);
assert!(set.contains(2, 2_4));
set.undo().undo_import(outcome);
assert_eq!(set.count(), 3);
assert!(set.contains(3, 3_1));
assert!(set.contains(2, 2_2));
assert!(set.contains(2, 2_3));
let outcome = set.import(3_2, 3, 2_3);
assert_eq!(outcome.inserted.hash, 3_2);
assert_eq!(outcome.removed, Some(2_3));
assert_eq!(set.count(), 3);
assert!(set.contains(3, 3_2));
set.undo().undo_import(outcome);
assert_eq!(set.count(), 3);
assert!(set.contains(3, 3_1));
assert!(set.contains(2, 2_2));
assert!(set.contains(2, 2_3));
}
#[test]
fn removal_works() {
let mut set = LeafSet::new();
set.import(10_1u32, 10u32, 0u32);
set.import(11_1, 11, 10_1);
set.import(11_2, 11, 10_1);
set.import(12_1, 12, 11_1);
let outcome = set.remove(12_1, 12, Some(11_1)).unwrap();
assert_eq!(outcome.removed.hash, 12_1);
assert_eq!(outcome.inserted, Some(11_1));
assert_eq!(set.count(), 2);
assert!(set.contains(11, 11_1));
assert!(set.contains(11, 11_2));
let outcome = set.remove(11_1, 11, None).unwrap();
assert_eq!(outcome.removed.hash, 11_1);
assert_eq!(outcome.inserted, None);
assert_eq!(set.count(), 1);
assert!(set.contains(11, 11_2));
let outcome = set.remove(11_2, 11, Some(10_1)).unwrap();
assert_eq!(outcome.removed.hash, 11_2);
assert_eq!(outcome.inserted, Some(10_1));
assert_eq!(set.count(), 1);
assert!(set.contains(10, 10_1));
set.undo().undo_remove(outcome);
assert_eq!(set.count(), 1);
assert!(set.contains(11, 11_2));
}
#[test]
fn flush_to_disk() {
const PREFIX: &[u8] = b"abcdefg";
let db = Arc::new(sp_database::MemDb::default());
let mut set = LeafSet::new();
set.import(0u32, 0u32, 0u32);
set.import(1_1, 1, 0);
set.import(2_1, 2, 1_1);
set.import(3_1, 3, 2_1);
let mut tx = Transaction::new();
set.prepare_transaction(&mut tx, 0, PREFIX);
db.commit(tx).unwrap();
let set2 = LeafSet::read_from_db(&*db, 0, PREFIX).unwrap();
assert_eq!(set, set2);
}
#[test]
fn two_leaves_same_height_can_be_included() {
let mut set = LeafSet::new();
set.import(1_1u32, 10u32, 0u32);
set.import(1_2, 10, 0);
assert!(set.storage.contains_key(&Reverse(10)));
assert!(set.contains(10, 1_1));
assert!(set.contains(10, 1_2));
assert!(!set.contains(10, 1_3));
}
}