use super::Precompile;
use crate::{Config, ExecReturnValue, GasMeter, RuntimeCosts};
use alloc::{vec, vec::Vec};
use core::cmp::max;
use num_bigint::BigUint;
use num_integer::Integer;
use num_traits::{One, Zero};
use pallet_revive_uapi::ReturnFlags;
pub struct Modexp;
const MIN_GAS_COST: u64 = 200;
fn calculate_gas_cost(
base_length: u64,
mod_length: u64,
exponent: &BigUint,
exponent_bytes: &[u8],
mod_is_even: bool,
) -> u64 {
fn calculate_multiplication_complexity(base_length: u64, mod_length: u64) -> u64 {
let max_length = max(base_length, mod_length);
let mut words = max_length / 8;
if max_length % 8 > 0 {
words += 1;
}
words * words
}
fn calculate_iteration_count(exponent: &BigUint, exponent_bytes: &[u8]) -> u64 {
let mut iteration_count: u64 = 0;
let exp_length = exponent_bytes.len() as u64;
if exp_length <= 32 && exponent.is_zero() {
iteration_count = 0;
} else if exp_length <= 32 {
iteration_count = exponent.bits() - 1;
} else if exp_length > 32 {
let exponent_head = BigUint::from_bytes_be(&exponent_bytes[..32]);
iteration_count = (8 * (exp_length - 32)) + exponent_head.bits() - 1;
}
max(iteration_count, 1)
}
let multiplication_complexity = calculate_multiplication_complexity(base_length, mod_length);
let iteration_count = calculate_iteration_count(exponent, exponent_bytes);
max(MIN_GAS_COST, multiplication_complexity * iteration_count / 3)
.saturating_mul(if mod_is_even { 20 } else { 1 })
}
fn read_input(source: &[u8], target: &mut [u8], source_offset: &mut usize) {
let offset = *source_offset;
*source_offset += target.len();
if source.len() <= offset {
return;
}
let len = core::cmp::min(target.len(), source.len() - offset);
target[..len].copy_from_slice(&source[offset..][..len]);
}
impl<T: Config> Precompile<T> for Modexp {
fn execute(gas_meter: &mut GasMeter<T>, input: &[u8]) -> Result<ExecReturnValue, &'static str> {
let mut input_offset = 0;
let mut base_len_buf = [0u8; 32];
read_input(input, &mut base_len_buf, &mut input_offset);
let mut exp_len_buf = [0u8; 32];
read_input(input, &mut exp_len_buf, &mut input_offset);
let mut mod_len_buf = [0u8; 32];
read_input(input, &mut mod_len_buf, &mut input_offset);
let max_size_big = BigUint::from(1024u32);
let base_len_big = BigUint::from_bytes_be(&base_len_buf);
if base_len_big > max_size_big {
return Err("unreasonably large base length");
}
let exp_len_big = BigUint::from_bytes_be(&exp_len_buf);
if exp_len_big > max_size_big {
return Err("unreasonably large exponent length");
}
let mod_len_big = BigUint::from_bytes_be(&mod_len_buf);
if mod_len_big > max_size_big {
return Err("unreasonably large modulus length");
}
let base_len: usize = base_len_big.try_into().expect("base_len out of bounds");
let exp_len: usize = exp_len_big.try_into().expect("exp_len out of bounds");
let mod_len: usize = mod_len_big.try_into().expect("mod_len out of bounds");
if mod_len == 0 {
return Ok(ExecReturnValue { data: vec![], flags: ReturnFlags::empty() })
}
let r = if base_len == 0 && mod_len == 0 {
gas_meter.charge(RuntimeCosts::Modexp(MIN_GAS_COST))?;
BigUint::zero()
} else {
let mut base_buf = vec![0u8; base_len];
read_input(input, &mut base_buf, &mut input_offset);
let base = BigUint::from_bytes_be(&base_buf);
let mut exp_buf = vec![0u8; exp_len];
read_input(input, &mut exp_buf, &mut input_offset);
let exponent = BigUint::from_bytes_be(&exp_buf);
let mut mod_buf = vec![0u8; mod_len];
read_input(input, &mut mod_buf, &mut input_offset);
let modulus = BigUint::from_bytes_be(&mod_buf);
let gas_cost = calculate_gas_cost(
base_len as u64,
mod_len as u64,
&exponent,
&exp_buf,
modulus.is_even(),
);
gas_meter.charge(RuntimeCosts::Modexp(gas_cost))?;
if modulus.is_zero() || modulus.is_one() {
BigUint::zero()
} else {
base.modpow(&exponent, &modulus)
}
};
let bytes = r.to_bytes_be();
if bytes.len() == mod_len {
Ok(ExecReturnValue { data: bytes.to_vec(), flags: ReturnFlags::empty() })
} else if bytes.len() < mod_len {
let mut ret = Vec::with_capacity(mod_len);
ret.extend(core::iter::repeat(0).take(mod_len - bytes.len()));
ret.extend_from_slice(&bytes[..]);
Ok(ExecReturnValue { data: ret.to_vec(), flags: ReturnFlags::empty() })
} else {
Err("failed")
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pure_precompiles::test::*;
use alloy_core::hex;
#[test]
fn process_consensus_tests() -> Result<(), String> {
test_precompile_test_vectors::<Modexp>(include_str!("./testdata/5-modexp_eip2565.json"))?;
Ok(())
}
#[test]
fn test_empty_input() {
let input = Vec::new();
let result = run_precompile::<Modexp>(input).unwrap();
assert_eq!(result.data, Vec::<u8>::new());
}
#[test]
fn test_insufficient_input() {
let input = hex::decode(
"0000000000000000000000000000000000000000000000000000000000000001\
0000000000000000000000000000000000000000000000000000000000000001\
0000000000000000000000000000000000000000000000000000000000000001",
)
.expect("Decode failed");
let result = run_precompile::<Modexp>(input).unwrap();
assert_eq!(result.data, vec![0x00]);
}
#[test]
fn test_excessive_input() {
let input = hex::decode(
"1000000000000000000000000000000000000000000000000000000000000001\
0000000000000000000000000000000000000000000000000000000000000001\
0000000000000000000000000000000000000000000000000000000000000001",
)
.expect("Decode failed");
let result = run_precompile::<Modexp>(input).unwrap_err();
assert_eq!(result, "unreasonably large base length");
}
#[test]
fn test_simple_inputs() {
let input = hex::decode(
"0000000000000000000000000000000000000000000000000000000000000001\
0000000000000000000000000000000000000000000000000000000000000001\
0000000000000000000000000000000000000000000000000000000000000001\
03\
05\
07",
)
.expect("Decode failed");
let precompile_result = run_precompile::<Modexp>(input).unwrap();
assert_eq!(precompile_result.data.len(), 1); let result = BigUint::from_bytes_be(&precompile_result.data[..]);
let expected = BigUint::parse_bytes(b"5", 10).unwrap();
assert_eq!(result, expected);
}
#[test]
fn test_large_inputs() {
let input = hex::decode(
"0000000000000000000000000000000000000000000000000000000000000020\
0000000000000000000000000000000000000000000000000000000000000020\
0000000000000000000000000000000000000000000000000000000000000020\
000000000000000000000000000000000000000000000000000000000000EA5F\
0000000000000000000000000000000000000000000000000000000000000015\
0000000000000000000000000000000000000000000000000000000000003874",
)
.expect("Decode failed");
let precompile_result = run_precompile::<Modexp>(input).unwrap();
assert_eq!(precompile_result.data.len(), 32); let result = BigUint::from_bytes_be(&precompile_result.data[..]);
let expected = BigUint::parse_bytes(b"10055", 10).unwrap();
assert_eq!(result, expected);
}
#[test]
fn test_large_computation() {
let input = hex::decode(
"0000000000000000000000000000000000000000000000000000000000000001\
0000000000000000000000000000000000000000000000000000000000000020\
0000000000000000000000000000000000000000000000000000000000000020\
03\
fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2e\
fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f",
)
.expect("Decode failed");
let precompile_result = run_precompile::<Modexp>(input).unwrap();
assert_eq!(precompile_result.data.len(), 32); let result = BigUint::from_bytes_be(&precompile_result.data[..]);
let expected = BigUint::parse_bytes(b"1", 10).unwrap();
assert_eq!(result, expected);
}
#[test]
fn test_zero_exp_with_33_length() {
let input = vec![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 33, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
];
let precompile_result = run_precompile::<Modexp>(input).unwrap();
assert_eq!(precompile_result.data.len(), 1); let result = BigUint::from_bytes_be(&precompile_result.data[..]);
let expected = BigUint::parse_bytes(b"0", 10).unwrap();
assert_eq!(result, expected);
}
#[test]
fn test_long_exp_gas_cost_matches_specs() {
use crate::{gas::Token, tests::Test, GasMeter, Weight};
let input = vec![
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 38, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
16, 0, 0, 0, 255, 255, 255, 2, 0, 0, 179, 0, 0, 2, 0, 0, 122, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 255, 251, 0, 0, 0, 0, 4, 38, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 255, 255, 255, 2, 0, 0, 179, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255,
255, 255, 255, 249,
];
let mut gas_meter = GasMeter::<Test>::new(Weight::MAX);
Modexp::execute(&mut gas_meter, &input).unwrap();
assert_eq!(
gas_meter.gas_consumed(),
Token::<Test>::weight(&RuntimeCosts::Modexp(7104 * 20))
);
}
}