use crate::Error;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
use core::convert::TryInto;
use core::fmt::Debug;
use unsigned_varint::encode as varint_encode;
#[cfg(feature = "std")]
use std::io;
#[cfg(not(feature = "std"))]
use core2::io;
#[derive(Clone, Copy, Debug, Eq, Ord, PartialOrd)]
pub struct Multihash<const S: usize> {
code: u64,
size: u8,
digest: [u8; S],
}
impl<const S: usize> Default for Multihash<S> {
fn default() -> Self {
Self {
code: 0,
size: 0,
digest: [0; S],
}
}
}
impl<const S: usize> Multihash<S> {
pub const fn wrap(code: u64, input_digest: &[u8]) -> Result<Self, Error> {
if input_digest.len() > S {
return Err(Error::invalid_size(input_digest.len() as _));
}
let size = input_digest.len();
let mut digest = [0; S];
let mut i = 0;
while i < size {
digest[i] = input_digest[i];
i += 1;
}
Ok(Self {
code,
size: size as u8,
digest,
})
}
pub const fn code(&self) -> u64 {
self.code
}
pub const fn size(&self) -> u8 {
self.size
}
pub fn digest(&self) -> &[u8] {
&self.digest[..self.size as usize]
}
pub fn read<R: io::Read>(r: R) -> Result<Self, Error>
where
Self: Sized,
{
let (code, size, digest) = read_multihash(r)?;
Ok(Self { code, size, digest })
}
pub fn from_bytes(mut bytes: &[u8]) -> Result<Self, Error>
where
Self: Sized,
{
let result = Self::read(&mut bytes)?;
if !bytes.is_empty() {
return Err(Error::invalid_size(bytes.len().try_into().expect(
"Currently the maximum size is 255, therefore always fits into usize",
)));
}
Ok(result)
}
pub fn write<W: io::Write>(&self, w: W) -> Result<usize, Error> {
write_multihash(w, self.code(), self.size(), self.digest())
}
pub fn encoded_len(&self) -> usize {
let mut code_buf = varint_encode::u64_buffer();
let code = varint_encode::u64(self.code, &mut code_buf);
let mut size_buf = varint_encode::u8_buffer();
let size = varint_encode::u8(self.size, &mut size_buf);
code.len() + size.len() + usize::from(self.size)
}
#[cfg(feature = "alloc")]
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(self.size().into());
let written = self
.write(&mut bytes)
.expect("writing to a vec should never fail");
debug_assert_eq!(written, bytes.len());
bytes
}
pub fn truncate(&self, size: u8) -> Self {
let mut mh = *self;
mh.size = mh.size.min(size);
mh
}
pub fn resize<const R: usize>(&self) -> Result<Multihash<R>, Error> {
let size = self.size as usize;
if size > R {
return Err(Error::invalid_size(self.size as u64));
}
let mut mh = Multihash {
code: self.code,
size: self.size,
digest: [0; R],
};
mh.digest[..size].copy_from_slice(&self.digest[..size]);
Ok(mh)
}
pub fn into_inner(self) -> (u64, [u8; S], u8) {
let Self { code, digest, size } = self;
(code, digest, size)
}
}
#[allow(clippy::derived_hash_with_manual_eq)]
impl<const S: usize> core::hash::Hash for Multihash<S> {
fn hash<T: core::hash::Hasher>(&self, state: &mut T) {
self.code.hash(state);
self.digest().hash(state);
}
}
#[cfg(feature = "alloc")]
impl<const S: usize> From<Multihash<S>> for Vec<u8> {
fn from(multihash: Multihash<S>) -> Self {
multihash.to_bytes()
}
}
impl<const A: usize, const B: usize> PartialEq<Multihash<B>> for Multihash<A> {
fn eq(&self, other: &Multihash<B>) -> bool {
self.code == other.code && self.digest() == other.digest()
}
}
#[cfg(feature = "scale-codec")]
impl<const S: usize> parity_scale_codec::Encode for Multihash<S> {
fn encode_to<EncOut: parity_scale_codec::Output + ?Sized>(&self, dest: &mut EncOut) {
self.code.encode_to(dest);
self.size.encode_to(dest);
dest.write(self.digest());
}
}
#[cfg(feature = "scale-codec")]
impl<const S: usize> parity_scale_codec::EncodeLike for Multihash<S> {}
#[cfg(feature = "scale-codec")]
impl<const S: usize> parity_scale_codec::Decode for Multihash<S> {
fn decode<DecIn: parity_scale_codec::Input>(
input: &mut DecIn,
) -> Result<Self, parity_scale_codec::Error> {
let mut mh = Multihash {
code: parity_scale_codec::Decode::decode(input)?,
size: parity_scale_codec::Decode::decode(input)?,
digest: [0; S],
};
if mh.size as usize > S {
return Err(parity_scale_codec::Error::from("invalid size"));
}
input.read(&mut mh.digest[..mh.size as usize])?;
Ok(mh)
}
}
fn write_multihash<W>(mut w: W, code: u64, size: u8, digest: &[u8]) -> Result<usize, Error>
where
W: io::Write,
{
let mut code_buf = varint_encode::u64_buffer();
let code = varint_encode::u64(code, &mut code_buf);
let mut size_buf = varint_encode::u8_buffer();
let size = varint_encode::u8(size, &mut size_buf);
let written = code.len() + size.len() + digest.len();
w.write_all(code)
.map_err(crate::error::io_to_multihash_error)?;
w.write_all(size)
.map_err(crate::error::io_to_multihash_error)?;
w.write_all(digest)
.map_err(crate::error::io_to_multihash_error)?;
Ok(written)
}
fn read_multihash<R, const S: usize>(mut r: R) -> Result<(u64, u8, [u8; S]), Error>
where
R: io::Read,
{
let code = read_u64(&mut r)?;
let size = read_u64(&mut r)?;
if size > S as u64 || size > u8::MAX as u64 {
return Err(Error::invalid_size(size));
}
let mut digest = [0; S];
r.read_exact(&mut digest[..size as usize])
.map_err(crate::error::io_to_multihash_error)?;
Ok((code, size as u8, digest))
}
#[cfg(feature = "std")]
pub(crate) fn read_u64<R: io::Read>(r: R) -> Result<u64, Error> {
unsigned_varint::io::read_u64(r).map_err(crate::error::unsigned_varint_to_multihash_error)
}
#[cfg(not(feature = "std"))]
pub(crate) fn read_u64<R: io::Read>(mut r: R) -> Result<u64, Error> {
use unsigned_varint::decode;
let mut b = varint_encode::u64_buffer();
for i in 0..b.len() {
let n = r
.read(&mut (b[i..i + 1]))
.map_err(crate::error::io_to_multihash_error)?;
if n == 0 {
return Err(Error::insufficient_varint_bytes());
} else if decode::is_last(b[i]) {
return decode::u64(&b[..=i])
.map(|decoded| decoded.0)
.map_err(crate::error::unsigned_varint_decode_to_multihash_error);
}
}
Err(Error::varint_overflow())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(feature = "scale-codec")]
fn test_scale() {
use parity_scale_codec::{Decode, Encode};
let mh1 = Multihash::<32>::wrap(0, b"hello world").unwrap();
let mh1_bytes = mh1.encode();
let mh2: Multihash<32> = Decode::decode(&mut &mh1_bytes[..]).unwrap();
assert_eq!(mh1, mh2);
let mh3 = Multihash::<64>::wrap(0, b"hello world").unwrap();
let mh3_bytes = mh3.encode();
let mh4: Multihash<64> = Decode::decode(&mut &mh3_bytes[..]).unwrap();
assert_eq!(mh3, mh4);
assert_eq!(mh1_bytes, mh3_bytes);
}
#[test]
fn test_eq_sizes() {
let mh1 = Multihash::<32>::default();
let mh2 = Multihash::<64>::default();
assert_eq!(mh1, mh2);
}
#[test]
fn decode_non_minimal_error() {
let data = [241, 0, 0, 0, 0, 0, 128, 132, 132, 132, 58];
let result = read_u64(&data[..]);
assert!(result.is_err());
}
}