use codec::{Decode, Encode};
use polkadot_node_primitives::{AvailableData, Proof};
use polkadot_primitives::{BlakeTwo256, Hash as H256, HashT};
use sp_core::Blake2Hasher;
use sp_trie::{
trie_types::{TrieDBBuilder, TrieDBMutBuilderV0 as TrieDBMutBuilder},
LayoutV0, MemoryDB, Trie, TrieMut, EMPTY_PREFIX,
};
use thiserror::Error;
use novelpoly::{CodeParams, WrappedShard};
const MAX_VALIDATORS: usize = novelpoly::f2e16::FIELD_SIZE;
#[derive(Debug, Clone, PartialEq, Error)]
pub enum Error {
#[error("There are too many validators")]
TooManyValidators,
#[error("Expected at least 2 validators")]
NotEnoughValidators,
#[error("Validator count mismatches between encoding and decoding")]
WrongValidatorCount,
#[error("Not enough chunks to reconstruct message")]
NotEnoughChunks,
#[error("Too many chunks present")]
TooManyChunks,
#[error("Chunks are not uniform, mismatch in length or are zero sized")]
NonUniformChunks,
#[error("Uneven length is not valid for field GF(2^16)")]
UnevenLength,
#[error("Chunk is out of bounds: {chunk_index} not included in 0..{n_validators}")]
ChunkIndexOutOfBounds { chunk_index: usize, n_validators: usize },
#[error("Reconstructed payload invalid")]
BadPayload,
#[error("Unable to decode reconstructed payload: {0}")]
Decode(#[source] codec::Error),
#[error("Invalid branch proof")]
InvalidBranchProof,
#[error("Branch is out of bounds")]
BranchOutOfBounds,
#[error("An unknown error has appeared when reconstructing erasure code chunks")]
UnknownReconstruction,
#[error("An unknown error has appeared when deriving code parameters from validator count")]
UnknownCodeParam,
}
impl From<novelpoly::Error> for Error {
fn from(error: novelpoly::Error) -> Self {
match error {
novelpoly::Error::NeedMoreShards { .. } => Self::NotEnoughChunks,
novelpoly::Error::ParamterMustBePowerOf2 { .. } => Self::UnevenLength,
novelpoly::Error::WantedShardCountTooHigh(_) => Self::TooManyValidators,
novelpoly::Error::WantedShardCountTooLow(_) => Self::NotEnoughValidators,
novelpoly::Error::PayloadSizeIsZero { .. } => Self::BadPayload,
novelpoly::Error::InconsistentShardLengths { .. } => Self::NonUniformChunks,
_ => Self::UnknownReconstruction,
}
}
}
pub const fn recovery_threshold(n_validators: usize) -> Result<usize, Error> {
if n_validators > MAX_VALIDATORS {
return Err(Error::TooManyValidators)
}
if n_validators <= 1 {
return Err(Error::NotEnoughValidators)
}
let needed = n_validators.saturating_sub(1) / 3;
Ok(needed + 1)
}
pub fn systematic_recovery_threshold(n_validators: usize) -> Result<usize, Error> {
code_params(n_validators).map(|params| params.k())
}
fn code_params(n_validators: usize) -> Result<CodeParams, Error> {
let n_wanted = n_validators;
let k_wanted = recovery_threshold(n_wanted)?;
if n_wanted > MAX_VALIDATORS as usize {
return Err(Error::TooManyValidators)
}
CodeParams::derive_parameters(n_wanted, k_wanted).map_err(|e| match e {
novelpoly::Error::WantedShardCountTooHigh(_) => Error::TooManyValidators,
novelpoly::Error::WantedShardCountTooLow(_) => Error::NotEnoughValidators,
_ => Error::UnknownCodeParam,
})
}
pub fn reconstruct_from_systematic_v1(
n_validators: usize,
chunks: Vec<Vec<u8>>,
) -> Result<AvailableData, Error> {
reconstruct_from_systematic(n_validators, chunks)
}
pub fn reconstruct_from_systematic<T: Decode>(
n_validators: usize,
chunks: Vec<Vec<u8>>,
) -> Result<T, Error> {
let code_params = code_params(n_validators)?;
let k = code_params.k();
for chunk_data in chunks.iter().take(k) {
if chunk_data.len() % 2 != 0 {
return Err(Error::UnevenLength)
}
}
let bytes = code_params.make_encoder().reconstruct_from_systematic(
chunks.into_iter().take(k).map(|data| WrappedShard::new(data)).collect(),
)?;
Decode::decode(&mut &bytes[..]).map_err(|err| Error::Decode(err))
}
pub fn obtain_chunks_v1(n_validators: usize, data: &AvailableData) -> Result<Vec<Vec<u8>>, Error> {
obtain_chunks(n_validators, data)
}
pub fn obtain_chunks<T: Encode>(n_validators: usize, data: &T) -> Result<Vec<Vec<u8>>, Error> {
let params = code_params(n_validators)?;
let encoded = data.encode();
if encoded.is_empty() {
return Err(Error::BadPayload)
}
let shards = params
.make_encoder()
.encode::<WrappedShard>(&encoded[..])
.expect("Payload non-empty, shard sizes are uniform, and validator numbers checked; qed");
Ok(shards.into_iter().map(|w: WrappedShard| w.into_inner()).collect())
}
pub fn reconstruct_v1<'a, I: 'a>(n_validators: usize, chunks: I) -> Result<AvailableData, Error>
where
I: IntoIterator<Item = (&'a [u8], usize)>,
{
reconstruct(n_validators, chunks)
}
pub fn reconstruct<'a, I: 'a, T: Decode>(n_validators: usize, chunks: I) -> Result<T, Error>
where
I: IntoIterator<Item = (&'a [u8], usize)>,
{
let params = code_params(n_validators)?;
let mut received_shards: Vec<Option<WrappedShard>> = vec![None; n_validators];
for (chunk_data, chunk_idx) in chunks.into_iter().take(n_validators) {
if chunk_data.len() % 2 != 0 {
return Err(Error::UnevenLength)
}
received_shards[chunk_idx] = Some(WrappedShard::new(chunk_data.to_vec()));
}
let payload_bytes = params.make_encoder().reconstruct(received_shards)?;
Decode::decode(&mut &payload_bytes[..]).map_err(|_| Error::BadPayload)
}
pub struct Branches<'a, I> {
trie_storage: MemoryDB<Blake2Hasher>,
root: H256,
chunks: &'a [I],
current_pos: usize,
}
impl<'a, I: AsRef<[u8]>> Branches<'a, I> {
pub fn root(&self) -> H256 {
self.root
}
}
impl<'a, I: AsRef<[u8]>> Iterator for Branches<'a, I> {
type Item = (Proof, &'a [u8]);
fn next(&mut self) -> Option<Self::Item> {
use sp_trie::Recorder;
let mut recorder = Recorder::<LayoutV0<Blake2Hasher>>::new();
let res = {
let trie = TrieDBBuilder::new(&self.trie_storage, &self.root)
.with_recorder(&mut recorder)
.build();
(self.current_pos as u32).using_encoded(|s| trie.get(s))
};
match res.expect("all nodes in trie present; qed") {
Some(_) => {
let nodes: Vec<Vec<u8>> = recorder.drain().into_iter().map(|r| r.data).collect();
let chunk = self.chunks.get(self.current_pos).expect(
"there is a one-to-one mapping of chunks to valid merkle branches; qed",
);
self.current_pos += 1;
Proof::try_from(nodes).ok().map(|proof| (proof, chunk.as_ref()))
},
None => None,
}
}
}
pub fn branches<'a, I: 'a>(chunks: &'a [I]) -> Branches<'a, I>
where
I: AsRef<[u8]>,
{
let mut trie_storage: MemoryDB<Blake2Hasher> = MemoryDB::default();
let mut root = H256::default();
{
let mut trie = TrieDBMutBuilder::new(&mut trie_storage, &mut root).build();
for (i, chunk) in chunks.as_ref().iter().enumerate() {
(i as u32).using_encoded(|encoded_index| {
let chunk_hash = BlakeTwo256::hash(chunk.as_ref());
trie.insert(encoded_index, chunk_hash.as_ref())
.expect("a fresh trie stored in memory cannot have errors loading nodes; qed");
})
}
}
Branches { trie_storage, root, chunks, current_pos: 0 }
}
pub fn branch_hash(root: &H256, branch_nodes: &Proof, index: usize) -> Result<H256, Error> {
let mut trie_storage: MemoryDB<Blake2Hasher> = MemoryDB::default();
for node in branch_nodes.iter() {
(&mut trie_storage as &mut sp_trie::HashDB<_>).insert(EMPTY_PREFIX, node);
}
let trie = TrieDBBuilder::new(&trie_storage, &root).build();
let res = (index as u32).using_encoded(|key| {
trie.get_with(key, |raw_hash: &[u8]| H256::decode(&mut &raw_hash[..]))
});
match res {
Ok(Some(Ok(hash))) => Ok(hash),
Ok(Some(Err(_))) => Err(Error::InvalidBranchProof), Ok(None) => Err(Error::BranchOutOfBounds),
Err(_) => Err(Error::InvalidBranchProof),
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use polkadot_node_primitives::{AvailableData, BlockData, PoV};
use polkadot_primitives::{HeadData, PersistedValidationData};
use quickcheck::{Arbitrary, Gen, QuickCheck};
const KEY_INDEX_NIBBLE_SIZE: usize = 4;
#[derive(Clone, Debug)]
struct ArbitraryAvailableData(AvailableData);
impl Arbitrary for ArbitraryAvailableData {
fn arbitrary(g: &mut Gen) -> Self {
let pov_len = (u32::arbitrary(g) % (1024 * 1024)).max(2);
let pov = (0..pov_len).map(|_| u8::arbitrary(g)).collect();
let pvd = PersistedValidationData {
parent_head: HeadData((0..u16::arbitrary(g)).map(|_| u8::arbitrary(g)).collect()),
relay_parent_number: u32::arbitrary(g),
relay_parent_storage_root: [u8::arbitrary(g); 32].into(),
max_pov_size: u32::arbitrary(g),
};
ArbitraryAvailableData(AvailableData {
pov: Arc::new(PoV { block_data: BlockData(pov) }),
validation_data: pvd,
})
}
}
#[test]
fn field_order_is_right_size() {
assert_eq!(MAX_VALIDATORS, 65536);
}
#[test]
fn round_trip_works() {
let pov = PoV { block_data: BlockData((0..255).collect()) };
let available_data = AvailableData { pov: pov.into(), validation_data: Default::default() };
let chunks = obtain_chunks(10, &available_data).unwrap();
assert_eq!(chunks.len(), 10);
let reconstructed: AvailableData = reconstruct(
10,
[(&*chunks[1], 1), (&*chunks[4], 4), (&*chunks[6], 6), (&*chunks[9], 9)]
.iter()
.cloned(),
)
.unwrap();
assert_eq!(reconstructed, available_data);
}
#[test]
fn round_trip_systematic_works() {
fn property(available_data: ArbitraryAvailableData, n_validators: u16) {
let n_validators = n_validators.max(2);
let kpow2 = systematic_recovery_threshold(n_validators as usize).unwrap();
let chunks = obtain_chunks(n_validators as usize, &available_data.0).unwrap();
assert_eq!(
reconstruct_from_systematic_v1(
n_validators as usize,
chunks.into_iter().take(kpow2).collect()
)
.unwrap(),
available_data.0
);
}
QuickCheck::new().quickcheck(property as fn(ArbitraryAvailableData, u16))
}
#[test]
fn reconstruct_does_not_panic_on_low_validator_count() {
let reconstructed = reconstruct_v1(1, [].iter().cloned());
assert_eq!(reconstructed, Err(Error::NotEnoughValidators));
}
fn generate_trie_and_generate_proofs(magnitude: u32) {
let n_validators = 2_u32.pow(magnitude) as usize;
let pov = PoV { block_data: BlockData(vec![2; n_validators / KEY_INDEX_NIBBLE_SIZE]) };
let available_data = AvailableData { pov: pov.into(), validation_data: Default::default() };
let chunks = obtain_chunks(magnitude as usize, &available_data).unwrap();
assert_eq!(chunks.len() as u32, magnitude);
let branches = branches(chunks.as_ref());
let root = branches.root();
let proofs: Vec<_> = branches.map(|(proof, _)| proof).collect();
assert_eq!(proofs.len() as u32, magnitude);
for (i, proof) in proofs.into_iter().enumerate() {
let encode = Encode::encode(&proof);
let decode = Decode::decode(&mut &encode[..]).unwrap();
assert_eq!(proof, decode);
assert_eq!(encode, Encode::encode(&decode));
assert_eq!(branch_hash(&root, &proof, i).unwrap(), BlakeTwo256::hash(&chunks[i]));
}
}
#[test]
fn roundtrip_proof_encoding() {
for i in 2..16 {
generate_trie_and_generate_proofs(i);
}
}
}