use super::{ProofToHashes, ProvingTrie, TrieError};
use crate::{Decode, DispatchError, Encode};
use codec::MaxEncodedLen;
use sp_std::vec::Vec;
use sp_trie::{
trie_types::{TrieDBBuilder, TrieDBMutBuilderV1},
LayoutV1, MemoryDB, Trie, TrieMut,
};
pub struct BasicProvingTrie<Hashing, Key, Value>
where
Hashing: sp_core::Hasher,
{
db: MemoryDB<Hashing>,
root: Hashing::Out,
_phantom: core::marker::PhantomData<(Key, Value)>,
}
impl<Hashing, Key, Value> BasicProvingTrie<Hashing, Key, Value>
where
Hashing: sp_core::Hasher,
Key: Encode,
{
pub fn create_multi_proof(&self, keys: &[Key]) -> Result<Vec<u8>, DispatchError> {
sp_trie::generate_trie_proof::<LayoutV1<Hashing>, _, _, _>(
&self.db,
self.root,
&keys.into_iter().map(|k| k.encode()).collect::<Vec<Vec<u8>>>(),
)
.map_err(|err| TrieError::from(*err).into())
.map(|structured_proof| structured_proof.encode())
}
}
impl<Hashing, Key, Value> ProvingTrie<Hashing, Key, Value> for BasicProvingTrie<Hashing, Key, Value>
where
Hashing: sp_core::Hasher,
Key: Encode,
Value: Encode + Decode,
{
fn generate_for<I>(items: I) -> Result<Self, DispatchError>
where
I: IntoIterator<Item = (Key, Value)>,
{
let mut db = MemoryDB::default();
let mut root = Default::default();
{
let mut trie = TrieDBMutBuilderV1::new(&mut db, &mut root).build();
for (key, value) in items.into_iter() {
key.using_encoded(|k| value.using_encoded(|v| trie.insert(k, v)))
.map_err(|_| "failed to insert into trie")?;
}
}
Ok(Self { db, root, _phantom: Default::default() })
}
fn root(&self) -> &Hashing::Out {
&self.root
}
fn query(&self, key: &Key) -> Option<Value> {
let trie = TrieDBBuilder::new(&self.db, &self.root).build();
key.using_encoded(|s| trie.get(s))
.ok()?
.and_then(|raw| Value::decode(&mut &*raw).ok())
}
fn create_proof(&self, key: &Key) -> Result<Vec<u8>, DispatchError> {
sp_trie::generate_trie_proof::<LayoutV1<Hashing>, _, _, _>(
&self.db,
self.root,
&[key.encode()],
)
.map_err(|err| TrieError::from(*err).into())
.map(|structured_proof| structured_proof.encode())
}
fn verify_proof(
root: &Hashing::Out,
proof: &[u8],
key: &Key,
value: &Value,
) -> Result<(), DispatchError> {
verify_proof::<Hashing, Key, Value>(root, proof, key, value)
}
}
impl<Hashing, Key, Value> ProofToHashes for BasicProvingTrie<Hashing, Key, Value>
where
Hashing: sp_core::Hasher,
Hashing::Out: MaxEncodedLen,
{
type Proof = [u8];
fn proof_to_hashes(proof: &[u8]) -> Result<u32, DispatchError> {
use codec::DecodeLength;
let depth =
<Vec<Vec<u8>> as DecodeLength>::len(proof).map_err(|_| TrieError::DecodeError)?;
Ok(depth as u32)
}
}
pub fn verify_proof<Hashing, Key, Value>(
root: &Hashing::Out,
proof: &[u8],
key: &Key,
value: &Value,
) -> Result<(), DispatchError>
where
Hashing: sp_core::Hasher,
Key: Encode,
Value: Encode,
{
let structured_proof: Vec<Vec<u8>> =
Decode::decode(&mut &proof[..]).map_err(|_| TrieError::DecodeError)?;
sp_trie::verify_trie_proof::<LayoutV1<Hashing>, _, _, _>(
&root,
&structured_proof,
&[(key.encode(), Some(value.encode()))],
)
.map_err(|err| TrieError::from(err).into())
}
pub fn verify_multi_proof<Hashing, Key, Value>(
root: &Hashing::Out,
proof: &[u8],
items: &[(Key, Value)],
) -> Result<(), DispatchError>
where
Hashing: sp_core::Hasher,
Key: Encode,
Value: Encode,
{
let structured_proof: Vec<Vec<u8>> =
Decode::decode(&mut &proof[..]).map_err(|_| TrieError::DecodeError)?;
let items_encoded = items
.into_iter()
.map(|(key, value)| (key.encode(), Some(value.encode())))
.collect::<Vec<(Vec<u8>, Option<Vec<u8>>)>>();
sp_trie::verify_trie_proof::<LayoutV1<Hashing>, _, _, _>(
&root,
&structured_proof,
&items_encoded,
)
.map_err(|err| TrieError::from(err).into())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::BlakeTwo256;
use sp_core::H256;
use sp_std::collections::btree_map::BTreeMap;
type BalanceTrie = BasicProvingTrie<BlakeTwo256, u32, u128>;
fn empty_root() -> H256 {
sp_trie::empty_trie_root::<LayoutV1<BlakeTwo256>>()
}
fn create_balance_trie() -> BalanceTrie {
let mut map = BTreeMap::<u32, u128>::new();
for i in 0..100u32 {
map.insert(i, i.into());
}
let balance_trie = BalanceTrie::generate_for(map).unwrap();
let root = *balance_trie.root();
assert!(root != empty_root());
assert_eq!(balance_trie.query(&6u32), Some(6u128));
assert_eq!(balance_trie.query(&9u32), Some(9u128));
assert_eq!(balance_trie.query(&69u32), Some(69u128));
assert_eq!(balance_trie.query(&6969u32), None);
balance_trie
}
#[test]
fn empty_trie_works() {
let empty_trie = BalanceTrie::generate_for(Vec::new()).unwrap();
assert_eq!(*empty_trie.root(), empty_root());
}
#[test]
fn basic_end_to_end_single_value() {
let balance_trie = create_balance_trie();
let root = *balance_trie.root();
let proof = balance_trie.create_proof(&6u32).unwrap();
for i in 0..200u32 {
if i == 6 {
assert_eq!(
verify_proof::<BlakeTwo256, _, _>(&root, &proof, &i, &u128::from(i)),
Ok(())
);
assert_eq!(
verify_proof::<BlakeTwo256, _, _>(&root, &proof, &i, &u128::from(i + 1)),
Err(TrieError::RootMismatch.into())
);
} else {
assert!(
verify_proof::<BlakeTwo256, _, _>(&root, &proof, &i, &u128::from(i)).is_err()
);
}
}
}
#[test]
fn basic_end_to_end_multi() {
let balance_trie = create_balance_trie();
let root = *balance_trie.root();
let proof = balance_trie.create_multi_proof(&[6u32, 9u32, 69u32]).unwrap();
let items = [(6u32, 6u128), (9u32, 9u128), (69u32, 69u128)];
assert_eq!(verify_multi_proof::<BlakeTwo256, _, _>(&root, &proof, &items), Ok(()));
}
#[test]
fn proof_fails_with_bad_data() {
let balance_trie = create_balance_trie();
let root = *balance_trie.root();
let proof = balance_trie.create_proof(&6u32).unwrap();
assert_eq!(verify_proof::<BlakeTwo256, _, _>(&root, &proof, &6u32, &6u128), Ok(()));
assert_eq!(
verify_proof::<BlakeTwo256, _, _>(&Default::default(), &proof, &6u32, &6u128),
Err(TrieError::RootMismatch.into())
);
let bad_proof = balance_trie.create_proof(&99u32).unwrap();
assert_eq!(
verify_proof::<BlakeTwo256, _, _>(&root, &bad_proof, &6u32, &6u128),
Err(TrieError::ExtraneousHashReference.into())
);
}
#[test]
fn proof_to_hashes() {
let mut i: u32 = 1;
let log16 = |x: u32| -> u32 {
let x_f64 = x as f64;
let log16_x = (x_f64.ln() / 16_f64.ln()).ceil();
log16_x as u32
};
while i < 10_000_000 {
let trie = BalanceTrie::generate_for((0..i).map(|i| (i, u128::from(i)))).unwrap();
let proof = trie.create_proof(&0).unwrap();
let hashes = BalanceTrie::proof_to_hashes(&proof).unwrap();
let log16 = log16(i).max(1);
assert_eq!(hashes, log16);
i = i * 10;
}
}
}