1use super::{ProofToHashes, ProvingTrie, TrieError};
24use crate::{Decode, DispatchError, Encode};
25use alloc::{collections::BTreeMap, vec::Vec};
26use binary_merkle_tree::{merkle_proof, merkle_root, MerkleProof};
27use codec::MaxEncodedLen;
28
29pub struct BasicProvingTrie<Hashing, Key, Value>
32where
33	Hashing: sp_core::Hasher,
34{
35	db: BTreeMap<Key, Value>,
36	root: Hashing::Out,
37	_phantom: core::marker::PhantomData<(Key, Value)>,
38}
39
40impl<Hashing, Key, Value> ProvingTrie<Hashing, Key, Value> for BasicProvingTrie<Hashing, Key, Value>
41where
42	Hashing: sp_core::Hasher,
43	Hashing::Out: Encode + Decode,
44	Key: Encode + Decode + Ord,
45	Value: Encode + Decode + Clone,
46{
47	fn generate_for<I>(items: I) -> Result<Self, DispatchError>
49	where
50		I: IntoIterator<Item = (Key, Value)>,
51	{
52		let mut db = BTreeMap::default();
53		for (key, value) in items.into_iter() {
54			db.insert(key, value);
55		}
56		let root = merkle_root::<Hashing, _>(db.iter().map(|item| item.encode()));
57		Ok(Self { db, root, _phantom: Default::default() })
58	}
59
60	fn root(&self) -> &Hashing::Out {
62		&self.root
63	}
64
65	fn query(&self, key: &Key) -> Option<Value> {
68		self.db.get(&key).cloned()
69	}
70
71	fn create_proof(&self, key: &Key) -> Result<Vec<u8>, DispatchError> {
74		let mut encoded = Vec::with_capacity(self.db.len());
75		let mut found_index = None;
76
77		for (i, (k, v)) in self.db.iter().enumerate() {
79			if k == key {
81				found_index = Some(i);
82			}
83
84			encoded.push((k, v).encode());
85		}
86
87		let index = found_index.ok_or(TrieError::IncompleteDatabase)?;
88		let proof = merkle_proof::<Hashing, Vec<Vec<u8>>, Vec<u8>>(encoded, index as u32);
89		Ok(proof.encode())
90	}
91
92	fn verify_proof(
94		root: &Hashing::Out,
95		proof: &[u8],
96		key: &Key,
97		value: &Value,
98	) -> Result<(), DispatchError> {
99		verify_proof::<Hashing, Key, Value>(root, proof, key, value)
100	}
101}
102
103impl<Hashing, Key, Value> ProofToHashes for BasicProvingTrie<Hashing, Key, Value>
104where
105	Hashing: sp_core::Hasher,
106	Hashing::Out: MaxEncodedLen + Decode,
107	Key: Decode,
108	Value: Decode,
109{
110	type Proof = [u8];
112	fn proof_to_hashes(proof: &[u8]) -> Result<u32, DispatchError> {
116		let decoded_proof: MerkleProof<Hashing::Out, Vec<u8>> =
117			Decode::decode(&mut &proof[..]).map_err(|_| TrieError::IncompleteProof)?;
118		let depth = decoded_proof.proof.len();
119		Ok(depth as u32)
120	}
121}
122
123pub fn verify_proof<Hashing, Key, Value>(
125	root: &Hashing::Out,
126	proof: &[u8],
127	key: &Key,
128	value: &Value,
129) -> Result<(), DispatchError>
130where
131	Hashing: sp_core::Hasher,
132	Hashing::Out: Decode,
133	Key: Encode + Decode,
134	Value: Encode + Decode,
135{
136	let decoded_proof: MerkleProof<Hashing::Out, Vec<u8>> =
137		Decode::decode(&mut &proof[..]).map_err(|_| TrieError::IncompleteProof)?;
138	if *root != decoded_proof.root {
139		return Err(TrieError::RootMismatch.into());
140	}
141
142	if (key, value).encode() != decoded_proof.leaf {
143		return Err(TrieError::ValueMismatch.into());
144	}
145
146	if binary_merkle_tree::verify_proof::<Hashing, _, _>(
147		&decoded_proof.root,
148		decoded_proof.proof,
149		decoded_proof.number_of_leaves,
150		decoded_proof.leaf_index,
151		&decoded_proof.leaf,
152	) {
153		Ok(())
154	} else {
155		Err(TrieError::IncompleteProof.into())
156	}
157}
158
159#[cfg(test)]
160mod tests {
161	use super::*;
162	use crate::traits::BlakeTwo256;
163	use sp_core::H256;
164	use std::collections::BTreeMap;
165
166	type BalanceTrie = BasicProvingTrie<BlakeTwo256, u32, u128>;
168
169	fn empty_root() -> H256 {
171		let tree = BalanceTrie::generate_for(Vec::new()).unwrap();
172		*tree.root()
173	}
174
175	fn create_balance_trie() -> BalanceTrie {
176		let mut map = BTreeMap::<u32, u128>::new();
178		for i in 0..100u32 {
179			map.insert(i, i.into());
180		}
181
182		let balance_trie = BalanceTrie::generate_for(map).unwrap();
184
185		let root = *balance_trie.root();
187		assert!(root != empty_root());
188
189		assert_eq!(balance_trie.query(&6u32), Some(6u128));
191		assert_eq!(balance_trie.query(&9u32), Some(9u128));
192		assert_eq!(balance_trie.query(&69u32), Some(69u128));
193
194		balance_trie
195	}
196
197	#[test]
198	fn empty_trie_works() {
199		let empty_trie = BalanceTrie::generate_for(Vec::new()).unwrap();
200		assert_eq!(*empty_trie.root(), empty_root());
201	}
202
203	#[test]
204	fn basic_end_to_end_single_value() {
205		let balance_trie = create_balance_trie();
206		let root = *balance_trie.root();
207
208		let proof = balance_trie.create_proof(&6u32).unwrap();
210
211		for i in 0..200u32 {
213			if i == 6 {
214				assert_eq!(
215					verify_proof::<BlakeTwo256, _, _>(&root, &proof, &i, &u128::from(i)),
216					Ok(())
217				);
218				assert_eq!(
220					verify_proof::<BlakeTwo256, _, _>(&root, &proof, &i, &u128::from(i + 1)),
221					Err(TrieError::ValueMismatch.into())
222				);
223			} else {
224				assert!(
225					verify_proof::<BlakeTwo256, _, _>(&root, &proof, &i, &u128::from(i)).is_err()
226				);
227			}
228		}
229	}
230
231	#[test]
232	fn proof_fails_with_bad_data() {
233		let balance_trie = create_balance_trie();
234		let root = *balance_trie.root();
235
236		let proof = balance_trie.create_proof(&6u32).unwrap();
238
239		assert_eq!(verify_proof::<BlakeTwo256, _, _>(&root, &proof, &6u32, &6u128), Ok(()));
241
242		assert_eq!(
244			verify_proof::<BlakeTwo256, _, _>(&Default::default(), &proof, &6u32, &6u128),
245			Err(TrieError::RootMismatch.into())
246		);
247
248		assert_eq!(
250			verify_proof::<BlakeTwo256, _, _>(&root, &[], &6u32, &6u128),
251			Err(TrieError::IncompleteProof.into())
252		);
253	}
254
255	#[test]
258	fn assert_structure_of_merkle_proof() {
259		let balance_trie = create_balance_trie();
260		let root = *balance_trie.root();
261		let proof = balance_trie.create_proof(&6u32).unwrap();
263		let decoded_proof: MerkleProof<H256, Vec<u8>> = Decode::decode(&mut &proof[..]).unwrap();
264
265		let constructed_proof = MerkleProof::<H256, Vec<u8>> {
266			root,
267			proof: decoded_proof.proof.clone(),
268			number_of_leaves: 100,
269			leaf_index: 6,
270			leaf: (6u32, 6u128).encode(),
271		};
272		assert_eq!(constructed_proof, decoded_proof);
273	}
274
275	#[test]
276	fn proof_to_hashes() {
277		let mut i: u32 = 1;
278		while i < 10_000_000 {
279			let trie = BalanceTrie::generate_for((0..i).map(|i| (i, u128::from(i)))).unwrap();
280			let proof = trie.create_proof(&0).unwrap();
281			let hashes = BalanceTrie::proof_to_hashes(&proof).unwrap();
282			let log2 = (i as f64).log2().ceil() as u32;
283
284			assert_eq!(hashes, log2);
285			i = i * 10;
286		}
287	}
288}