polkavm_common/
varint.rs

1#[inline]
2fn get_varint_length(leading_zeros: u32) -> u32 {
3    let bits_required = 32 - leading_zeros;
4    let x = bits_required >> 3;
5    ((x + bits_required) ^ x) >> 3
6}
7
8pub const MAX_VARINT_LENGTH: usize = 5;
9
10// TODO: Apply zigzag encoding to the varints before serialization/after deserialization.
11// (Otherwise negative offsets will always be encoded with the maximum number of bytes.)
12
13#[inline]
14pub(crate) fn read_varint(input: &[u8], first_byte: u8) -> Option<(usize, u32)> {
15    let length = (!first_byte).leading_zeros();
16    let upper_mask = 0b11111111_u32 >> length;
17    let upper_bits = (upper_mask & u32::from(first_byte)).wrapping_shl(length * 8);
18    let input = input.get(..length as usize)?;
19    let value = match input.len() {
20        0 => upper_bits,
21        1 => upper_bits | u32::from(input[0]),
22        2 => upper_bits | u32::from(u16::from_le_bytes([input[0], input[1]])),
23        3 => upper_bits | u32::from_le_bytes([input[0], input[1], input[2], 0]),
24        4 => upper_bits | u32::from_le_bytes([input[0], input[1], input[2], input[3]]),
25        _ => return None,
26    };
27
28    Some((length as usize, value))
29}
30
31#[inline]
32pub fn write_varint(value: u32, buffer: &mut [u8]) -> usize {
33    let varint_length = get_varint_length(value.leading_zeros());
34    match varint_length {
35        0 => buffer[0] = value as u8,
36        1 => {
37            buffer[0] = 0b10000000 | (value >> 8) as u8;
38            let bytes = value.to_le_bytes();
39            buffer[1] = bytes[0];
40        }
41        2 => {
42            buffer[0] = 0b11000000 | (value >> 16) as u8;
43            let bytes = value.to_le_bytes();
44            buffer[1] = bytes[0];
45            buffer[2] = bytes[1];
46        }
47        3 => {
48            buffer[0] = 0b11100000 | (value >> 24) as u8;
49            let bytes = value.to_le_bytes();
50            buffer[1] = bytes[0];
51            buffer[2] = bytes[1];
52            buffer[3] = bytes[2];
53        }
54        4 => {
55            buffer[0] = 0b11110000;
56            let bytes = value.to_le_bytes();
57            buffer[1] = bytes[0];
58            buffer[2] = bytes[1];
59            buffer[3] = bytes[2];
60            buffer[4] = bytes[3];
61        }
62        _ => unreachable!(),
63    }
64
65    varint_length as usize + 1
66}
67
68#[cfg(test)]
69proptest::proptest! {
70    #[allow(clippy::ignored_unit_patterns)]
71    #[test]
72    fn varint_serialization(value in 0u32..=0xffffffff) {
73        let mut buffer = [0; MAX_VARINT_LENGTH];
74        let length = write_varint(value, &mut buffer);
75        let (parsed_length, parsed_value) = read_varint(&buffer[1..], buffer[0]).unwrap();
76        assert_eq!(parsed_value, value, "value mismatch");
77        assert_eq!(parsed_length + 1, length, "length mismatch")
78    }
79}