Skip to main content

revive_strategy/cheatcodes/
mock_handler.rs

1use std::{
2    cell::RefCell,
3    collections::{BTreeMap, VecDeque},
4    rc::Rc,
5};
6
7use alloy_primitives::{Address, Bytes, map::foldhash::HashMap, ruint::aliases::U256};
8use foundry_cheatcodes::{DealRecord, Ecx, MockCallDataContext, MockCallReturnData};
9use foundry_evm::constants::CHEATCODE_ADDRESS;
10use polkadot_sdk::{
11    frame_system,
12    pallet_revive::{
13        self, AccountId32Mapper, AddressMapper, DelegateInfo, ExecOrigin, ExecReturnValue, Pallet,
14        mock::MockHandler,
15    },
16    pallet_revive_uapi::ReturnFlags,
17    polkadot_sdk_frame::prelude::OriginFor,
18    sp_core::{H160, U256 as SpU256},
19};
20use revive_env::Runtime;
21
22use revm::interpreter::InstructionResult;
23
24// Implementation object that holds the mock state and implements the MockHandler trait for Revive.
25// It is only purpose is to make transferring the mock state into the Revive EVM easier and then
26// synchronize whatever mocks got consumed back into the Cheatcodes state after the call.
27#[derive(Clone)]
28pub(crate) struct MockHandlerImpl {
29    inner: Rc<RefCell<MockHandlerInner<Runtime>>>,
30    pub origin: ExecOrigin<Runtime>,
31}
32
33impl MockHandlerImpl {
34    /// Creates a new MockHandlerImpl from the given Ecx and Cheatcodes state.
35    pub(crate) fn new(
36        ecx: &Ecx<'_, '_, '_>,
37        caller: &Address,
38        origin: &Address,
39        target_address: Option<&Address>,
40        callee: Option<&Address>,
41        state: &mut foundry_cheatcodes::Cheatcodes,
42    ) -> Self {
43        let inject_env = MockHandlerInner::new(ecx, caller, target_address, callee, state);
44        Self {
45            inner: Rc::new(RefCell::new(inject_env)),
46            origin: ExecOrigin::<Runtime>::from_runtime_origin(OriginFor::<Runtime>::signed(
47                AccountId32Mapper::<Runtime>::to_fallback_account_id(&H160::from_slice(
48                    origin.as_slice(),
49                )),
50            ))
51            .expect("Could not create tx origin"),
52        }
53    }
54
55    /// Updates the given Cheatcodes state with the current mock state.
56    /// This is used to synchronize the mock state after a call has been executed in Revive
57    pub(crate) fn update_state_mocks(&self, state: &mut foundry_cheatcodes::Cheatcodes) {
58        let mock_inner = self.inner.borrow();
59        state.mocked_calls = mock_inner.mocked_calls.clone();
60        state.mocked_functions = mock_inner.mocked_functions.clone();
61    }
62
63    /// Syncs balances for pranked accounts between REVM and pallet-revive.
64    ///
65    /// If the account was explicitly dealt to via vm.deal(), sync that balance to pallet-revive.
66    /// This handles cases where vm.deal() was called in a callback and pallet-revive's balance
67    /// diverged from REVM's balance.
68    ///
69    /// If the account was NOT dealt to and has 0 balance, fund with u128::MAX so fuzzed
70    /// prank addresses can make calls in pallet-revive.
71    pub(crate) fn fund_pranked_accounts(account: Address, eth_deals: &[DealRecord]) {
72        let account_h160 = H160::from_slice(account.as_slice());
73
74        // Check if account was explicitly dealt to via vm.deal()
75        // Use the most recent deal record for this account
76        if let Some(deal) = eth_deals.iter().rev().find(|d| d.address == account) {
77            // Sync the dealt balance to pallet-revive
78            let target_balance =
79                SpU256::from_little_endian(&deal.new_balance.as_le_bytes()).min(u128::MAX.into());
80            let pvm_balance = Pallet::<Runtime>::evm_balance(&account_h160);
81            if pvm_balance != target_balance {
82                Pallet::<Runtime>::set_evm_balance(&account_h160, target_balance)
83                    .expect("Could not sync dealt account balance");
84            }
85            return;
86        }
87
88        // Account was not dealt to - fund with u128::MAX if balance is 0
89        let pvm_balance = Pallet::<Runtime>::evm_balance(&account_h160);
90        if pvm_balance == 0.into() {
91            Pallet::<Runtime>::set_evm_balance(&account_h160, u128::MAX.into())
92                .expect("Could not fund pranked account");
93        }
94    }
95}
96
97impl MockHandler<Runtime> for MockHandlerImpl {
98    fn mock_call(
99        &self,
100        callee: H160,
101        call_data: &[u8],
102        value_transferred: polkadot_sdk::pallet_revive::U256,
103    ) -> Option<pallet_revive::ExecReturnValue> {
104        // Check if trying to call cheatcode address from pallet-revive
105        if Address::from_slice(callee.as_bytes()) == CHEATCODE_ADDRESS {
106            return Some(ExecReturnValue {
107                flags: ReturnFlags::REVERT,
108                data: b"Cheatcodes are not available in polkadot runtime.".to_vec(),
109            });
110        }
111
112        let mut mock_inner = self.inner.borrow_mut();
113        let ctx = MockCallDataContext {
114            calldata: call_data.to_vec().into(),
115            value: Some(U256::from_limbs(value_transferred.0)),
116        };
117
118        // Use the same logic as in inspect.rs to find the correct mocked call and consume some of
119        // them. https://github.com/paritytech/foundry-polkadot/blob/26eda0de53ac03f7ac9b6a6023d8243101cffaf1/crates/cheatcodes/src/inspector.rs#L1013
120        if let Some(mock_data) =
121            mock_inner.mocked_calls.get_mut(&Address::from_slice(callee.as_bytes()))
122            && let Some(return_data_queue) = match mock_data.get_mut(&ctx) {
123                Some(found) => Some(found),
124                None => mock_data
125                    .iter_mut()
126                    .find(|(key, _)| {
127                        ctx.calldata.starts_with(&key.calldata)
128                            && (key.value.is_none()
129                                || ctx.value == key.value
130                                || (ctx.value == Some(U256::ZERO) && key.value.is_none()))
131                    })
132                    .map(|(_, v)| v),
133            }
134            && let Some(return_data) = if return_data_queue.len() == 1 {
135                // If the mocked calls stack has a single element in it, don't empty it
136                return_data_queue.front().map(|x| x.to_owned())
137            } else {
138                // Else, we pop the front element
139                return_data_queue.pop_front()
140            }
141        {
142            return Some(ExecReturnValue {
143                flags: if matches!(return_data.ret_type, InstructionResult::Revert) {
144                    ReturnFlags::REVERT
145                } else {
146                    ReturnFlags::default()
147                },
148                data: return_data.data.0.to_vec(),
149            });
150        }
151
152        None
153    }
154
155    fn mock_caller(&self, frames_len: usize) -> Option<OriginFor<Runtime>> {
156        let mock_inner = self.inner.borrow();
157        if frames_len == 0 && mock_inner.delegated_caller.is_none() {
158            return Some(mock_inner.caller.clone());
159        }
160        None
161    }
162
163    fn mock_origin(&self) -> Option<&ExecOrigin<Runtime>> {
164        Some(&self.origin)
165    }
166
167    fn mock_delegated_caller(
168        &self,
169        dest: H160,
170        input_data: &[u8],
171    ) -> Option<DelegateInfo<Runtime>> {
172        let mock_inner = self.inner.borrow();
173
174        // Mocked functions are implemented by making use of the hooks for delegated calls.
175        if let Some(mocked_function) =
176            mock_inner.mocked_functions.get(&Address::from_slice(dest.as_bytes()))
177        {
178            let input_data = Bytes::from(input_data.to_vec());
179            if let Some(target) = mocked_function
180                .get(&input_data)
181                .or_else(|| input_data.get(..4).and_then(|selector| mocked_function.get(selector)))
182            {
183                return Some(DelegateInfo {
184                    caller:
185        ExecOrigin::<Runtime>::from_runtime_origin(OriginFor::<Runtime>::signed(
186                        <revive_env::Runtime as
187        polkadot_sdk::pallet_revive::Config>::AddressMapper::to_account_id(&dest),
188                    )).ok()?,
189                callee: H160::from_slice(target.as_slice())
190                }
191                );
192            }
193        }
194
195        mock_inner.delegated_caller.as_ref().and_then(|delegate_caller| {
196            Some(DelegateInfo {
197                caller: ExecOrigin::<Runtime>::from_runtime_origin(delegate_caller.clone()).ok()?,
198                callee: mock_inner.callee,
199            })
200        })
201    }
202}
203
204// Internal struct that holds the mock state. It is wrapped in an Arc<Mutex<>> in MockHandlerImpl
205// to make it easier to transfer the state into Revive and back and be able to mutate it from the
206// MockHandler trait methods.
207#[derive(Clone)]
208struct MockHandlerInner<T: frame_system::Config + pallet_revive::Config> {
209    pub caller: OriginFor<T>,
210    pub delegated_caller: Option<OriginFor<T>>,
211    pub callee: H160,
212
213    pub mocked_calls: HashMap<Address, BTreeMap<MockCallDataContext, VecDeque<MockCallReturnData>>>,
214    pub mocked_functions: HashMap<Address, HashMap<Bytes, Address>>,
215}
216
217impl MockHandlerInner<Runtime> {
218    /// Creates a new MockHandlerInner from the given Ecx and Cheatcodes state.
219    /// Also returns whether a prank is currently enabled.
220    fn new(
221        _ecx: &Ecx<'_, '_, '_>,
222        caller: &Address,
223        target_address: Option<&Address>,
224        callee: Option<&Address>,
225        state: &mut foundry_cheatcodes::Cheatcodes,
226    ) -> Self {
227        let pranked_caller =
228            OriginFor::<Runtime>::signed(AccountId32Mapper::<Runtime>::to_fallback_account_id(
229                &H160::from_slice(caller.as_slice()),
230            ));
231
232        let delegated_caller = target_address.map(|addr| {
233            OriginFor::<Runtime>::signed(AccountId32Mapper::<Runtime>::to_fallback_account_id(
234                &H160::from_slice(addr.as_slice()),
235            ))
236        });
237
238        Self {
239            caller: pranked_caller,
240            delegated_caller,
241            mocked_calls: state.mocked_calls.clone(),
242            callee: callee.map(|addr| H160::from_slice(addr.as_slice())).unwrap_or_default(),
243            mocked_functions: state.mocked_functions.clone(),
244        }
245    }
246}