use sp_io::storage::{commit_transaction, rollback_transaction, start_transaction};
use sp_runtime::{DispatchError, TransactionOutcome, TransactionalError};
pub type Layer = u32;
pub const TRANSACTION_LEVEL_KEY: &[u8] = b":transaction_level:";
pub const TRANSACTIONAL_LIMIT: Layer = 255;
fn get_transaction_level() -> Layer {
crate::storage::unhashed::get_or_default::<Layer>(TRANSACTION_LEVEL_KEY)
}
fn set_transaction_level(level: Layer) {
crate::storage::unhashed::put::<Layer>(TRANSACTION_LEVEL_KEY, &level);
}
fn kill_transaction_level() {
crate::storage::unhashed::kill(TRANSACTION_LEVEL_KEY);
}
fn inc_transaction_level() -> Result<StorageLayerGuard, ()> {
let existing_levels = get_transaction_level();
if existing_levels >= TRANSACTIONAL_LIMIT {
return Err(())
}
set_transaction_level(existing_levels + 1);
Ok(StorageLayerGuard)
}
fn dec_transaction_level() {
let existing_levels = get_transaction_level();
if existing_levels == 0 {
log::warn!(
"We are underflowing with calculating transactional levels. Not great, but let's not panic...",
);
} else if existing_levels == 1 {
kill_transaction_level();
} else {
set_transaction_level(existing_levels - 1);
}
}
struct StorageLayerGuard;
impl Drop for StorageLayerGuard {
fn drop(&mut self) {
dec_transaction_level()
}
}
pub fn is_transactional() -> bool {
get_transaction_level() > 0
}
pub fn with_transaction<T, E, F>(f: F) -> Result<T, E>
where
E: From<DispatchError>,
F: FnOnce() -> TransactionOutcome<Result<T, E>>,
{
let _guard = inc_transaction_level().map_err(|()| TransactionalError::LimitReached.into())?;
start_transaction();
match f() {
TransactionOutcome::Commit(res) => {
commit_transaction();
res
},
TransactionOutcome::Rollback(res) => {
rollback_transaction();
res
},
}
}
pub fn with_transaction_opaque_err<T, E, F>(f: F) -> Result<Result<T, E>, ()>
where
F: FnOnce() -> TransactionOutcome<Result<T, E>>,
{
with_transaction(move || -> TransactionOutcome<Result<Result<T, E>, DispatchError>> {
match f() {
TransactionOutcome::Commit(res) => TransactionOutcome::Commit(Ok(res)),
TransactionOutcome::Rollback(res) => TransactionOutcome::Rollback(Ok(res)),
}
})
.map_err(|_| ())
}
pub fn with_transaction_unchecked<R, F>(f: F) -> R
where
F: FnOnce() -> TransactionOutcome<R>,
{
let maybe_guard = inc_transaction_level();
if maybe_guard.is_err() {
log::warn!(
"The transactional layer limit has been reached, and new transactional layers are being
spawned with `with_transaction_unchecked`. This could be caused by someone trying to
attack your chain, and you should investigate usage of `with_transaction_unchecked` and
potentially migrate to `with_transaction`, which enforces a transactional limit.",
);
}
start_transaction();
match f() {
TransactionOutcome::Commit(res) => {
commit_transaction();
res
},
TransactionOutcome::Rollback(res) => {
rollback_transaction();
res
},
}
}
pub fn with_storage_layer<T, E, F>(f: F) -> Result<T, E>
where
E: From<DispatchError>,
F: FnOnce() -> Result<T, E>,
{
with_transaction(|| {
let r = f();
if r.is_ok() {
TransactionOutcome::Commit(r)
} else {
TransactionOutcome::Rollback(r)
}
})
}
pub fn in_storage_layer<T, E, F>(f: F) -> Result<T, E>
where
E: From<DispatchError>,
F: FnOnce() -> Result<T, E>,
{
if is_transactional() {
f()
} else {
with_storage_layer(f)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{assert_noop, assert_ok};
use sp_io::TestExternalities;
use sp_runtime::DispatchResult;
#[test]
fn is_transactional_should_return_false() {
TestExternalities::default().execute_with(|| {
assert!(!is_transactional());
});
}
#[test]
fn is_transactional_should_not_error_in_with_transaction() {
TestExternalities::default().execute_with(|| {
assert_ok!(with_transaction(|| -> TransactionOutcome<DispatchResult> {
assert!(is_transactional());
TransactionOutcome::Commit(Ok(()))
}));
assert_noop!(
with_transaction(|| -> TransactionOutcome<DispatchResult> {
assert!(is_transactional());
TransactionOutcome::Rollback(Err("revert".into()))
}),
"revert"
);
});
}
fn recursive_transactional(num: u32) -> DispatchResult {
if num == 0 {
return Ok(())
}
with_transaction(|| -> TransactionOutcome<DispatchResult> {
let res = recursive_transactional(num - 1);
TransactionOutcome::Commit(res)
})
}
#[test]
fn transaction_limit_should_work() {
TestExternalities::default().execute_with(|| {
assert_eq!(get_transaction_level(), 0);
assert_ok!(with_transaction(|| -> TransactionOutcome<DispatchResult> {
assert_eq!(get_transaction_level(), 1);
TransactionOutcome::Commit(Ok(()))
}));
assert_ok!(with_transaction(|| -> TransactionOutcome<DispatchResult> {
assert_eq!(get_transaction_level(), 1);
let res = with_transaction(|| -> TransactionOutcome<DispatchResult> {
assert_eq!(get_transaction_level(), 2);
TransactionOutcome::Commit(Ok(()))
});
TransactionOutcome::Commit(res)
}));
assert_ok!(recursive_transactional(255));
assert_noop!(
recursive_transactional(256),
sp_runtime::TransactionalError::LimitReached
);
assert_eq!(get_transaction_level(), 0);
});
}
#[test]
fn in_storage_layer_works() {
TestExternalities::default().execute_with(|| {
assert_eq!(get_transaction_level(), 0);
let res = in_storage_layer(|| -> DispatchResult {
assert_eq!(get_transaction_level(), 1);
in_storage_layer(|| -> DispatchResult {
assert_eq!(get_transaction_level(), 1);
Ok(())
})
});
assert_ok!(res);
let res = in_storage_layer(|| -> DispatchResult {
assert_eq!(get_transaction_level(), 1);
in_storage_layer(|| -> DispatchResult {
assert_eq!(get_transaction_level(), 1);
Err("epic fail".into())
})
});
assert_noop!(res, "epic fail");
});
}
}