referrerpolicy=no-referrer-when-downgrade

pallet_revive/precompiles/builtin/
modexp.rs

1// This file is part of Substrate.
2
3// Copyright (C) Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18use crate::{
19	precompiles::{BuiltinAddressMatcher, Error, Ext, PrimitivePrecompile},
20	vm::RuntimeCosts,
21	Config,
22};
23use alloc::{vec, vec::Vec};
24use core::{cmp::max, marker::PhantomData, num::NonZero};
25use num_bigint::BigUint;
26use num_integer::Integer;
27use num_traits::{One, ToPrimitive, Zero};
28use sp_runtime::DispatchError;
29
30/// See EIP-2565
31const MIN_GAS_COST: u64 = 200;
32
33/// The Modexp precompile.
34/// ModExp expects the following as inputs:
35/// 1) 32 bytes expressing the length of base
36/// 2) 32 bytes expressing the length of exponent
37/// 3) 32 bytes expressing the length of modulus
38/// 4) base, size as described above
39/// 5) exponent, size as described above
40/// 6) modulus, size as described above
41///
42/// # Note
43///
44/// input sizes are bound to 1024 bytes, with the expectation
45/// that gas limits would be applied before actual computation.
46/// maximum stack size will also prevent abuse.
47/// see <https://eips.ethereum.org/EIPS/eip-198>
48pub struct Modexp<T>(PhantomData<T>);
49
50impl<T: Config> PrimitivePrecompile for Modexp<T> {
51	type T = T;
52	const MATCHER: BuiltinAddressMatcher = BuiltinAddressMatcher::Fixed(NonZero::new(5).unwrap());
53	const HAS_CONTRACT_INFO: bool = false;
54
55	fn call(
56		_address: &[u8; 20],
57		input: Vec<u8>,
58		env: &mut impl Ext<T = Self::T>,
59	) -> Result<Vec<u8>, Error> {
60		let mut input_offset = 0;
61
62		// Yellowpaper: whenever the input is too short, the missing bytes are
63		// considered to be zero.
64		let mut base_len_buf = [0u8; 32];
65		read_input(&input, &mut base_len_buf, &mut input_offset);
66		let mut exp_len_buf = [0u8; 32];
67		read_input(&input, &mut exp_len_buf, &mut input_offset);
68		let mut mod_len_buf = [0u8; 32];
69		read_input(&input, &mut mod_len_buf, &mut input_offset);
70
71		// reasonable assumption: this must fit within the Ethereum EVM's max stack size
72		let max_size_big = BigUint::from(1024u32);
73
74		let base_len_big = BigUint::from_bytes_be(&base_len_buf);
75		if base_len_big > max_size_big {
76			Err(DispatchError::from("unreasonably large base length"))?;
77		}
78
79		let exp_len_big = BigUint::from_bytes_be(&exp_len_buf);
80		if exp_len_big > max_size_big {
81			Err(DispatchError::from("unreasonably exponent length"))?;
82		}
83
84		let mod_len_big = BigUint::from_bytes_be(&mod_len_buf);
85		if mod_len_big > max_size_big {
86			Err(DispatchError::from("unreasonably large modulus length"))?;
87		}
88
89		// bounds check handled above
90		let base_len = base_len_big.to_usize().expect("base_len out of bounds");
91		let exp_len = exp_len_big.to_usize().expect("exp_len out of bounds");
92		let mod_len = mod_len_big.to_usize().expect("mod_len out of bounds");
93
94		// if mod_len is 0 output must be empty
95		if mod_len == 0 {
96			return Ok(Vec::new())
97		}
98
99		// Gas formula allows arbitrary large exp_len when base and modulus are empty, so we need to
100		// handle empty base first.
101		let r = if base_len == 0 && mod_len == 0 {
102			env.gas_meter_mut().charge(RuntimeCosts::Modexp(MIN_GAS_COST))?;
103
104			BigUint::zero()
105		} else {
106			// read the numbers themselves.
107			let mut base_buf = vec![0u8; base_len];
108			read_input(&input, &mut base_buf, &mut input_offset);
109			let base = BigUint::from_bytes_be(&base_buf);
110
111			let mut exp_buf = vec![0u8; exp_len];
112			read_input(&input, &mut exp_buf, &mut input_offset);
113			let exponent = BigUint::from_bytes_be(&exp_buf);
114
115			let mut mod_buf = vec![0u8; mod_len];
116			read_input(&input, &mut mod_buf, &mut input_offset);
117			let modulus = BigUint::from_bytes_be(&mod_buf);
118
119			// do our gas accounting
120			let gas_cost = calculate_gas_cost(
121				base_len as u64,
122				mod_len as u64,
123				&exponent,
124				&exp_buf,
125				modulus.is_even(),
126			);
127
128			env.gas_meter_mut().charge(RuntimeCosts::Modexp(gas_cost))?;
129
130			if modulus.is_zero() || modulus.is_one() {
131				BigUint::zero()
132			} else {
133				base.modpow(&exponent, &modulus)
134			}
135		};
136
137		// write output to given memory, left padded and same length as the modulus.
138		let bytes = r.to_bytes_be();
139
140		// always true except in the case of zero-length modulus, which leads to
141		// output of length and value 1.
142		if bytes.len() == mod_len {
143			Ok(bytes.to_vec())
144		} else if bytes.len() < mod_len {
145			let mut ret = Vec::with_capacity(mod_len);
146			ret.extend(core::iter::repeat(0).take(mod_len - bytes.len()));
147			ret.extend_from_slice(&bytes[..]);
148			Ok(ret)
149		} else {
150			return Err(DispatchError::from("failed").into());
151		}
152	}
153}
154
155// Calculate gas cost according to EIP 2565:
156// https://eips.ethereum.org/EIPS/eip-2565
157fn calculate_gas_cost(
158	base_length: u64,
159	mod_length: u64,
160	exponent: &BigUint,
161	exponent_bytes: &[u8],
162	mod_is_even: bool,
163) -> u64 {
164	fn calculate_multiplication_complexity(base_length: u64, mod_length: u64) -> u64 {
165		let max_length = max(base_length, mod_length);
166		let mut words = max_length / 8;
167		if max_length % 8 > 0 {
168			words += 1;
169		}
170
171		// Note: can't overflow because we take words to be some u64 value / 8, which is
172		// necessarily less than sqrt(u64::MAX).
173		// Additionally, both base_length and mod_length are bounded to 1024, so this has
174		// an upper bound of roughly (1024 / 8) squared
175		words * words
176	}
177
178	fn calculate_iteration_count(exponent: &BigUint, exponent_bytes: &[u8]) -> u64 {
179		let mut iteration_count: u64 = 0;
180		let exp_length = exponent_bytes.len() as u64;
181
182		if exp_length <= 32 && exponent.is_zero() {
183			iteration_count = 0;
184		} else if exp_length <= 32 {
185			iteration_count = exponent.bits() - 1;
186		} else if exp_length > 32 {
187			// from the EIP spec:
188			// (8 * (exp_length - 32)) + ((exponent & (2**256 - 1)).bit_length() - 1)
189			//
190			// Notes:
191			// * exp_length is bounded to 1024 and is > 32
192			// * exponent can be zero, so we subtract 1 after adding the other terms (whose sum must
193			//   be > 0)
194			// * the addition can't overflow because the terms are both capped at roughly 8 * max
195			//   size of exp_length (1024)
196			// * the EIP spec is written in python, in which (exponent & (2**256 - 1)) takes the
197			//   FIRST 32 bytes. However this `BigUint` `&` operator takes the LAST 32 bytes. We
198			//   thus instead take the bytes manually.
199			let exponent_head = BigUint::from_bytes_be(&exponent_bytes[..32]);
200
201			iteration_count = (8 * (exp_length - 32)) + exponent_head.bits() - 1;
202		}
203
204		max(iteration_count, 1)
205	}
206
207	let multiplication_complexity = calculate_multiplication_complexity(base_length, mod_length);
208	let iteration_count = calculate_iteration_count(exponent, exponent_bytes);
209	max(MIN_GAS_COST, multiplication_complexity * iteration_count / 3)
210		.saturating_mul(if mod_is_even { 20 } else { 1 })
211}
212
213/// Copy bytes from input to target.
214fn read_input(source: &[u8], target: &mut [u8], source_offset: &mut usize) {
215	// We move the offset by the len of the target, regardless of what we
216	// actually copy.
217	let offset = *source_offset;
218	*source_offset += target.len();
219
220	// Out of bounds, nothing to copy.
221	if source.len() <= offset {
222		return;
223	}
224
225	// Find len to copy up to target len, but not out of bounds.
226	let len = core::cmp::min(target.len(), source.len() - offset);
227	target[..len].copy_from_slice(&source[offset..][..len]);
228}
229
230#[cfg(test)]
231mod tests {
232	use super::*;
233	use crate::{
234		precompiles::tests::{run_primitive, run_test_vectors},
235		tests::Test,
236	};
237	use alloy_core::hex;
238
239	#[test]
240	fn process_consensus_tests() {
241		run_test_vectors::<Modexp<Test>>(include_str!("./testdata/5-modexp_eip2565.json"));
242	}
243
244	#[test]
245	fn test_empty_input() {
246		let input = Vec::new();
247		let result = run_primitive::<Modexp<Test>>(input).unwrap();
248		assert_eq!(result, Vec::<u8>::new());
249	}
250
251	#[test]
252	fn test_insufficient_input() {
253		let input = hex::decode(
254			"0000000000000000000000000000000000000000000000000000000000000001\
255			0000000000000000000000000000000000000000000000000000000000000001\
256			0000000000000000000000000000000000000000000000000000000000000001",
257		)
258		.expect("Decode failed");
259
260		let result = run_primitive::<Modexp<Test>>(input).unwrap();
261		assert_eq!(result, vec![0x00]);
262	}
263
264	#[test]
265	fn test_excessive_input() {
266		let input = hex::decode(
267			"1000000000000000000000000000000000000000000000000000000000000001\
268			0000000000000000000000000000000000000000000000000000000000000001\
269			0000000000000000000000000000000000000000000000000000000000000001",
270		)
271		.expect("Decode failed");
272
273		let result = run_primitive::<Modexp<Test>>(input).unwrap_err();
274		if let Error::Error(crate::ExecError { error: DispatchError::Other(reason), .. }) = result {
275			assert_eq!(reason, "unreasonably large base length");
276		} else {
277			panic!("Unexpected error");
278		}
279	}
280
281	#[test]
282	fn test_simple_inputs() {
283		let input = hex::decode(
284			"0000000000000000000000000000000000000000000000000000000000000001\
285			0000000000000000000000000000000000000000000000000000000000000001\
286			0000000000000000000000000000000000000000000000000000000000000001\
287			03\
288			05\
289			07",
290		)
291		.expect("Decode failed");
292
293		// 3 ^ 5 % 7 == 5
294
295		let precompile_result = run_primitive::<Modexp<Test>>(input).unwrap();
296		assert_eq!(precompile_result.len(), 1); // should be same length as mod
297		let result = BigUint::from_bytes_be(&precompile_result[..]);
298		let expected = BigUint::parse_bytes(b"5", 10).unwrap();
299		assert_eq!(result, expected);
300	}
301
302	#[test]
303	fn test_large_inputs() {
304		let input = hex::decode(
305			"0000000000000000000000000000000000000000000000000000000000000020\
306			0000000000000000000000000000000000000000000000000000000000000020\
307			0000000000000000000000000000000000000000000000000000000000000020\
308			000000000000000000000000000000000000000000000000000000000000EA5F\
309			0000000000000000000000000000000000000000000000000000000000000015\
310			0000000000000000000000000000000000000000000000000000000000003874",
311		)
312		.expect("Decode failed");
313
314		// 59999 ^ 21 % 14452 = 10055
315
316		let precompile_result = run_primitive::<Modexp<Test>>(input).unwrap();
317		assert_eq!(precompile_result.len(), 32); // should be same length as mod
318		let result = BigUint::from_bytes_be(&precompile_result[..]);
319		let expected = BigUint::parse_bytes(b"10055", 10).unwrap();
320		assert_eq!(result, expected);
321	}
322
323	#[test]
324	fn test_large_computation() {
325		let input = hex::decode(
326			"0000000000000000000000000000000000000000000000000000000000000001\
327			0000000000000000000000000000000000000000000000000000000000000020\
328			0000000000000000000000000000000000000000000000000000000000000020\
329			03\
330			fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2e\
331			fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f",
332		)
333		.expect("Decode failed");
334
335		let precompile_result = run_primitive::<Modexp<Test>>(input).unwrap();
336		assert_eq!(precompile_result.len(), 32); // should be same length as mod
337		let result = BigUint::from_bytes_be(&precompile_result[..]);
338		let expected = BigUint::parse_bytes(b"1", 10).unwrap();
339		assert_eq!(result, expected);
340	}
341
342	#[test]
343	fn test_zero_exp_with_33_length() {
344		// This is a regression test which ensures that the 'iteration_count' calculation
345		// in 'calculate_iteration_count' cannot underflow.
346		//
347		// In debug mode, this underflow could cause a panic. Otherwise, it causes N**0 to
348		// be calculated at more-than-normal expense.
349		//
350		// TODO: cite security advisory
351
352		let input = vec![
353			0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
354			0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
355			0, 0, 0, 0, 0, 33, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
356			0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
357			0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,
358		];
359
360		let precompile_result = run_primitive::<Modexp<Test>>(input).unwrap();
361		assert_eq!(precompile_result.len(), 1); // should be same length as mod
362		let result = BigUint::from_bytes_be(&precompile_result[..]);
363		let expected = BigUint::parse_bytes(b"0", 10).unwrap();
364		assert_eq!(result, expected);
365	}
366
367	#[test]
368	fn test_long_exp_gas_cost_matches_specs() {
369		use crate::{call_builder::CallSetup, gas::Token, tests::ExtBuilder};
370
371		let input = vec![
372			0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
373			0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
374			0, 0, 0, 0, 0, 38, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
375			0, 0, 0, 0, 0, 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
376			16, 0, 0, 0, 255, 255, 255, 2, 0, 0, 179, 0, 0, 2, 0, 0, 122, 0, 0, 0, 0, 0, 0, 0, 0,
377			0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
378			0, 0, 0, 0, 0, 0, 0, 0, 255, 251, 0, 0, 0, 0, 4, 38, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
379			0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
380			0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
381			0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
382			0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 255, 255, 255, 2, 0, 0, 179, 0, 0, 0, 0, 0, 0, 0, 0,
383			0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
384			0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 255,
385			255, 255, 255, 249,
386		];
387
388		ExtBuilder::default().build().execute_with(|| {
389			let mut call_setup = CallSetup::<Test>::default();
390			let (mut ext, _) = call_setup.ext();
391
392			let before = ext.gas_meter().gas_consumed();
393			<Modexp<Test>>::call(&<Modexp<Test>>::MATCHER.base_address(), input, &mut ext).unwrap();
394			let after = ext.gas_meter().gas_consumed();
395
396			// 7104 * 20 gas used when ran in geth (x20)
397			assert_eq!(after - before, Token::<Test>::weight(&RuntimeCosts::Modexp(7104 * 20)));
398		})
399	}
400}