1use crate::{NodeCodec, StorageProof};
24use codec::Encode;
25use hash_db::Hasher;
26use parking_lot::{Mutex, MutexGuard};
27use std::{
28 collections::{HashMap, HashSet},
29 marker::PhantomData,
30 mem,
31 ops::DerefMut,
32 sync::{
33 atomic::{AtomicUsize, Ordering},
34 Arc,
35 },
36};
37use trie_db::{RecordedForKey, TrieAccess};
38
39const LOG_TARGET: &str = "trie-recorder";
40
41#[derive(Default)]
43struct Transaction<H> {
44 recorded_keys: HashMap<H, HashMap<Arc<[u8]>, Option<RecordedForKey>>>,
49 accessed_nodes: HashSet<H>,
53}
54
55struct RecorderInner<H> {
57 recorded_keys: HashMap<H, HashMap<Arc<[u8]>, RecordedForKey>>,
61
62 transactions: Vec<Transaction<H>>,
64
65 accessed_nodes: HashMap<H, Vec<u8>>,
69}
70
71impl<H> Default for RecorderInner<H> {
72 fn default() -> Self {
73 Self {
74 recorded_keys: Default::default(),
75 accessed_nodes: Default::default(),
76 transactions: Vec::new(),
77 }
78 }
79}
80
81pub struct Recorder<H: Hasher> {
87 inner: Arc<Mutex<RecorderInner<H::Out>>>,
88 encoded_size_estimation: Arc<AtomicUsize>,
92}
93
94impl<H: Hasher> Default for Recorder<H> {
95 fn default() -> Self {
96 Self { inner: Default::default(), encoded_size_estimation: Arc::new(0.into()) }
97 }
98}
99
100impl<H: Hasher> Clone for Recorder<H> {
101 fn clone(&self) -> Self {
102 Self {
103 inner: self.inner.clone(),
104 encoded_size_estimation: self.encoded_size_estimation.clone(),
105 }
106 }
107}
108
109impl<H: Hasher> Recorder<H> {
110 pub fn recorded_keys(&self) -> HashMap<<H as Hasher>::Out, HashMap<Arc<[u8]>, RecordedForKey>> {
114 self.inner.lock().recorded_keys.clone()
115 }
116
117 #[inline]
124 pub fn as_trie_recorder(&self, storage_root: H::Out) -> TrieRecorder<'_, H> {
125 TrieRecorder::<H> {
126 inner: self.inner.lock(),
127 storage_root,
128 encoded_size_estimation: self.encoded_size_estimation.clone(),
129 _phantom: PhantomData,
130 }
131 }
132
133 pub fn drain_storage_proof(self) -> StorageProof {
142 let mut recorder = mem::take(&mut *self.inner.lock());
143 StorageProof::new(recorder.accessed_nodes.drain().map(|(_, v)| v))
144 }
145
146 pub fn to_storage_proof(&self) -> StorageProof {
153 let recorder = self.inner.lock();
154 StorageProof::new(recorder.accessed_nodes.values().cloned())
155 }
156
157 pub fn estimate_encoded_size(&self) -> usize {
162 self.encoded_size_estimation.load(Ordering::Relaxed)
163 }
164
165 pub fn reset(&self) {
169 mem::take(&mut *self.inner.lock());
170 self.encoded_size_estimation.store(0, Ordering::Relaxed);
171 }
172
173 pub fn start_transaction(&self) {
175 let mut inner = self.inner.lock();
176 inner.transactions.push(Default::default());
177 }
178
179 pub fn rollback_transaction(&self) -> Result<(), ()> {
183 let mut inner = self.inner.lock();
184
185 let mut new_encoded_size_estimation = self.encoded_size_estimation.load(Ordering::Relaxed);
188 let transaction = inner.transactions.pop().ok_or(())?;
189
190 transaction.accessed_nodes.into_iter().for_each(|n| {
191 if let Some(old) = inner.accessed_nodes.remove(&n) {
192 new_encoded_size_estimation =
193 new_encoded_size_estimation.saturating_sub(old.encoded_size());
194 }
195 });
196
197 transaction.recorded_keys.into_iter().for_each(|(storage_root, keys)| {
198 keys.into_iter().for_each(|(k, old_state)| {
199 if let Some(state) = old_state {
200 inner.recorded_keys.entry(storage_root).or_default().insert(k, state);
201 } else {
202 inner.recorded_keys.entry(storage_root).or_default().remove(&k);
203 }
204 });
205 });
206
207 self.encoded_size_estimation
208 .store(new_encoded_size_estimation, Ordering::Relaxed);
209
210 Ok(())
211 }
212
213 pub fn commit_transaction(&self) -> Result<(), ()> {
217 let mut inner = self.inner.lock();
218
219 let transaction = inner.transactions.pop().ok_or(())?;
220
221 if let Some(parent_transaction) = inner.transactions.last_mut() {
222 parent_transaction.accessed_nodes.extend(transaction.accessed_nodes);
223
224 transaction.recorded_keys.into_iter().for_each(|(storage_root, keys)| {
225 keys.into_iter().for_each(|(k, old_state)| {
226 parent_transaction
227 .recorded_keys
228 .entry(storage_root)
229 .or_default()
230 .entry(k)
231 .or_insert(old_state);
232 })
233 });
234 }
235
236 Ok(())
237 }
238}
239
240impl<H: Hasher> crate::ProofSizeProvider for Recorder<H> {
241 fn estimate_encoded_size(&self) -> usize {
242 Recorder::estimate_encoded_size(self)
243 }
244}
245
246pub struct TrieRecorder<'a, H: Hasher> {
248 inner: MutexGuard<'a, RecorderInner<H::Out>>,
249 storage_root: H::Out,
250 encoded_size_estimation: Arc<AtomicUsize>,
251 _phantom: PhantomData<H>,
252}
253
254impl<H: Hasher> crate::TrieRecorderProvider<H> for Recorder<H> {
255 type Recorder<'a> = TrieRecorder<'a, H> where H: 'a;
256
257 fn drain_storage_proof(self) -> Option<StorageProof> {
258 Some(Recorder::drain_storage_proof(self))
259 }
260
261 fn as_trie_recorder(&self, storage_root: H::Out) -> Self::Recorder<'_> {
262 Recorder::as_trie_recorder(&self, storage_root)
263 }
264}
265
266impl<'a, H: Hasher> TrieRecorder<'a, H> {
267 fn update_recorded_keys(&mut self, full_key: &[u8], access: RecordedForKey) {
269 let inner = self.inner.deref_mut();
270
271 let entry =
272 inner.recorded_keys.entry(self.storage_root).or_default().entry(full_key.into());
273
274 let key = entry.key().clone();
275
276 let entry = if matches!(access, RecordedForKey::Value) {
279 entry.and_modify(|e| {
280 if let Some(tx) = inner.transactions.last_mut() {
281 tx.recorded_keys
283 .entry(self.storage_root)
284 .or_default()
285 .entry(key.clone())
286 .or_insert(Some(*e));
287 }
288
289 *e = access;
290 })
291 } else {
292 entry
293 };
294
295 entry.or_insert_with(|| {
296 if let Some(tx) = inner.transactions.last_mut() {
297 tx.recorded_keys
299 .entry(self.storage_root)
300 .or_default()
301 .entry(key)
302 .or_insert(None);
303 }
304
305 access
306 });
307 }
308}
309
310impl<'a, H: Hasher> trie_db::TrieRecorder<H::Out> for TrieRecorder<'a, H> {
311 fn record(&mut self, access: TrieAccess<H::Out>) {
312 let mut encoded_size_update = 0;
313
314 match access {
315 TrieAccess::NodeOwned { hash, node_owned } => {
316 tracing::trace!(
317 target: LOG_TARGET,
318 hash = ?hash,
319 "Recording node",
320 );
321
322 let inner = self.inner.deref_mut();
323
324 inner.accessed_nodes.entry(hash).or_insert_with(|| {
325 let node = node_owned.to_encoded::<NodeCodec<H>>();
326
327 encoded_size_update += node.encoded_size();
328
329 if let Some(tx) = inner.transactions.last_mut() {
330 tx.accessed_nodes.insert(hash);
331 }
332
333 node
334 });
335 },
336 TrieAccess::EncodedNode { hash, encoded_node } => {
337 tracing::trace!(
338 target: LOG_TARGET,
339 hash = ?hash,
340 "Recording node",
341 );
342
343 let inner = self.inner.deref_mut();
344
345 inner.accessed_nodes.entry(hash).or_insert_with(|| {
346 let node = encoded_node.into_owned();
347
348 encoded_size_update += node.encoded_size();
349
350 if let Some(tx) = inner.transactions.last_mut() {
351 tx.accessed_nodes.insert(hash);
352 }
353
354 node
355 });
356 },
357 TrieAccess::Value { hash, value, full_key } => {
358 tracing::trace!(
359 target: LOG_TARGET,
360 hash = ?hash,
361 key = ?sp_core::hexdisplay::HexDisplay::from(&full_key),
362 "Recording value",
363 );
364
365 let inner = self.inner.deref_mut();
366
367 inner.accessed_nodes.entry(hash).or_insert_with(|| {
368 let value = value.into_owned();
369
370 encoded_size_update += value.encoded_size();
371
372 if let Some(tx) = inner.transactions.last_mut() {
373 tx.accessed_nodes.insert(hash);
374 }
375
376 value
377 });
378
379 self.update_recorded_keys(full_key, RecordedForKey::Value);
380 },
381 TrieAccess::Hash { full_key } => {
382 tracing::trace!(
383 target: LOG_TARGET,
384 key = ?sp_core::hexdisplay::HexDisplay::from(&full_key),
385 "Recorded hash access for key",
386 );
387
388 self.update_recorded_keys(full_key, RecordedForKey::Hash);
391 },
392 TrieAccess::NonExisting { full_key } => {
393 tracing::trace!(
394 target: LOG_TARGET,
395 key = ?sp_core::hexdisplay::HexDisplay::from(&full_key),
396 "Recorded non-existing value access for key",
397 );
398
399 self.update_recorded_keys(full_key, RecordedForKey::Value);
403 },
404 TrieAccess::InlineValue { full_key } => {
405 tracing::trace!(
406 target: LOG_TARGET,
407 key = ?sp_core::hexdisplay::HexDisplay::from(&full_key),
408 "Recorded inline value access for key",
409 );
410
411 self.update_recorded_keys(full_key, RecordedForKey::Value);
414 },
415 };
416
417 self.encoded_size_estimation.fetch_add(encoded_size_update, Ordering::Relaxed);
418 }
419
420 fn trie_nodes_recorded_for_key(&self, key: &[u8]) -> RecordedForKey {
421 self.inner
422 .recorded_keys
423 .get(&self.storage_root)
424 .and_then(|k| k.get(key).copied())
425 .unwrap_or(RecordedForKey::None)
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432 use crate::tests::create_trie;
433 use trie_db::{Trie, TrieDBBuilder, TrieRecorder};
434
435 type MemoryDB = crate::MemoryDB<sp_core::Blake2Hasher>;
436 type Layout = crate::LayoutV1<sp_core::Blake2Hasher>;
437 type Recorder = super::Recorder<sp_core::Blake2Hasher>;
438
439 const TEST_DATA: &[(&[u8], &[u8])] =
440 &[(b"key1", &[1; 64]), (b"key2", &[2; 64]), (b"key3", &[3; 64]), (b"key4", &[4; 64])];
441
442 #[test]
443 fn recorder_works() {
444 let (db, root) = create_trie::<Layout>(TEST_DATA);
445
446 let recorder = Recorder::default();
447
448 {
449 let mut trie_recorder = recorder.as_trie_recorder(root);
450 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
451 .with_recorder(&mut trie_recorder)
452 .build();
453 assert_eq!(TEST_DATA[0].1.to_vec(), trie.get(TEST_DATA[0].0).unwrap().unwrap());
454 }
455
456 let storage_proof = recorder.drain_storage_proof();
457 let memory_db: MemoryDB = storage_proof.into_memory_db();
458
459 let trie = TrieDBBuilder::<Layout>::new(&memory_db, &root).build();
461 assert_eq!(TEST_DATA[0].1.to_vec(), trie.get(TEST_DATA[0].0).unwrap().unwrap());
462 }
463
464 #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
465 struct RecorderStats {
466 accessed_nodes: usize,
467 recorded_keys: usize,
468 estimated_size: usize,
469 }
470
471 impl RecorderStats {
472 fn extract(recorder: &Recorder) -> Self {
473 let inner = recorder.inner.lock();
474
475 let recorded_keys =
476 inner.recorded_keys.iter().flat_map(|(_, keys)| keys.keys()).count();
477
478 Self {
479 recorded_keys,
480 accessed_nodes: inner.accessed_nodes.len(),
481 estimated_size: recorder.estimate_encoded_size(),
482 }
483 }
484 }
485
486 #[test]
487 fn recorder_transactions_rollback_work() {
488 let (db, root) = create_trie::<Layout>(TEST_DATA);
489
490 let recorder = Recorder::default();
491 let mut stats = vec![RecorderStats::default()];
492
493 for i in 0..4 {
494 recorder.start_transaction();
495 {
496 let mut trie_recorder = recorder.as_trie_recorder(root);
497 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
498 .with_recorder(&mut trie_recorder)
499 .build();
500
501 assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
502 }
503 stats.push(RecorderStats::extract(&recorder));
504 }
505
506 assert_eq!(4, recorder.inner.lock().transactions.len());
507
508 for i in 0..5 {
509 assert_eq!(stats[4 - i], RecorderStats::extract(&recorder));
510
511 let storage_proof = recorder.to_storage_proof();
512 let memory_db: MemoryDB = storage_proof.into_memory_db();
513
514 let trie = TrieDBBuilder::<Layout>::new(&memory_db, &root).build();
516
517 for a in 0..4 {
519 if a < 4 - i {
520 assert_eq!(TEST_DATA[a].1.to_vec(), trie.get(TEST_DATA[a].0).unwrap().unwrap());
521 } else {
522 assert!(trie.get(TEST_DATA[a].0).is_err());
524 }
525 }
526
527 if i < 4 {
528 recorder.rollback_transaction().unwrap();
529 }
530 }
531
532 assert_eq!(0, recorder.inner.lock().transactions.len());
533 }
534
535 #[test]
536 fn recorder_transactions_commit_work() {
537 let (db, root) = create_trie::<Layout>(TEST_DATA);
538
539 let recorder = Recorder::default();
540
541 for i in 0..4 {
542 recorder.start_transaction();
543 {
544 let mut trie_recorder = recorder.as_trie_recorder(root);
545 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
546 .with_recorder(&mut trie_recorder)
547 .build();
548
549 assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
550 }
551 }
552
553 let stats = RecorderStats::extract(&recorder);
554 assert_eq!(4, recorder.inner.lock().transactions.len());
555
556 for _ in 0..4 {
557 recorder.commit_transaction().unwrap();
558 }
559 assert_eq!(0, recorder.inner.lock().transactions.len());
560 assert_eq!(stats, RecorderStats::extract(&recorder));
561
562 let storage_proof = recorder.to_storage_proof();
563 let memory_db: MemoryDB = storage_proof.into_memory_db();
564
565 let trie = TrieDBBuilder::<Layout>::new(&memory_db, &root).build();
567
568 for i in 0..4 {
570 assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
571 }
572 }
573
574 #[test]
575 fn recorder_transactions_commit_and_rollback_work() {
576 let (db, root) = create_trie::<Layout>(TEST_DATA);
577
578 let recorder = Recorder::default();
579
580 for i in 0..2 {
581 recorder.start_transaction();
582 {
583 let mut trie_recorder = recorder.as_trie_recorder(root);
584 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
585 .with_recorder(&mut trie_recorder)
586 .build();
587
588 assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
589 }
590 }
591
592 recorder.rollback_transaction().unwrap();
593
594 for i in 2..4 {
595 recorder.start_transaction();
596 {
597 let mut trie_recorder = recorder.as_trie_recorder(root);
598 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
599 .with_recorder(&mut trie_recorder)
600 .build();
601
602 assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
603 }
604 }
605
606 recorder.rollback_transaction().unwrap();
607
608 assert_eq!(2, recorder.inner.lock().transactions.len());
609
610 for _ in 0..2 {
611 recorder.commit_transaction().unwrap();
612 }
613
614 assert_eq!(0, recorder.inner.lock().transactions.len());
615
616 let storage_proof = recorder.to_storage_proof();
617 let memory_db: MemoryDB = storage_proof.into_memory_db();
618
619 let trie = TrieDBBuilder::<Layout>::new(&memory_db, &root).build();
621
622 for i in 0..4 {
624 if i % 2 == 0 {
625 assert_eq!(TEST_DATA[i].1.to_vec(), trie.get(TEST_DATA[i].0).unwrap().unwrap());
626 } else {
627 assert!(trie.get(TEST_DATA[i].0).is_err());
628 }
629 }
630 }
631
632 #[test]
633 fn recorder_transaction_accessed_keys_works() {
634 let key = TEST_DATA[0].0;
635 let (db, root) = create_trie::<Layout>(TEST_DATA);
636
637 let recorder = Recorder::default();
638
639 {
640 let trie_recorder = recorder.as_trie_recorder(root);
641 assert!(matches!(trie_recorder.trie_nodes_recorded_for_key(key), RecordedForKey::None));
642 }
643
644 recorder.start_transaction();
645 {
646 let mut trie_recorder = recorder.as_trie_recorder(root);
647 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
648 .with_recorder(&mut trie_recorder)
649 .build();
650
651 assert_eq!(
652 sp_core::Blake2Hasher::hash(TEST_DATA[0].1),
653 trie.get_hash(TEST_DATA[0].0).unwrap().unwrap()
654 );
655 assert!(matches!(trie_recorder.trie_nodes_recorded_for_key(key), RecordedForKey::Hash));
656 }
657
658 recorder.start_transaction();
659 {
660 let mut trie_recorder = recorder.as_trie_recorder(root);
661 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
662 .with_recorder(&mut trie_recorder)
663 .build();
664
665 assert_eq!(TEST_DATA[0].1.to_vec(), trie.get(TEST_DATA[0].0).unwrap().unwrap());
666 assert!(matches!(
667 trie_recorder.trie_nodes_recorded_for_key(key),
668 RecordedForKey::Value,
669 ));
670 }
671
672 recorder.rollback_transaction().unwrap();
673 {
674 let trie_recorder = recorder.as_trie_recorder(root);
675 assert!(matches!(trie_recorder.trie_nodes_recorded_for_key(key), RecordedForKey::Hash));
676 }
677
678 recorder.rollback_transaction().unwrap();
679 {
680 let trie_recorder = recorder.as_trie_recorder(root);
681 assert!(matches!(trie_recorder.trie_nodes_recorded_for_key(key), RecordedForKey::None));
682 }
683
684 recorder.start_transaction();
685 {
686 let mut trie_recorder = recorder.as_trie_recorder(root);
687 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
688 .with_recorder(&mut trie_recorder)
689 .build();
690
691 assert_eq!(TEST_DATA[0].1.to_vec(), trie.get(TEST_DATA[0].0).unwrap().unwrap());
692 assert!(matches!(
693 trie_recorder.trie_nodes_recorded_for_key(key),
694 RecordedForKey::Value,
695 ));
696 }
697
698 recorder.start_transaction();
699 {
700 let mut trie_recorder = recorder.as_trie_recorder(root);
701 let trie = TrieDBBuilder::<Layout>::new(&db, &root)
702 .with_recorder(&mut trie_recorder)
703 .build();
704
705 assert_eq!(
706 sp_core::Blake2Hasher::hash(TEST_DATA[0].1),
707 trie.get_hash(TEST_DATA[0].0).unwrap().unwrap()
708 );
709 assert!(matches!(
710 trie_recorder.trie_nodes_recorded_for_key(key),
711 RecordedForKey::Value
712 ));
713 }
714
715 recorder.rollback_transaction().unwrap();
716 {
717 let trie_recorder = recorder.as_trie_recorder(root);
718 assert!(matches!(
719 trie_recorder.trie_nodes_recorded_for_key(key),
720 RecordedForKey::Value
721 ));
722 }
723
724 recorder.rollback_transaction().unwrap();
725 {
726 let trie_recorder = recorder.as_trie_recorder(root);
727 assert!(matches!(trie_recorder.trie_nodes_recorded_for_key(key), RecordedForKey::None));
728 }
729 }
730}