use crate::{
exec::{AccountIdOf, Key},
storage::WriteOutcome,
Config, Error,
};
use alloc::{collections::BTreeMap, vec::Vec};
use codec::Encode;
use core::{marker::PhantomData, mem};
use frame_support::DefaultNoBound;
use sp_runtime::{DispatchError, DispatchResult, Saturating};
#[derive(Default, Debug)]
pub struct MeterEntry {
pub amount: u32,
pub limit: u32,
}
impl MeterEntry {
fn new(limit: u32) -> Self {
Self { limit, amount: Default::default() }
}
fn exceeds_limit(&self, amount: u32) -> bool {
self.amount.saturating_add(amount) > self.limit
}
fn absorb(&mut self, rhs: Self) {
self.amount.saturating_accrue(rhs.amount)
}
}
#[derive(DefaultNoBound)]
pub struct StorageMeter<T: Config> {
nested_meters: Vec<MeterEntry>,
root_meter: MeterEntry,
_phantom: PhantomData<T>,
}
impl<T: Config> StorageMeter<T> {
const STORAGE_FRACTION_DENOMINATOR: u32 = 16;
fn new(memory_limit: u32) -> Self {
Self { root_meter: MeterEntry::new(memory_limit), ..Default::default() }
}
fn charge(&mut self, amount: u32) -> DispatchResult {
let meter = self.current_mut();
if meter.exceeds_limit(amount) {
return Err(Error::<T>::OutOfTransientStorage.into());
}
meter.amount.saturating_accrue(amount);
Ok(())
}
fn revert(&mut self) {
self.nested_meters.pop().expect(
"A call to revert a meter must be preceded by a corresponding call to start a meter;
the code within this crate makes sure that this is always the case; qed",
);
}
fn start(&mut self) {
let meter = self.current();
let mut transaction_limit = meter.limit.saturating_sub(meter.amount);
if !self.nested_meters.is_empty() {
transaction_limit.saturating_reduce(
transaction_limit.saturating_div(Self::STORAGE_FRACTION_DENOMINATOR),
);
}
self.nested_meters.push(MeterEntry::new(transaction_limit));
}
fn commit(&mut self) {
let transaction_meter = self.nested_meters.pop().expect(
"A call to commit a meter must be preceded by a corresponding call to start a meter;
the code within this crate makes sure that this is always the case; qed",
);
self.current_mut().absorb(transaction_meter)
}
#[cfg(test)]
fn total_amount(&self) -> u32 {
self.nested_meters
.iter()
.fold(self.root_meter.amount, |acc, e| acc.saturating_add(e.amount))
}
pub fn current_mut(&mut self) -> &mut MeterEntry {
self.nested_meters.last_mut().unwrap_or(&mut self.root_meter)
}
pub fn current(&self) -> &MeterEntry {
self.nested_meters.last().unwrap_or(&self.root_meter)
}
}
struct JournalEntry {
key: Vec<u8>,
prev_value: Option<Vec<u8>>,
}
impl JournalEntry {
fn new(key: Vec<u8>, prev_value: Option<Vec<u8>>) -> Self {
Self { key, prev_value }
}
fn revert(self, storage: &mut Storage) {
storage.write(&self.key, self.prev_value);
}
}
struct Journal(Vec<JournalEntry>);
impl Journal {
fn new() -> Self {
Self(Default::default())
}
fn push(&mut self, entry: JournalEntry) {
self.0.push(entry);
}
fn len(&self) -> usize {
self.0.len()
}
fn rollback(&mut self, storage: &mut Storage, checkpoint: usize) {
self.0.drain(checkpoint..).rev().for_each(|entry| entry.revert(storage));
}
}
#[derive(Default)]
struct Storage(BTreeMap<Vec<u8>, Vec<u8>>);
impl Storage {
fn read(&self, key: &Vec<u8>) -> Option<Vec<u8>> {
self.0.get(key).cloned()
}
fn write(&mut self, key: &Vec<u8>, value: Option<Vec<u8>>) -> Option<Vec<u8>> {
if let Some(value) = value {
self.0.insert(key.clone(), value)
} else {
self.0.remove(key)
}
}
}
pub struct TransientStorage<T: Config> {
storage: Storage,
journal: Journal,
meter: StorageMeter<T>,
checkpoints: Vec<usize>,
}
impl<T: Config> TransientStorage<T> {
pub fn new(memory_limit: u32) -> Self {
TransientStorage {
storage: Default::default(),
journal: Journal::new(),
checkpoints: Default::default(),
meter: StorageMeter::new(memory_limit),
}
}
pub fn read(&self, account: &AccountIdOf<T>, key: &Key) -> Option<Vec<u8>> {
self.storage.read(&Self::storage_key(&account.encode(), &key.hash()))
}
pub fn write(
&mut self,
account: &AccountIdOf<T>,
key: &Key,
value: Option<Vec<u8>>,
take: bool,
) -> Result<WriteOutcome, DispatchError> {
let key = Self::storage_key(&account.encode(), &key.hash());
let prev_value = self.storage.read(&key);
if prev_value != value {
if let Some(value) = &value {
let key_len = key.capacity();
let mut amount = value
.capacity()
.saturating_add(key_len)
.saturating_add(mem::size_of::<JournalEntry>());
if prev_value.is_none() {
amount.saturating_accrue(key_len.saturating_add(mem::size_of::<Vec<u8>>()));
}
self.meter.charge(amount as _)?;
}
self.storage.write(&key, value);
self.journal.push(JournalEntry::new(key, prev_value.clone()));
}
Ok(match (take, prev_value) {
(_, None) => WriteOutcome::New,
(false, Some(prev_value)) => WriteOutcome::Overwritten(prev_value.len() as _),
(true, Some(prev_value)) => WriteOutcome::Taken(prev_value),
})
}
pub fn start_transaction(&mut self) {
self.meter.start();
self.checkpoints.push(self.journal.len());
}
pub fn rollback_transaction(&mut self) {
let checkpoint = self
.checkpoints
.pop()
.expect(
"A call to rollback_transaction must be preceded by a corresponding call to start_transaction;
the code within this crate makes sure that this is always the case; qed"
);
self.meter.revert();
self.journal.rollback(&mut self.storage, checkpoint);
}
pub fn commit_transaction(&mut self) {
self.checkpoints
.pop()
.expect(
"A call to commit_transaction must be preceded by a corresponding call to start_transaction;
the code within this crate makes sure that this is always the case; qed"
);
self.meter.commit();
}
#[cfg(any(test, feature = "runtime-benchmarks"))]
pub fn meter(&mut self) -> &mut StorageMeter<T> {
return &mut self.meter
}
fn storage_key(account: &[u8], key: &[u8]) -> Vec<u8> {
let mut storage_key = Vec::with_capacity(account.len() + key.len());
storage_key.extend_from_slice(&account);
storage_key.extend_from_slice(&key);
storage_key
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{test_utils::*, tests::Test, Error};
use core::u32::MAX;
fn allocation_size(account: &AccountIdOf<Test>, key: &Key, value: Option<Vec<u8>>) -> u32 {
let mut storage: TransientStorage<Test> = TransientStorage::<Test>::new(MAX);
storage
.write(account, key, value, false)
.expect("Could not write to transient storage.");
storage.meter().current().amount
}
#[test]
fn read_write_works() {
let mut storage: TransientStorage<Test> = TransientStorage::<Test>::new(2048);
assert_eq!(
storage.write(&ALICE, &Key::Fix([1; 32]), Some(vec![1]), false),
Ok(WriteOutcome::New)
);
assert_eq!(
storage.write(&ALICE, &Key::Fix([2; 32]), Some(vec![2]), true),
Ok(WriteOutcome::New)
);
assert_eq!(
storage.write(&BOB, &Key::Fix([3; 32]), Some(vec![3]), false),
Ok(WriteOutcome::New)
);
assert_eq!(storage.read(&ALICE, &Key::Fix([1; 32])), Some(vec![1]));
assert_eq!(storage.read(&ALICE, &Key::Fix([2; 32])), Some(vec![2]));
assert_eq!(storage.read(&BOB, &Key::Fix([3; 32])), Some(vec![3]));
assert_eq!(
storage.write(&ALICE, &Key::Fix([2; 32]), Some(vec![4, 5]), false),
Ok(WriteOutcome::Overwritten(1))
);
assert_eq!(
storage.write(&BOB, &Key::Fix([3; 32]), Some(vec![6, 7]), true),
Ok(WriteOutcome::Taken(vec![3]))
);
assert_eq!(storage.read(&ALICE, &Key::Fix([1; 32])), Some(vec![1]));
assert_eq!(storage.read(&ALICE, &Key::Fix([2; 32])), Some(vec![4, 5]));
assert_eq!(storage.read(&BOB, &Key::Fix([3; 32])), Some(vec![6, 7]));
assert_eq!(
storage.write(&BOB, &Key::Fix([3; 32]), Some(vec![]), true),
Ok(WriteOutcome::Taken(vec![6, 7]))
);
assert_eq!(storage.read(&BOB, &Key::Fix([3; 32])), Some(vec![]));
assert_eq!(
storage.write(&BOB, &Key::Fix([3; 32]), None, true),
Ok(WriteOutcome::Taken(vec![]))
);
assert_eq!(storage.read(&BOB, &Key::Fix([3; 32])), None);
}
#[test]
fn read_write_with_var_sized_keys_works() {
let mut storage = TransientStorage::<Test>::new(2048);
assert_eq!(
storage.write(
&ALICE,
&Key::try_from_var([1; 64].to_vec()).unwrap(),
Some(vec![1]),
false
),
Ok(WriteOutcome::New)
);
assert_eq!(
storage.write(
&BOB,
&Key::try_from_var([2; 64].to_vec()).unwrap(),
Some(vec![2, 3]),
false
),
Ok(WriteOutcome::New)
);
assert_eq!(
storage.read(&ALICE, &Key::try_from_var([1; 64].to_vec()).unwrap()),
Some(vec![1])
);
assert_eq!(
storage.read(&BOB, &Key::try_from_var([2; 64].to_vec()).unwrap()),
Some(vec![2, 3])
);
assert_eq!(
storage.write(
&ALICE,
&Key::try_from_var([1; 64].to_vec()).unwrap(),
Some(vec![4, 5]),
false
),
Ok(WriteOutcome::Overwritten(1))
);
assert_eq!(
storage.read(&ALICE, &Key::try_from_var([1; 64].to_vec()).unwrap()),
Some(vec![4, 5])
);
}
#[test]
fn rollback_transaction_works() {
let mut storage = TransientStorage::<Test>::new(1024);
storage.start_transaction();
assert_eq!(
storage.write(&ALICE, &Key::Fix([1; 32]), Some(vec![1]), false),
Ok(WriteOutcome::New)
);
storage.rollback_transaction();
assert_eq!(storage.read(&ALICE, &Key::Fix([1; 32])), None)
}
#[test]
fn commit_transaction_works() {
let mut storage = TransientStorage::<Test>::new(1024);
storage.start_transaction();
assert_eq!(
storage.write(&ALICE, &Key::Fix([1; 32]), Some(vec![1]), false),
Ok(WriteOutcome::New)
);
storage.commit_transaction();
assert_eq!(storage.read(&ALICE, &Key::Fix([1; 32])), Some(vec![1]))
}
#[test]
fn overwrite_and_commmit_transaction_works() {
let mut storage = TransientStorage::<Test>::new(1024);
storage.start_transaction();
assert_eq!(
storage.write(&ALICE, &Key::Fix([1; 32]), Some(vec![1]), false),
Ok(WriteOutcome::New)
);
assert_eq!(
storage.write(&ALICE, &Key::Fix([1; 32]), Some(vec![1, 2]), false),
Ok(WriteOutcome::Overwritten(1))
);
storage.commit_transaction();
assert_eq!(storage.read(&ALICE, &Key::Fix([1; 32])), Some(vec![1, 2]))
}
#[test]
fn rollback_in_nested_transaction_works() {
let mut storage = TransientStorage::<Test>::new(1024);
storage.start_transaction();
assert_eq!(
storage.write(&ALICE, &Key::Fix([1; 32]), Some(vec![1]), false),
Ok(WriteOutcome::New)
);
storage.start_transaction();
assert_eq!(
storage.write(&BOB, &Key::Fix([1; 32]), Some(vec![1]), false),
Ok(WriteOutcome::New)
);
storage.rollback_transaction();
storage.commit_transaction();
assert_eq!(storage.read(&ALICE, &Key::Fix([1; 32])), Some(vec![1]));
assert_eq!(storage.read(&BOB, &Key::Fix([1; 32])), None)
}
#[test]
fn commit_in_nested_transaction_works() {
let mut storage = TransientStorage::<Test>::new(1024);
storage.start_transaction();
assert_eq!(
storage.write(&ALICE, &Key::Fix([1; 32]), Some(vec![1]), false),
Ok(WriteOutcome::New)
);
storage.start_transaction();
assert_eq!(
storage.write(&BOB, &Key::Fix([1; 32]), Some(vec![2]), false),
Ok(WriteOutcome::New)
);
storage.start_transaction();
assert_eq!(
storage.write(&CHARLIE, &Key::Fix([1; 32]), Some(vec![3]), false),
Ok(WriteOutcome::New)
);
storage.commit_transaction();
storage.commit_transaction();
storage.commit_transaction();
assert_eq!(storage.read(&ALICE, &Key::Fix([1; 32])), Some(vec![1]));
assert_eq!(storage.read(&BOB, &Key::Fix([1; 32])), Some(vec![2]));
assert_eq!(storage.read(&CHARLIE, &Key::Fix([1; 32])), Some(vec![3]));
}
#[test]
fn rollback_all_transactions_works() {
let mut storage = TransientStorage::<Test>::new(1024);
storage.start_transaction();
assert_eq!(
storage.write(&ALICE, &Key::Fix([1; 32]), Some(vec![1]), false),
Ok(WriteOutcome::New)
);
storage.start_transaction();
assert_eq!(
storage.write(&BOB, &Key::Fix([1; 32]), Some(vec![2]), false),
Ok(WriteOutcome::New)
);
storage.start_transaction();
assert_eq!(
storage.write(&CHARLIE, &Key::Fix([1; 32]), Some(vec![3]), false),
Ok(WriteOutcome::New)
);
storage.commit_transaction();
storage.commit_transaction();
storage.rollback_transaction();
assert_eq!(storage.read(&ALICE, &Key::Fix([1; 32])), None);
assert_eq!(storage.read(&BOB, &Key::Fix([1; 32])), None);
assert_eq!(storage.read(&CHARLIE, &Key::Fix([1; 32])), None);
}
#[test]
fn metering_transactions_works() {
let size = allocation_size(&ALICE, &Key::Fix([1; 32]), Some(vec![1u8; 4096]));
let mut storage = TransientStorage::<Test>::new(size * 2);
storage.start_transaction();
assert_eq!(
storage.write(&ALICE, &Key::Fix([1; 32]), Some(vec![1u8; 4096]), false),
Ok(WriteOutcome::New)
);
let limit = storage.meter().current().limit;
storage.commit_transaction();
storage.start_transaction();
assert_eq!(storage.meter().current().limit, limit - size);
assert_eq!(storage.meter().current().limit - storage.meter().current().amount, size);
assert_eq!(
storage.write(&ALICE, &Key::Fix([2; 32]), Some(vec![1u8; 4096]), false),
Ok(WriteOutcome::New)
);
assert_eq!(storage.meter().current().amount, size);
storage.commit_transaction();
assert_eq!(storage.meter().total_amount(), size * 2);
}
#[test]
fn metering_nested_transactions_works() {
let size = allocation_size(&ALICE, &Key::Fix([1; 32]), Some(vec![1u8; 4096]));
let mut storage = TransientStorage::<Test>::new(size * 3);
storage.start_transaction();
let limit = storage.meter().current().limit;
assert_eq!(
storage.write(&ALICE, &Key::Fix([1; 32]), Some(vec![1u8; 4096]), false),
Ok(WriteOutcome::New)
);
storage.start_transaction();
assert_eq!(storage.meter().total_amount(), size);
assert!(storage.meter().current().limit < limit - size);
assert_eq!(
storage.write(&ALICE, &Key::Fix([2; 32]), Some(vec![1u8; 4096]), false),
Ok(WriteOutcome::New)
);
storage.commit_transaction();
assert_eq!(storage.meter().current().limit, limit);
assert_eq!(storage.meter().total_amount(), storage.meter().current().amount);
storage.commit_transaction();
}
#[test]
fn metering_transaction_fails() {
let size = allocation_size(&ALICE, &Key::Fix([1; 32]), Some(vec![1u8; 4096]));
let mut storage = TransientStorage::<Test>::new(size - 1);
storage.start_transaction();
assert_eq!(
storage.write(&ALICE, &Key::Fix([1; 32]), Some(vec![1u8; 4096]), false),
Err(Error::<Test>::OutOfTransientStorage.into())
);
assert_eq!(storage.meter.current().amount, 0);
storage.commit_transaction();
assert_eq!(storage.meter.total_amount(), 0);
}
#[test]
fn metering_nested_transactions_fails() {
let size = allocation_size(&ALICE, &Key::Fix([1; 32]), Some(vec![1u8; 4096]));
let mut storage = TransientStorage::<Test>::new(size * 2);
storage.start_transaction();
assert_eq!(
storage.write(&ALICE, &Key::Fix([1; 32]), Some(vec![1u8; 4096]), false),
Ok(WriteOutcome::New)
);
storage.start_transaction();
assert_eq!(
storage.write(&ALICE, &Key::Fix([2; 32]), Some(vec![1u8; 4096]), false),
Err(Error::<Test>::OutOfTransientStorage.into())
);
storage.commit_transaction();
storage.commit_transaction();
assert_eq!(storage.meter.total_amount(), size);
}
#[test]
fn metering_nested_transaction_with_rollback_works() {
let size = allocation_size(&ALICE, &Key::Fix([1; 32]), Some(vec![1u8; 4096]));
let mut storage = TransientStorage::<Test>::new(size * 2);
storage.start_transaction();
let limit = storage.meter.current().limit;
storage.start_transaction();
assert_eq!(
storage.write(&ALICE, &Key::Fix([2; 32]), Some(vec![1u8; 4096]), false),
Ok(WriteOutcome::New)
);
storage.rollback_transaction();
assert_eq!(storage.meter.total_amount(), 0);
assert_eq!(storage.meter.current().limit, limit);
assert_eq!(
storage.write(&ALICE, &Key::Fix([1; 32]), Some(vec![1u8; 4096]), false),
Ok(WriteOutcome::New)
);
let amount = storage.meter().current().amount;
assert_eq!(storage.meter().total_amount(), amount);
storage.commit_transaction();
}
#[test]
fn metering_with_rollback_works() {
let size = allocation_size(&ALICE, &Key::Fix([1; 32]), Some(vec![1u8; 4096]));
let mut storage = TransientStorage::<Test>::new(size * 5);
storage.start_transaction();
assert_eq!(
storage.write(&ALICE, &Key::Fix([1; 32]), Some(vec![1u8; 4096]), false),
Ok(WriteOutcome::New)
);
let amount = storage.meter.total_amount();
storage.start_transaction();
assert_eq!(
storage.write(&ALICE, &Key::Fix([2; 32]), Some(vec![1u8; 4096]), false),
Ok(WriteOutcome::New)
);
storage.start_transaction();
assert_eq!(
storage.write(&BOB, &Key::Fix([1; 32]), Some(vec![1u8; 4096]), false),
Ok(WriteOutcome::New)
);
storage.commit_transaction();
storage.rollback_transaction();
assert_eq!(storage.meter.total_amount(), amount);
storage.commit_transaction();
}
}