1use crate::{
19 precompiles::{BuiltinAddressMatcher, Error, Ext, PrimitivePrecompile},
20 vm::RuntimeCosts,
21 Config,
22};
23use alloc::{vec, vec::Vec};
24use core::{cmp::max, marker::PhantomData, num::NonZero};
25use num_bigint::BigUint;
26use num_integer::Integer;
27use num_traits::{One, ToPrimitive, Zero};
28use sp_runtime::DispatchError;
29
30const MIN_GAS_COST: u64 = 200;
32
33pub struct Modexp<T>(PhantomData<T>);
49
50impl<T: Config> PrimitivePrecompile for Modexp<T> {
51 type T = T;
52 const MATCHER: BuiltinAddressMatcher = BuiltinAddressMatcher::Fixed(NonZero::new(5).unwrap());
53 const HAS_CONTRACT_INFO: bool = false;
54
55 fn call(
56 _address: &[u8; 20],
57 input: Vec<u8>,
58 env: &mut impl Ext<T = Self::T>,
59 ) -> Result<Vec<u8>, Error> {
60 let mut input_offset = 0;
61
62 let mut base_len_buf = [0u8; 32];
65 read_input(&input, &mut base_len_buf, &mut input_offset);
66 let mut exp_len_buf = [0u8; 32];
67 read_input(&input, &mut exp_len_buf, &mut input_offset);
68 let mut mod_len_buf = [0u8; 32];
69 read_input(&input, &mut mod_len_buf, &mut input_offset);
70
71 let max_size_big = BigUint::from(1024u32);
73
74 let base_len_big = BigUint::from_bytes_be(&base_len_buf);
75 if base_len_big > max_size_big {
76 Err(DispatchError::from("unreasonably large base length"))?;
77 }
78
79 let exp_len_big = BigUint::from_bytes_be(&exp_len_buf);
80 if exp_len_big > max_size_big {
81 Err(DispatchError::from("unreasonably exponent length"))?;
82 }
83
84 let mod_len_big = BigUint::from_bytes_be(&mod_len_buf);
85 if mod_len_big > max_size_big {
86 Err(DispatchError::from("unreasonably large modulus length"))?;
87 }
88
89 let base_len = base_len_big.to_usize().expect("base_len out of bounds");
91 let exp_len = exp_len_big.to_usize().expect("exp_len out of bounds");
92 let mod_len = mod_len_big.to_usize().expect("mod_len out of bounds");
93
94 if mod_len == 0 {
96 return Ok(Vec::new())
97 }
98
99 let r = if base_len == 0 && mod_len == 0 {
102 env.gas_meter_mut().charge(RuntimeCosts::Modexp(MIN_GAS_COST))?;
103
104 BigUint::zero()
105 } else {
106 let mut base_buf = vec![0u8; base_len];
108 read_input(&input, &mut base_buf, &mut input_offset);
109 let base = BigUint::from_bytes_be(&base_buf);
110
111 let mut exp_buf = vec![0u8; exp_len];
112 read_input(&input, &mut exp_buf, &mut input_offset);
113 let exponent = BigUint::from_bytes_be(&exp_buf);
114
115 let mut mod_buf = vec![0u8; mod_len];
116 read_input(&input, &mut mod_buf, &mut input_offset);
117 let modulus = BigUint::from_bytes_be(&mod_buf);
118
119 let gas_cost = calculate_gas_cost(
121 base_len as u64,
122 mod_len as u64,
123 &exponent,
124 &exp_buf,
125 modulus.is_even(),
126 );
127
128 env.gas_meter_mut().charge(RuntimeCosts::Modexp(gas_cost))?;
129
130 if modulus.is_zero() || modulus.is_one() {
131 BigUint::zero()
132 } else {
133 base.modpow(&exponent, &modulus)
134 }
135 };
136
137 let bytes = r.to_bytes_be();
139
140 if bytes.len() == mod_len {
143 Ok(bytes.to_vec())
144 } else if bytes.len() < mod_len {
145 let mut ret = Vec::with_capacity(mod_len);
146 ret.extend(core::iter::repeat(0).take(mod_len - bytes.len()));
147 ret.extend_from_slice(&bytes[..]);
148 Ok(ret)
149 } else {
150 return Err(DispatchError::from("failed").into());
151 }
152 }
153}
154
155fn calculate_gas_cost(
158 base_length: u64,
159 mod_length: u64,
160 exponent: &BigUint,
161 exponent_bytes: &[u8],
162 mod_is_even: bool,
163) -> u64 {
164 fn calculate_multiplication_complexity(base_length: u64, mod_length: u64) -> u64 {
165 let max_length = max(base_length, mod_length);
166 let mut words = max_length / 8;
167 if max_length % 8 > 0 {
168 words += 1;
169 }
170
171 words * words
176 }
177
178 fn calculate_iteration_count(exponent: &BigUint, exponent_bytes: &[u8]) -> u64 {
179 let mut iteration_count: u64 = 0;
180 let exp_length = exponent_bytes.len() as u64;
181
182 if exp_length <= 32 && exponent.is_zero() {
183 iteration_count = 0;
184 } else if exp_length <= 32 {
185 iteration_count = exponent.bits() - 1;
186 } else if exp_length > 32 {
187 let exponent_head = BigUint::from_bytes_be(&exponent_bytes[..32]);
200
201 iteration_count = (8 * (exp_length - 32)) + exponent_head.bits() - 1;
202 }
203
204 max(iteration_count, 1)
205 }
206
207 let multiplication_complexity = calculate_multiplication_complexity(base_length, mod_length);
208 let iteration_count = calculate_iteration_count(exponent, exponent_bytes);
209 max(MIN_GAS_COST, multiplication_complexity * iteration_count / 3)
210 .saturating_mul(if mod_is_even { 20 } else { 1 })
211}
212
213fn read_input(source: &[u8], target: &mut [u8], source_offset: &mut usize) {
215 let offset = *source_offset;
218 *source_offset += target.len();
219
220 if source.len() <= offset {
222 return;
223 }
224
225 let len = core::cmp::min(target.len(), source.len() - offset);
227 target[..len].copy_from_slice(&source[offset..][..len]);
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::{
234 precompiles::tests::{run_primitive, run_test_vectors},
235 tests::Test,
236 };
237 use alloy_core::hex;
238
239 #[test]
240 fn process_consensus_tests() {
241 run_test_vectors::<Modexp<Test>>(include_str!("./testdata/5-modexp_eip2565.json"));
242 }
243
244 #[test]
245 fn test_empty_input() {
246 let input = Vec::new();
247 let result = run_primitive::<Modexp<Test>>(input).unwrap();
248 assert_eq!(result, Vec::<u8>::new());
249 }
250
251 #[test]
252 fn test_insufficient_input() {
253 let input = hex::decode(
254 "0000000000000000000000000000000000000000000000000000000000000001\
255 0000000000000000000000000000000000000000000000000000000000000001\
256 0000000000000000000000000000000000000000000000000000000000000001",
257 )
258 .expect("Decode failed");
259
260 let result = run_primitive::<Modexp<Test>>(input).unwrap();
261 assert_eq!(result, vec![0x00]);
262 }
263
264 #[test]
265 fn test_excessive_input() {
266 let input = hex::decode(
267 "1000000000000000000000000000000000000000000000000000000000000001\
268 0000000000000000000000000000000000000000000000000000000000000001\
269 0000000000000000000000000000000000000000000000000000000000000001",
270 )
271 .expect("Decode failed");
272
273 let result = run_primitive::<Modexp<Test>>(input).unwrap_err();
274 if let Error::Error(crate::ExecError { error: DispatchError::Other(reason), .. }) = result {
275 assert_eq!(reason, "unreasonably large base length");
276 } else {
277 panic!("Unexpected error");
278 }
279 }
280
281 #[test]
282 fn test_simple_inputs() {
283 let input = hex::decode(
284 "0000000000000000000000000000000000000000000000000000000000000001\
285 0000000000000000000000000000000000000000000000000000000000000001\
286 0000000000000000000000000000000000000000000000000000000000000001\
287 03\
288 05\
289 07",
290 )
291 .expect("Decode failed");
292
293 let precompile_result = run_primitive::<Modexp<Test>>(input).unwrap();
296 assert_eq!(precompile_result.len(), 1); let result = BigUint::from_bytes_be(&precompile_result[..]);
298 let expected = BigUint::parse_bytes(b"5", 10).unwrap();
299 assert_eq!(result, expected);
300 }
301
302 #[test]
303 fn test_large_inputs() {
304 let input = hex::decode(
305 "0000000000000000000000000000000000000000000000000000000000000020\
306 0000000000000000000000000000000000000000000000000000000000000020\
307 0000000000000000000000000000000000000000000000000000000000000020\
308 000000000000000000000000000000000000000000000000000000000000EA5F\
309 0000000000000000000000000000000000000000000000000000000000000015\
310 0000000000000000000000000000000000000000000000000000000000003874",
311 )
312 .expect("Decode failed");
313
314 let precompile_result = run_primitive::<Modexp<Test>>(input).unwrap();
317 assert_eq!(precompile_result.len(), 32); let result = BigUint::from_bytes_be(&precompile_result[..]);
319 let expected = BigUint::parse_bytes(b"10055", 10).unwrap();
320 assert_eq!(result, expected);
321 }
322
323 #[test]
324 fn test_large_computation() {
325 let input = hex::decode(
326 "0000000000000000000000000000000000000000000000000000000000000001\
327 0000000000000000000000000000000000000000000000000000000000000020\
328 0000000000000000000000000000000000000000000000000000000000000020\
329 03\
330 fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2e\
331 fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f",
332 )
333 .expect("Decode failed");
334
335 let precompile_result = run_primitive::<Modexp<Test>>(input).unwrap();
336 assert_eq!(precompile_result.len(), 32); let result = BigUint::from_bytes_be(&precompile_result[..]);
338 let expected = BigUint::parse_bytes(b"1", 10).unwrap();
339 assert_eq!(result, expected);
340 }
341
342 #[test]
343 fn test_zero_exp_with_33_length() {
344 let input = vec![
353 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
354 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
355 0, 0, 0, 0, 0, 33, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
356 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
357 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
358 ];
359
360 let precompile_result = run_primitive::<Modexp<Test>>(input).unwrap();
361 assert_eq!(precompile_result.len(), 1); let result = BigUint::from_bytes_be(&precompile_result[..]);
363 let expected = BigUint::parse_bytes(b"0", 10).unwrap();
364 assert_eq!(result, expected);
365 }
366
367 #[test]
368 fn test_long_exp_gas_cost_matches_specs() {
369 use crate::{call_builder::CallSetup, gas::Token, tests::ExtBuilder};
370
371 let input = vec![
372 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
373 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
374 0, 0, 0, 0, 0, 38, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
375 0, 0, 0, 0, 0, 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
376 16, 0, 0, 0, 255, 255, 255, 2, 0, 0, 179, 0, 0, 2, 0, 0, 122, 0, 0, 0, 0, 0, 0, 0, 0,
377 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
378 0, 0, 0, 0, 0, 0, 0, 0, 255, 251, 0, 0, 0, 0, 4, 38, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
379 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
380 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
381 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
382 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 255, 255, 255, 2, 0, 0, 179, 0, 0, 0, 0, 0, 0, 0, 0,
383 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
384 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255,
385 255, 255, 255, 249,
386 ];
387
388 ExtBuilder::default().build().execute_with(|| {
389 let mut call_setup = CallSetup::<Test>::default();
390 let (mut ext, _) = call_setup.ext();
391
392 let before = ext.gas_meter().gas_consumed();
393 <Modexp<Test>>::call(&<Modexp<Test>>::MATCHER.base_address(), input, &mut ext).unwrap();
394 let after = ext.gas_meter().gas_consumed();
395
396 assert_eq!(after - before, Token::<Test>::weight(&RuntimeCosts::Modexp(7104 * 20)));
398 })
399 }
400}