use crate::{NodeCodec, StorageProof};
use codec::Encode;
use hash_db::Hasher;
use parking_lot::{Mutex, MutexGuard};
use std::{
collections::{HashMap, HashSet},
marker::PhantomData,
mem,
ops::DerefMut,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use trie_db::{RecordedForKey, TrieAccess};
const LOG_TARGET: &str = "trie-recorder";
#[derive(Default)]
struct Transaction<H> {
recorded_keys: HashMap<H, HashMap<Arc<[u8]>, Option<RecordedForKey>>>,
accessed_nodes: HashSet<H>,
}
struct RecorderInner<H> {
recorded_keys: HashMap<H, HashMap<Arc<[u8]>, RecordedForKey>>,
transactions: Vec<Transaction<H>>,
accessed_nodes: HashMap<H, Vec<u8>>,
}
impl<H> Default for RecorderInner<H> {
fn default() -> Self {
Self {
recorded_keys: Default::default(),
accessed_nodes: Default::default(),
transactions: Vec::new(),
}
}
}
pub struct Recorder<H: Hasher> {
inner: Arc<Mutex<RecorderInner<H::Out>>>,
encoded_size_estimation: Arc<AtomicUsize>,
}
impl<H: Hasher> Default for Recorder<H> {
fn default() -> Self {
Self { inner: Default::default(), encoded_size_estimation: Arc::new(0.into()) }
}
}
impl<H: Hasher> Clone for Recorder<H> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
encoded_size_estimation: self.encoded_size_estimation.clone(),
}
}
}
impl<H: Hasher> Recorder<H> {
pub fn recorded_keys(&self) -> HashMap<<H as Hasher>::Out, HashMap<Arc<[u8]>, RecordedForKey>> {
self.inner.lock().recorded_keys.clone()
}
#[inline]
pub fn as_trie_recorder(&self, storage_root: H::Out) -> TrieRecorder<'_, H> {
TrieRecorder::<H> {
inner: self.inner.lock(),
storage_root,
encoded_size_estimation: self.encoded_size_estimation.clone(),
_phantom: PhantomData,
}
}
pub fn drain_storage_proof(self) -> StorageProof {
let mut recorder = mem::take(&mut *self.inner.lock());
StorageProof::new(recorder.accessed_nodes.drain().map(|(_, v)| v))
}
pub fn to_storage_proof(&self) -> StorageProof {
let recorder = self.inner.lock();
StorageProof::new(recorder.accessed_nodes.values().cloned())
}
pub fn estimate_encoded_size(&self) -> usize {
self.encoded_size_estimation.load(Ordering::Relaxed)
}
pub fn reset(&self) {
mem::take(&mut *self.inner.lock());
self.encoded_size_estimation.store(0, Ordering::Relaxed);
}
pub fn start_transaction(&self) {
let mut inner = self.inner.lock();
inner.transactions.push(Default::default());
}
pub fn rollback_transaction(&self) -> Result<(), ()> {
let mut inner = self.inner.lock();
let mut new_encoded_size_estimation = self.encoded_size_estimation.load(Ordering::Relaxed);
let transaction = inner.transactions.pop().ok_or(())?;
transaction.accessed_nodes.into_iter().for_each(|n| {
if let Some(old) = inner.accessed_nodes.remove(&n) {
new_encoded_size_estimation =
new_encoded_size_estimation.saturating_sub(old.encoded_size());
}
});
transaction.recorded_keys.into_iter().for_each(|(storage_root, keys)| {
keys.into_iter().for_each(|(k, old_state)| {
if let Some(state) = old_state {
inner.recorded_keys.entry(storage_root).or_default().insert(k, state);
} else {
inner.recorded_keys.entry(storage_root).or_default().remove(&k);
}
});
});
self.encoded_size_estimation
.store(new_encoded_size_estimation, Ordering::Relaxed);
Ok(())
}
pub fn commit_transaction(&self) -> Result<(), ()> {
let mut inner = self.inner.lock();
let transaction = inner.transactions.pop().ok_or(())?;
if let Some(parent_transaction) = inner.transactions.last_mut() {
parent_transaction.accessed_nodes.extend(transaction.accessed_nodes);
transaction.recorded_keys.into_iter().for_each(|(storage_root, keys)| {
keys.into_iter().for_each(|(k, old_state)| {
parent_transaction
.recorded_keys
.entry(storage_root)
.or_default()
.entry(k)
.or_insert(old_state);
})
});
}
Ok(())
}
}
impl<H: Hasher> crate::ProofSizeProvider for Recorder<H> {
fn estimate_encoded_size(&self) -> usize {
Recorder::estimate_encoded_size(self)
}
}
pub struct TrieRecorder<'a, H: Hasher> {
inner: MutexGuard<'a, RecorderInner<H::Out>>,
storage_root: H::Out,
encoded_size_estimation: Arc<AtomicUsize>,
_phantom: PhantomData<H>,
}
impl<H: Hasher> crate::TrieRecorderProvider<H> for Recorder<H> {
type Recorder<'a> = TrieRecorder<'a, H> where H: 'a;
fn drain_storage_proof(self) -> Option<StorageProof> {
Some(Recorder::drain_storage_proof(self))
}
fn as_trie_recorder(&self, storage_root: H::Out) -> Self::Recorder<'_> {
Recorder::as_trie_recorder(&self, storage_root)
}
}
impl<'a, H: Hasher> TrieRecorder<'a, H> {
fn update_recorded_keys(&mut self, full_key: &[u8], access: RecordedForKey) {
let inner = self.inner.deref_mut();
let entry =
inner.recorded_keys.entry(self.storage_root).or_default().entry(full_key.into());
let key = entry.key().clone();
let entry = if matches!(access, RecordedForKey::Value) {
entry.and_modify(|e| {
if let Some(tx) = inner.transactions.last_mut() {
tx.recorded_keys
.entry(self.storage_root)
.or_default()
.entry(key.clone())
.or_insert(Some(*e));
}
*e = access;
})
} else {
entry
};
entry.or_insert_with(|| {
if let Some(tx) = inner.transactions.last_mut() {
tx.recorded_keys
.entry(self.storage_root)
.or_default()
.entry(key)
.or_insert(None);
}
access
});
}
}
impl<'a, H: Hasher> trie_db::TrieRecorder<H::Out> for TrieRecorder<'a, H> {
fn record(&mut self, access: TrieAccess<H::Out>) {
let mut encoded_size_update = 0;
match access {
TrieAccess::NodeOwned { hash, node_owned } => {
tracing::trace!(
target: LOG_TARGET,
hash = ?hash,
"Recording node",
);
let inner = self.inner.deref_mut();
inner.accessed_nodes.entry(hash).or_insert_with(|| {
let node = node_owned.to_encoded::<NodeCodec<H>>();
encoded_size_update += node.encoded_size();
if let Some(tx) = inner.transactions.last_mut() {
tx.accessed_nodes.insert(hash);
}
node
});
},
TrieAccess::EncodedNode { hash, encoded_node } => {
tracing::trace!(
target: LOG_TARGET,
hash = ?hash,
"Recording node",
);
let inner = self.inner.deref_mut();
inner.accessed_nodes.entry(hash).or_insert_with(|| {
let node = encoded_node.into_owned();
encoded_size_update += node.encoded_size();
if let Some(tx) = inner.transactions.last_mut() {
tx.accessed_nodes.insert(hash);
}
node
});
},
TrieAccess::Value { hash, value, full_key } => {
tracing::trace!(
target: LOG_TARGET,
hash = ?hash,
key = ?sp_core::hexdisplay::HexDisplay::from(&full_key),
"Recording value",
);
let inner = self.inner.deref_mut();
inner.accessed_nodes.entry(hash).or_insert_with(|| {
let value = value.into_owned();
encoded_size_update += value.encoded_size();
if let Some(tx) = inner.transactions.last_mut() {
tx.accessed_nodes.insert(hash);
}
value
});
self.update_recorded_keys(full_key, RecordedForKey::Value);
},
TrieAccess::Hash { full_key } => {
tracing::trace!(
target: LOG_TARGET,
key = ?sp_core::hexdisplay::HexDisplay::from(&full_key),
"Recorded hash access for key",
);
self.update_recorded_keys(full_key, RecordedForKey::Hash);
},
TrieAccess::NonExisting { full_key } => {
tracing::trace!(
target: LOG_TARGET,
key = ?sp_core::hexdisplay::HexDisplay::from(&full_key),
"Recorded non-existing value access for key",
);
self.update_recorded_keys(full_key, RecordedForKey::Value);
},
TrieAccess::InlineValue { full_key } => {
tracing::trace!(
target: LOG_TARGET,
key = ?sp_core::hexdisplay::HexDisplay::from(&full_key),
"Recorded inline value access for key",
);
self.update_recorded_keys(full_key, RecordedForKey::Value);
},
};
self.encoded_size_estimation.fetch_add(encoded_size_update, Ordering::Relaxed);
}
fn trie_nodes_recorded_for_key(&self, key: &[u8]) -> RecordedForKey {
self.inner
.recorded_keys
.get(&self.storage_root)
.and_then(|k| k.get(key).copied())
.unwrap_or(RecordedForKey::None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::create_trie;
use trie_db::{Trie, TrieDBBuilder, TrieRecorder};
type MemoryDB = crate::MemoryDB<sp_core::Blake2Hasher>;
type Layout = crate::LayoutV1<sp_core::Blake2Hasher>;
type Recorder = super::Recorder<sp_core::Blake2Hasher>;
const TEST_DATA: &[(&[u8], &[u8])] =
&[(b"key1", &[1; 64]), (b"key2", &[2; 64]), (b"key3", &[3; 64]), (b"key4", &[4; 64])];
#[test]
fn recorder_works() {
let (db, root) = create_trie::<Layout>(TEST_DATA);
let recorder = Recorder::default();
{
let mut trie_recorder = recorder.as_trie_recorder(root);
let trie = TrieDBBuilder::<Layout>::new(&db, &root)
.with_recorder(&mut trie_recorder)
.build();
assert_eq!(TEST_DATA[0].1.to_vec(), trie.get(TEST_DATA[0].0).unwrap().unwrap());
}
let storage_proof = recorder.drain_storage_proof();
let memory_db: MemoryDB = storage_proof.into_memory_db();
let trie = TrieDBBuilder::<Layout>::new(&memory_db, &root).build();
assert_eq!(TEST_DATA[0].1.to_vec(), trie.get(TEST_DATA[0].0).unwrap().unwrap());
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
struct RecorderStats {
accessed_nodes: usize,
recorded_keys: usize,
estimated_size: usize,
}
impl RecorderStats {
fn extract(recorder: &Recorder) -> Self {
let inner = recorder.inner.lock();
let recorded_keys =
inner.recorded_keys.iter().flat_map(|(_, keys)| keys.keys()).count();
Self {
recorded_keys,
accessed_nodes: inner.accessed_nodes.len(),
estimated_size: recorder.estimate_encoded_size(),
}
}
}
#[test]
fn recorder_transactions_rollback_work() {
let (db, root) = create_trie::<Layout>(TEST_DATA);
let recorder = Recorder::default();
let mut stats = vec![RecorderStats::default()];
for i in 0..4 {
recorder.start_transaction();
{
let mut trie_recorder = recorder.as_trie_recorder(root);
let trie = TrieDBBuilder::<Layout>::new(&db, &root)
.with_recorder(&mut trie_recorder)
.build();
assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
}
stats.push(RecorderStats::extract(&recorder));
}
assert_eq!(4, recorder.inner.lock().transactions.len());
for i in 0..5 {
assert_eq!(stats[4 - i], RecorderStats::extract(&recorder));
let storage_proof = recorder.to_storage_proof();
let memory_db: MemoryDB = storage_proof.into_memory_db();
let trie = TrieDBBuilder::<Layout>::new(&memory_db, &root).build();
for a in 0..4 {
if a < 4 - i {
assert_eq!(TEST_DATA[a].1.to_vec(), trie.get(TEST_DATA[a].0).unwrap().unwrap());
} else {
assert!(trie.get(TEST_DATA[a].0).is_err());
}
}
if i < 4 {
recorder.rollback_transaction().unwrap();
}
}
assert_eq!(0, recorder.inner.lock().transactions.len());
}
#[test]
fn recorder_transactions_commit_work() {
let (db, root) = create_trie::<Layout>(TEST_DATA);
let recorder = Recorder::default();
for i in 0..4 {
recorder.start_transaction();
{
let mut trie_recorder = recorder.as_trie_recorder(root);
let trie = TrieDBBuilder::<Layout>::new(&db, &root)
.with_recorder(&mut trie_recorder)
.build();
assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
}
}
let stats = RecorderStats::extract(&recorder);
assert_eq!(4, recorder.inner.lock().transactions.len());
for _ in 0..4 {
recorder.commit_transaction().unwrap();
}
assert_eq!(0, recorder.inner.lock().transactions.len());
assert_eq!(stats, RecorderStats::extract(&recorder));
let storage_proof = recorder.to_storage_proof();
let memory_db: MemoryDB = storage_proof.into_memory_db();
let trie = TrieDBBuilder::<Layout>::new(&memory_db, &root).build();
for i in 0..4 {
assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
}
}
#[test]
fn recorder_transactions_commit_and_rollback_work() {
let (db, root) = create_trie::<Layout>(TEST_DATA);
let recorder = Recorder::default();
for i in 0..2 {
recorder.start_transaction();
{
let mut trie_recorder = recorder.as_trie_recorder(root);
let trie = TrieDBBuilder::<Layout>::new(&db, &root)
.with_recorder(&mut trie_recorder)
.build();
assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
}
}
recorder.rollback_transaction().unwrap();
for i in 2..4 {
recorder.start_transaction();
{
let mut trie_recorder = recorder.as_trie_recorder(root);
let trie = TrieDBBuilder::<Layout>::new(&db, &root)
.with_recorder(&mut trie_recorder)
.build();
assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
}
}
recorder.rollback_transaction().unwrap();
assert_eq!(2, recorder.inner.lock().transactions.len());
for _ in 0..2 {
recorder.commit_transaction().unwrap();
}
assert_eq!(0, recorder.inner.lock().transactions.len());
let storage_proof = recorder.to_storage_proof();
let memory_db: MemoryDB = storage_proof.into_memory_db();
let trie = TrieDBBuilder::<Layout>::new(&memory_db, &root).build();
for i in 0..4 {
if i % 2 == 0 {
assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
} else {
assert!(trie.get(TEST_DATA[i].0).is_err());
}
}
}
#[test]
fn recorder_transaction_accessed_keys_works() {
let key = TEST_DATA[0].0;
let (db, root) = create_trie::<Layout>(TEST_DATA);
let recorder = Recorder::default();
{
let trie_recorder = recorder.as_trie_recorder(root);
assert!(matches!(trie_recorder.trie_nodes_recorded_for_key(key), RecordedForKey::None));
}
recorder.start_transaction();
{
let mut trie_recorder = recorder.as_trie_recorder(root);
let trie = TrieDBBuilder::<Layout>::new(&db, &root)
.with_recorder(&mut trie_recorder)
.build();
assert_eq!(
sp_core::Blake2Hasher::hash(TEST_DATA[0].1),
trie.get_hash(TEST_DATA[0].0).unwrap().unwrap()
);
assert!(matches!(trie_recorder.trie_nodes_recorded_for_key(key), RecordedForKey::Hash));
}
recorder.start_transaction();
{
let mut trie_recorder = recorder.as_trie_recorder(root);
let trie = TrieDBBuilder::<Layout>::new(&db, &root)
.with_recorder(&mut trie_recorder)
.build();
assert_eq!(TEST_DATA[0].1.to_vec(), trie.get(TEST_DATA[0].0).unwrap().unwrap());
assert!(matches!(
trie_recorder.trie_nodes_recorded_for_key(key),
RecordedForKey::Value,
));
}
recorder.rollback_transaction().unwrap();
{
let trie_recorder = recorder.as_trie_recorder(root);
assert!(matches!(trie_recorder.trie_nodes_recorded_for_key(key), RecordedForKey::Hash));
}
recorder.rollback_transaction().unwrap();
{
let trie_recorder = recorder.as_trie_recorder(root);
assert!(matches!(trie_recorder.trie_nodes_recorded_for_key(key), RecordedForKey::None));
}
recorder.start_transaction();
{
let mut trie_recorder = recorder.as_trie_recorder(root);
let trie = TrieDBBuilder::<Layout>::new(&db, &root)
.with_recorder(&mut trie_recorder)
.build();
assert_eq!(TEST_DATA[0].1.to_vec(), trie.get(TEST_DATA[0].0).unwrap().unwrap());
assert!(matches!(
trie_recorder.trie_nodes_recorded_for_key(key),
RecordedForKey::Value,
));
}
recorder.start_transaction();
{
let mut trie_recorder = recorder.as_trie_recorder(root);
let trie = TrieDBBuilder::<Layout>::new(&db, &root)
.with_recorder(&mut trie_recorder)
.build();
assert_eq!(
sp_core::Blake2Hasher::hash(TEST_DATA[0].1),
trie.get_hash(TEST_DATA[0].0).unwrap().unwrap()
);
assert!(matches!(
trie_recorder.trie_nodes_recorded_for_key(key),
RecordedForKey::Value
));
}
recorder.rollback_transaction().unwrap();
{
let trie_recorder = recorder.as_trie_recorder(root);
assert!(matches!(
trie_recorder.trie_nodes_recorded_for_key(key),
RecordedForKey::Value
));
}
recorder.rollback_transaction().unwrap();
{
let trie_recorder = recorder.as_trie_recorder(root);
assert!(matches!(trie_recorder.trie_nodes_recorded_for_key(key), RecordedForKey::None));
}
}
}