#![cfg_attr(not(feature = "std"), no_std)]
use sp_std::{prelude::*, result::Result};
use codec::{Decode, Encode};
use sp_inherents::{InherentData, InherentIdentifier, IsFatalError};
use sp_runtime::traits::{Block as BlockT, NumberFor};
pub use sp_inherents::Error;
pub const INHERENT_IDENTIFIER: InherentIdentifier = *b"tx_proof";
pub const DEFAULT_STORAGE_PERIOD: u32 = 100800;
pub const CHUNK_SIZE: usize = 256;
#[derive(Encode, sp_runtime::RuntimeDebug)]
#[cfg_attr(feature = "std", derive(Decode))]
pub enum InherentError {
	InvalidProof,
	TrieError,
}
impl IsFatalError for InherentError {
	fn is_fatal_error(&self) -> bool {
		true
	}
}
#[derive(Encode, Decode, Clone, PartialEq, Debug, scale_info::TypeInfo)]
pub struct TransactionStorageProof {
	pub chunk: Vec<u8>,
	pub proof: Vec<Vec<u8>>,
}
pub trait TransactionStorageProofInherentData {
	fn storage_proof(&self) -> Result<Option<TransactionStorageProof>, Error>;
}
impl TransactionStorageProofInherentData for InherentData {
	fn storage_proof(&self) -> Result<Option<TransactionStorageProof>, Error> {
		self.get_data(&INHERENT_IDENTIFIER)
	}
}
#[cfg(feature = "std")]
pub struct InherentDataProvider {
	proof: Option<TransactionStorageProof>,
}
#[cfg(feature = "std")]
impl InherentDataProvider {
	pub fn new(proof: Option<TransactionStorageProof>) -> Self {
		InherentDataProvider { proof }
	}
}
#[cfg(feature = "std")]
#[async_trait::async_trait]
impl sp_inherents::InherentDataProvider for InherentDataProvider {
	async fn provide_inherent_data(&self, inherent_data: &mut InherentData) -> Result<(), Error> {
		if let Some(proof) = &self.proof {
			inherent_data.put_data(INHERENT_IDENTIFIER, proof)
		} else {
			Ok(())
		}
	}
	async fn try_handle_error(
		&self,
		identifier: &InherentIdentifier,
		mut error: &[u8],
	) -> Option<Result<(), Error>> {
		if *identifier != INHERENT_IDENTIFIER {
			return None
		}
		let error = InherentError::decode(&mut error).ok()?;
		Some(Err(Error::Application(Box::from(format!("{:?}", error)))))
	}
}
pub fn random_chunk(random_hash: &[u8], total_chunks: u32) -> u32 {
	let mut buf = [0u8; 8];
	buf.copy_from_slice(&random_hash[0..8]);
	let random_u64 = u64::from_be_bytes(buf);
	(random_u64 % total_chunks as u64) as u32
}
pub fn encode_index(input: u32) -> Vec<u8> {
	codec::Encode::encode(&codec::Compact(input))
}
pub trait IndexedBody<B: BlockT> {
	fn block_indexed_body(&self, number: NumberFor<B>) -> Result<Option<Vec<Vec<u8>>>, Error>;
	fn number(&self, hash: B::Hash) -> Result<Option<NumberFor<B>>, Error>;
}
#[cfg(feature = "std")]
pub mod registration {
	use super::*;
	use sp_runtime::traits::{Block as BlockT, One, Saturating, Zero};
	use sp_trie::TrieMut;
	type Hasher = sp_core::Blake2Hasher;
	type TrieLayout = sp_trie::LayoutV1<Hasher>;
	pub fn new_data_provider<B, C>(
		client: &C,
		parent: &B::Hash,
	) -> Result<InherentDataProvider, Error>
	where
		B: BlockT,
		C: IndexedBody<B>,
	{
		let parent_number = client.number(*parent)?.unwrap_or(Zero::zero());
		let number = parent_number
			.saturating_add(One::one())
			.saturating_sub(DEFAULT_STORAGE_PERIOD.into());
		if number.is_zero() {
			return Ok(InherentDataProvider::new(None))
		}
		let proof = match client.block_indexed_body(number)? {
			Some(transactions) if !transactions.is_empty() =>
				Some(build_proof(parent.as_ref(), transactions)?),
			Some(_) | None => {
				None
			},
		};
		Ok(InherentDataProvider::new(proof))
	}
	pub fn build_proof(
		random_hash: &[u8],
		transactions: Vec<Vec<u8>>,
	) -> Result<TransactionStorageProof, Error> {
		let mut db = sp_trie::MemoryDB::<Hasher>::default();
		let mut target_chunk = None;
		let mut target_root = Default::default();
		let mut target_chunk_key = Default::default();
		let mut chunk_proof = Default::default();
		let total_chunks: u64 = transactions
			.iter()
			.map(|t| ((t.len() + CHUNK_SIZE - 1) / CHUNK_SIZE) as u64)
			.sum();
		let mut buf = [0u8; 8];
		buf.copy_from_slice(&random_hash[0..8]);
		let random_u64 = u64::from_be_bytes(buf);
		let target_chunk_index = random_u64 % total_chunks;
		let mut chunk_index = 0;
		for transaction in transactions {
			let mut transaction_root = sp_trie::empty_trie_root::<TrieLayout>();
			{
				let mut trie =
					sp_trie::TrieDBMutBuilder::<TrieLayout>::new(&mut db, &mut transaction_root)
						.build();
				let chunks = transaction.chunks(CHUNK_SIZE).map(|c| c.to_vec());
				for (index, chunk) in chunks.enumerate() {
					let index = encode_index(index as u32);
					trie.insert(&index, &chunk).map_err(|e| Error::Application(Box::new(e)))?;
					if chunk_index == target_chunk_index {
						target_chunk = Some(chunk);
						target_chunk_key = index;
					}
					chunk_index += 1;
				}
				trie.commit();
			}
			if target_chunk.is_some() && target_root == Default::default() {
				target_root = transaction_root;
				chunk_proof = sp_trie::generate_trie_proof::<TrieLayout, _, _, _>(
					&db,
					transaction_root,
					&[target_chunk_key.clone()],
				)
				.map_err(|e| Error::Application(Box::new(e)))?;
			}
		}
		Ok(TransactionStorageProof { proof: chunk_proof, chunk: target_chunk.unwrap() })
	}
	#[test]
	fn build_proof_check() {
		use std::str::FromStr;
		let random = [0u8; 32];
		let proof = build_proof(&random, vec![vec![42]]).unwrap();
		let root = sp_core::H256::from_str(
			"0xff8611a4d212fc161dae19dd57f0f1ba9309f45d6207da13f2d3eab4c6839e91",
		)
		.unwrap();
		sp_trie::verify_trie_proof::<TrieLayout, _, _, _>(
			&root,
			&proof.proof,
			&[(encode_index(0), Some(proof.chunk))],
		)
		.unwrap();
	}
}