use super::{ProofToHashes, ProvingTrie, TrieError};
use crate::{Decode, DispatchError, Encode};
use binary_merkle_tree::{merkle_proof, merkle_root, MerkleProof};
use codec::MaxEncodedLen;
use sp_std::{collections::btree_map::BTreeMap, vec::Vec};
pub struct BasicProvingTrie<Hashing, Key, Value>
where
Hashing: sp_core::Hasher,
{
db: BTreeMap<Key, Value>,
root: Hashing::Out,
_phantom: core::marker::PhantomData<(Key, Value)>,
}
impl<Hashing, Key, Value> ProvingTrie<Hashing, Key, Value> for BasicProvingTrie<Hashing, Key, Value>
where
Hashing: sp_core::Hasher,
Hashing::Out: Encode + Decode,
Key: Encode + Decode + Ord,
Value: Encode + Decode + Clone,
{
fn generate_for<I>(items: I) -> Result<Self, DispatchError>
where
I: IntoIterator<Item = (Key, Value)>,
{
let mut db = BTreeMap::default();
for (key, value) in items.into_iter() {
db.insert(key, value);
}
let root = merkle_root::<Hashing, _>(db.iter().map(|item| item.encode()));
Ok(Self { db, root, _phantom: Default::default() })
}
fn root(&self) -> &Hashing::Out {
&self.root
}
fn query(&self, key: &Key) -> Option<Value> {
self.db.get(&key).cloned()
}
fn create_proof(&self, key: &Key) -> Result<Vec<u8>, DispatchError> {
let mut encoded = Vec::with_capacity(self.db.len());
let mut found_index = None;
for (i, (k, v)) in self.db.iter().enumerate() {
if k == key {
found_index = Some(i);
}
encoded.push((k, v).encode());
}
let index = found_index.ok_or(TrieError::IncompleteDatabase)?;
let proof = merkle_proof::<Hashing, Vec<Vec<u8>>, Vec<u8>>(encoded, index as u32);
Ok(proof.encode())
}
fn verify_proof(
root: &Hashing::Out,
proof: &[u8],
key: &Key,
value: &Value,
) -> Result<(), DispatchError> {
verify_proof::<Hashing, Key, Value>(root, proof, key, value)
}
}
impl<Hashing, Key, Value> ProofToHashes for BasicProvingTrie<Hashing, Key, Value>
where
Hashing: sp_core::Hasher,
Hashing::Out: MaxEncodedLen + Decode,
Key: Decode,
Value: Decode,
{
type Proof = [u8];
fn proof_to_hashes(proof: &[u8]) -> Result<u32, DispatchError> {
let decoded_proof: MerkleProof<Hashing::Out, Vec<u8>> =
Decode::decode(&mut &proof[..]).map_err(|_| TrieError::IncompleteProof)?;
let depth = decoded_proof.proof.len();
Ok(depth as u32)
}
}
pub fn verify_proof<Hashing, Key, Value>(
root: &Hashing::Out,
proof: &[u8],
key: &Key,
value: &Value,
) -> Result<(), DispatchError>
where
Hashing: sp_core::Hasher,
Hashing::Out: Decode,
Key: Encode + Decode,
Value: Encode + Decode,
{
let decoded_proof: MerkleProof<Hashing::Out, Vec<u8>> =
Decode::decode(&mut &proof[..]).map_err(|_| TrieError::IncompleteProof)?;
if *root != decoded_proof.root {
return Err(TrieError::RootMismatch.into());
}
if (key, value).encode() != decoded_proof.leaf {
return Err(TrieError::ValueMismatch.into());
}
if binary_merkle_tree::verify_proof::<Hashing, _, _>(
&decoded_proof.root,
decoded_proof.proof,
decoded_proof.number_of_leaves,
decoded_proof.leaf_index,
&decoded_proof.leaf,
) {
Ok(())
} else {
Err(TrieError::IncompleteProof.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::BlakeTwo256;
use sp_core::H256;
use sp_std::collections::btree_map::BTreeMap;
type BalanceTrie = BasicProvingTrie<BlakeTwo256, u32, u128>;
fn empty_root() -> H256 {
let tree = BalanceTrie::generate_for(Vec::new()).unwrap();
*tree.root()
}
fn create_balance_trie() -> BalanceTrie {
let mut map = BTreeMap::<u32, u128>::new();
for i in 0..100u32 {
map.insert(i, i.into());
}
let balance_trie = BalanceTrie::generate_for(map).unwrap();
let root = *balance_trie.root();
assert!(root != empty_root());
assert_eq!(balance_trie.query(&6u32), Some(6u128));
assert_eq!(balance_trie.query(&9u32), Some(9u128));
assert_eq!(balance_trie.query(&69u32), Some(69u128));
balance_trie
}
#[test]
fn empty_trie_works() {
let empty_trie = BalanceTrie::generate_for(Vec::new()).unwrap();
assert_eq!(*empty_trie.root(), empty_root());
}
#[test]
fn basic_end_to_end_single_value() {
let balance_trie = create_balance_trie();
let root = *balance_trie.root();
let proof = balance_trie.create_proof(&6u32).unwrap();
for i in 0..200u32 {
if i == 6 {
assert_eq!(
verify_proof::<BlakeTwo256, _, _>(&root, &proof, &i, &u128::from(i)),
Ok(())
);
assert_eq!(
verify_proof::<BlakeTwo256, _, _>(&root, &proof, &i, &u128::from(i + 1)),
Err(TrieError::ValueMismatch.into())
);
} else {
assert!(
verify_proof::<BlakeTwo256, _, _>(&root, &proof, &i, &u128::from(i)).is_err()
);
}
}
}
#[test]
fn proof_fails_with_bad_data() {
let balance_trie = create_balance_trie();
let root = *balance_trie.root();
let proof = balance_trie.create_proof(&6u32).unwrap();
assert_eq!(verify_proof::<BlakeTwo256, _, _>(&root, &proof, &6u32, &6u128), Ok(()));
assert_eq!(
verify_proof::<BlakeTwo256, _, _>(&Default::default(), &proof, &6u32, &6u128),
Err(TrieError::RootMismatch.into())
);
assert_eq!(
verify_proof::<BlakeTwo256, _, _>(&root, &[], &6u32, &6u128),
Err(TrieError::IncompleteProof.into())
);
}
#[test]
fn assert_structure_of_merkle_proof() {
let balance_trie = create_balance_trie();
let root = *balance_trie.root();
let proof = balance_trie.create_proof(&6u32).unwrap();
let decoded_proof: MerkleProof<H256, Vec<u8>> = Decode::decode(&mut &proof[..]).unwrap();
let constructed_proof = MerkleProof::<H256, Vec<u8>> {
root,
proof: decoded_proof.proof.clone(),
number_of_leaves: 100,
leaf_index: 6,
leaf: (6u32, 6u128).encode(),
};
assert_eq!(constructed_proof, decoded_proof);
}
#[test]
fn proof_to_hashes() {
let mut i: u32 = 1;
while i < 10_000_000 {
let trie = BalanceTrie::generate_for((0..i).map(|i| (i, u128::from(i)))).unwrap();
let proof = trie.create_proof(&0).unwrap();
let hashes = BalanceTrie::proof_to_hashes(&proof).unwrap();
let log2 = (i as f64).log2().ceil() as u32;
assert_eq!(hashes, log2);
i = i * 10;
}
}
}