base_x/
bigint.rs

1#[cfg(not(feature = "std"))]
2use alloc::vec::Vec;
3
4#[cfg(not(feature = "std"))]
5use core as std;
6
7use std::{ptr, u32};
8
9/// This is a pretty naive implementation of a BigUint abstracting all
10/// math out to a vector of `u32` chunks.
11///
12/// It can only do a few things:
13/// - Be instantiated from an arbitrary big-endian byte slice
14/// - Be converted to a vector of big-endian bytes.
15/// - Do a division by `u32`, mutating self and returning the remainder.
16/// - Do a multiplication with addition in one pass.
17/// - Check if it's zero.
18///
19/// Turns out those are all the operations you need to encode and decode
20/// base58, or anything else, really.
21pub struct BigUint {
22    pub chunks: Vec<u32>,
23}
24
25impl BigUint {
26    #[inline]
27    pub fn with_capacity(capacity: usize) -> Self {
28        let mut chunks = Vec::with_capacity(capacity);
29
30        chunks.push(0);
31
32        BigUint { chunks }
33    }
34
35    /// Divide self by `divider`, return the remainder of the operation.
36    #[inline]
37    pub fn div_mod(&mut self, divider: u32) -> u32 {
38        let mut carry = 0u64;
39
40        for chunk in self.chunks.iter_mut() {
41            carry = (carry << 32) | u64::from(*chunk);
42            *chunk = (carry / u64::from(divider)) as u32;
43            carry %= u64::from(divider);
44        }
45
46        if let Some(0) = self.chunks.get(0) {
47            self.chunks.remove(0);
48        }
49
50        carry as u32
51    }
52
53    /// Perform a multiplication followed by addition. This is a reverse
54    /// of `div_mod` in the sense that when supplied remained for addition
55    /// and the same base for multiplication as divison, the result is
56    /// the original BigUint.
57    #[inline]
58    pub fn mul_add(&mut self, multiplicator: u32, addition: u32) {
59        let mut carry = 0u64;
60
61        {
62            let mut iter = self.chunks.iter_mut().rev();
63
64            if let Some(chunk) = iter.next() {
65                carry = u64::from(*chunk) * u64::from(multiplicator) + u64::from(addition);
66                *chunk = carry as u32;
67                carry >>= 32;
68            }
69
70            for chunk in iter {
71                carry += u64::from(*chunk) * u64::from(multiplicator);
72                *chunk = carry as u32;
73                carry >>= 32;
74            }
75        }
76
77        if carry > 0 {
78            self.chunks.insert(0, carry as u32);
79        }
80    }
81
82    /// Check if self is zero.
83    #[inline]
84    pub fn is_zero(&self) -> bool {
85        self.chunks.iter().all(|chunk| *chunk == 0)
86    }
87
88    #[inline]
89    pub fn into_bytes_be(mut self) -> Vec<u8> {
90        let mut skip = 0;
91
92        for chunk in self.chunks.iter() {
93            if *chunk != 0 {
94                skip += chunk.leading_zeros() / 8;
95                break;
96            }
97
98            skip += 4;
99        }
100
101        let len = self.chunks.len() * 4 - skip as usize;
102
103        if len == 0 {
104            return Vec::new();
105        }
106
107        for chunk in self.chunks.iter_mut() {
108            *chunk = u32::to_be(*chunk);
109        }
110
111        let mut bytes = Vec::with_capacity(len);
112        unsafe {
113            bytes.set_len(len);
114
115            let chunks_ptr = (self.chunks.as_ptr() as *const u8).offset(skip as isize);
116
117            ptr::copy_nonoverlapping(chunks_ptr, bytes.as_mut_ptr(), len);
118
119        }
120            bytes
121    }
122
123    #[inline]
124    pub fn from_bytes_be(bytes: &[u8]) -> Self {
125        let modulo = bytes.len() % 4;
126
127        let len = bytes.len() / 4 + (modulo > 0) as usize;
128
129        let mut chunks = Vec::with_capacity(len);
130
131        unsafe {
132            chunks.set_len(len);
133
134            let mut chunks_ptr = chunks.as_mut_ptr() as *mut u8;
135
136            if modulo > 0 {
137                *chunks.get_unchecked_mut(0) = 0u32;
138                chunks_ptr = chunks_ptr.offset(4 - modulo as isize);
139            }
140
141            ptr::copy_nonoverlapping(bytes.as_ptr(), chunks_ptr, bytes.len());
142        }
143
144        for chunk in chunks.iter_mut() {
145            *chunk = u32::from_be(*chunk);
146        }
147
148        BigUint { chunks }
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    #![allow(clippy::unreadable_literal)]
155    use super::BigUint;
156
157    #[test]
158    fn big_uint_from_bytes() {
159        let bytes: &[u8] = &[
160            0xDE, 0xAD, 0x00, 0x00, 0x00, 0x13, 0x37, 0xAD, 0x00, 0x00, 0x00, 0x00, 0xDE, 0xAD,
161        ];
162
163        let big = BigUint::from_bytes_be(bytes);
164
165        assert_eq!(
166            big.chunks,
167            vec![0x0000DEAD, 0x00000013, 0x37AD0000, 0x0000DEAD]
168        );
169    }
170
171    #[test]
172    fn big_uint_rem_div() {
173        let mut big = BigUint {
174            chunks: vec![0x136AD712, 0x84322759],
175        };
176
177        let rem = big.div_mod(58);
178        let merged = (u64::from(big.chunks[0]) << 32) | u64::from(big.chunks[1]);
179
180        assert_eq!(merged, 0x136AD71284322759 / 58);
181        assert_eq!(u64::from(rem), 0x136AD71284322759 % 58);
182    }
183
184    #[test]
185    fn big_uint_add_mul() {
186        let mut big = BigUint {
187            chunks: vec![0x000AD712, 0x84322759],
188        };
189
190        big.mul_add(58, 37);
191        let merged = (u64::from(big.chunks[0]) << 32) | u64::from(big.chunks[1]);
192
193        assert_eq!(merged, (0x000AD71284322759 * 58) + 37);
194    }
195}