polkavm/
caller.rs

1use crate::api::BackendAccess;
2use crate::tracer::Tracer;
3use crate::Gas;
4use core::mem::MaybeUninit;
5use polkavm_common::error::Trap;
6use polkavm_common::program::Reg;
7use polkavm_common::utils::{Access, AsUninitSliceMut};
8use std::rc::{Rc, Weak};
9
10pub(crate) struct CallerRaw {
11    user_data: *mut core::ffi::c_void,
12    access: *mut core::ffi::c_void,
13    tracer: Option<Tracer>,
14}
15
16// SAFETY: Most of the methods of this struct are `unsafe` and the callers will uphold the invariants to ensure that this is safe.
17unsafe impl Send for CallerRaw {}
18
19// SAFETY: Most of the methods of this struct are `unsafe` and the callers will uphold the invariants to ensure that this is safe.
20unsafe impl Sync for CallerRaw {}
21
22impl CallerRaw {
23    pub(crate) fn new(tracer: Option<Tracer>) -> Self {
24        CallerRaw {
25            user_data: core::ptr::null_mut(),
26            access: core::ptr::null_mut(),
27            tracer,
28        }
29    }
30
31    unsafe fn data<T>(&self) -> &T {
32        // SAFETY: The caller will make sure that the invariants hold.
33        unsafe { &*(self.user_data as *const T) }
34    }
35
36    unsafe fn data_mut<T>(&mut self) -> &mut T {
37        // SAFETY: The caller will make sure that the invariants hold.
38        unsafe { &mut *self.user_data.cast::<T>() }
39    }
40
41    unsafe fn access(&self) -> &BackendAccess {
42        // SAFETY: The caller will make sure that the invariants hold.
43        unsafe { &*(self.access.cast::<BackendAccess>().cast_const()) }
44    }
45
46    unsafe fn access_mut(&mut self) -> &mut BackendAccess {
47        // SAFETY: The caller will make sure that the invariants hold.
48        unsafe { &mut *self.access.cast::<BackendAccess>() }
49    }
50
51    pub(crate) fn tracer(&mut self) -> Option<&mut Tracer> {
52        self.tracer.as_mut()
53    }
54
55    unsafe fn get_reg(&self, reg: Reg) -> u32 {
56        // SAFETY: The caller will make sure that the invariants hold.
57        let value = unsafe { self.access() }.get_reg(reg);
58        log::trace!("Getting register (during hostcall): {reg} = 0x{value:x}");
59        value
60    }
61
62    unsafe fn set_reg(&mut self, reg: Reg, value: u32) {
63        log::trace!("Setting register (during hostcall): {reg} = 0x{value:x}");
64
65        // SAFETY: The caller will make sure that the invariants hold.
66        unsafe { self.access_mut() }.set_reg(reg, value);
67
68        if let Some(ref mut tracer) = self.tracer() {
69            tracer.on_set_reg_in_hostcall(reg, value);
70        }
71    }
72
73    unsafe fn read_memory_into_slice<'slice, B>(&self, address: u32, buffer: &'slice mut B) -> Result<&'slice mut [u8], Trap>
74    where
75        B: ?Sized + AsUninitSliceMut,
76    {
77        // SAFETY: The caller will make sure that the invariants hold.
78        let access = unsafe { self.access() };
79
80        log::trace!(
81            "Reading memory (during hostcall): 0x{:x}-0x{:x} ({} bytes)",
82            address,
83            (address as usize + buffer.as_uninit_slice_mut().len()) as u32,
84            buffer.as_uninit_slice_mut().len()
85        );
86        access.read_memory_into_slice(address, buffer)
87    }
88
89    unsafe fn read_memory_into_vec(&self, address: u32, length: u32) -> Result<Vec<u8>, Trap> {
90        log::trace!(
91            "Reading memory (during hostcall): 0x{:x}-0x{:x} ({} bytes)",
92            address,
93            address.wrapping_add(length),
94            length
95        );
96
97        // SAFETY: The caller will make sure that the invariants hold.
98        unsafe { self.access() }.read_memory_into_vec(address, length)
99    }
100
101    unsafe fn read_u32(&self, address: u32) -> Result<u32, Trap> {
102        let mut buffer: MaybeUninit<[u8; 4]> = MaybeUninit::uninit();
103
104        // SAFETY: The caller will make sure that the invariants hold.
105        let slice = unsafe { self.read_memory_into_slice(address, &mut buffer) }?;
106        let value = u32::from_le_bytes([slice[0], slice[1], slice[2], slice[3]]);
107        Ok(value)
108    }
109
110    unsafe fn write_memory(&mut self, address: u32, data: &[u8]) -> Result<(), Trap> {
111        log::trace!(
112            "Writing memory (during hostcall): 0x{:x}-0x{:x} ({} bytes)",
113            address,
114            (address as usize + data.len()) as u32,
115            data.len()
116        );
117
118        // SAFETY: The caller will make sure that the invariants hold.
119        let result = unsafe { self.access_mut() }.write_memory(address, data);
120
121        if let Some(ref mut tracer) = self.tracer() {
122            tracer.on_memory_write_in_hostcall(address, data, result.is_ok())?;
123        }
124
125        result
126    }
127
128    unsafe fn sbrk(&mut self, size: u32) -> Option<u32> {
129        // SAFETY: The caller will make sure that the invariants hold.
130        unsafe { self.access_mut() }.sbrk(size)
131    }
132
133    unsafe fn gas_remaining(&self) -> Option<Gas> {
134        // SAFETY: The caller will make sure that the invariants hold.
135        unsafe { self.access() }.gas_remaining()
136    }
137
138    unsafe fn consume_gas(&mut self, gas: u64) {
139        // SAFETY: The caller will make sure that the invariants hold.
140        unsafe { self.access_mut() }.consume_gas(gas)
141    }
142}
143
144/// A handle used to access the execution context.
145pub struct Caller<'a, T> {
146    raw: &'a mut CallerRaw,
147    lifetime: *mut Option<Rc<()>>,
148    _phantom: core::marker::PhantomData<&'a mut T>,
149}
150
151impl<'a, T> Caller<'a, T> {
152    pub(crate) fn wrap<R>(
153        user_data: &mut T,
154        access: &'a mut BackendAccess<'_>,
155        raw: &'a mut CallerRaw,
156        callback: impl FnOnce(Self) -> R,
157    ) -> R
158    where
159        T: 'a,
160    {
161        raw.user_data = (user_data as *mut T).cast::<core::ffi::c_void>();
162        raw.access = (access as *mut BackendAccess).cast::<core::ffi::c_void>();
163
164        let mut lifetime = None;
165        let caller = Caller {
166            raw,
167            lifetime: &mut lifetime,
168            _phantom: core::marker::PhantomData,
169        };
170
171        let result = callback(caller);
172
173        core::mem::drop(lifetime);
174        result
175    }
176
177    /// Creates a caller handle with dynamically checked borrow rules.
178    pub fn into_ref(self) -> CallerRef<T> {
179        let lifetime = Rc::new(());
180        let lifetime_weak = Rc::downgrade(&lifetime);
181
182        // SAFETY: This can only be called from inside of `Caller::wrap` so the pointer to `lifetime` is always valid.
183        unsafe {
184            *self.lifetime = Some(lifetime);
185        }
186
187        CallerRef {
188            raw: self.raw,
189            lifetime: lifetime_weak,
190            _phantom: core::marker::PhantomData,
191        }
192    }
193
194    pub fn split(self) -> (Caller<'a, ()>, &'a mut T) {
195        let Caller { raw, lifetime, _phantom } = self;
196        let dummy: *mut () = &mut () as *mut ();
197        let dummy: *mut core::ffi::c_void = dummy.cast();
198        let user_data: *mut core::ffi::c_void = core::mem::replace(&mut raw.user_data, dummy);
199        let user_data: *mut T = user_data.cast();
200
201        // SAFETY: This can only be called from inside of `Caller::wrap` so this is always valid.
202        let user_data = unsafe { &mut *user_data };
203        let caller = Caller {
204            raw,
205            lifetime,
206            _phantom: core::marker::PhantomData,
207        };
208
209        (caller, user_data)
210    }
211
212    pub fn data(&self) -> &T {
213        // SAFETY: This can only be called from inside of `Caller::wrap` so this is always valid.
214        unsafe { self.raw.data() }
215    }
216
217    pub fn data_mut(&mut self) -> &mut T {
218        // SAFETY: This can only be called from inside of `Caller::wrap` so this is always valid.
219        unsafe { self.raw.data_mut() }
220    }
221
222    pub fn get_reg(&self, reg: Reg) -> u32 {
223        // SAFETY: This can only be called from inside of `Caller::wrap` so this is always valid.
224        unsafe { self.raw.get_reg(reg) }
225    }
226
227    pub fn set_reg(&mut self, reg: Reg, value: u32) {
228        // SAFETY: This can only be called from inside of `Caller::wrap` so this is always valid.
229        unsafe { self.raw.set_reg(reg, value) }
230    }
231
232    pub fn read_memory_into_slice<'slice, B>(&self, address: u32, buffer: &'slice mut B) -> Result<&'slice mut [u8], Trap>
233    where
234        B: ?Sized + AsUninitSliceMut,
235    {
236        // SAFETY: This can only be called from inside of `Caller::wrap` so this is always valid.
237        unsafe { self.raw.read_memory_into_slice(address, buffer) }
238    }
239
240    pub fn read_memory_into_vec(&self, address: u32, length: u32) -> Result<Vec<u8>, Trap> {
241        // SAFETY: This can only be called from inside of `Caller::wrap` so this is always valid.
242        unsafe { self.raw.read_memory_into_vec(address, length) }
243    }
244
245    pub fn read_u32(&self, address: u32) -> Result<u32, Trap> {
246        // SAFETY: This can only be called from inside of `Caller::wrap` so this is always valid.
247        unsafe { self.raw.read_u32(address) }
248    }
249
250    pub fn write_memory(&mut self, address: u32, data: &[u8]) -> Result<(), Trap> {
251        // SAFETY: This can only be called from inside of `Caller::wrap` so this is always valid.
252        unsafe { self.raw.write_memory(address, data) }
253    }
254
255    pub fn sbrk(&mut self, size: u32) -> Option<u32> {
256        // SAFETY: This can only be called from inside of `Caller::wrap` so this is always valid.
257        unsafe { self.raw.sbrk(size) }
258    }
259
260    pub fn gas_remaining(&self) -> Option<Gas> {
261        // SAFETY: This can only be called from inside of `Caller::wrap` so this is always valid.
262        unsafe { self.raw.gas_remaining() }
263    }
264
265    pub fn consume_gas(&mut self, gas: u64) {
266        // SAFETY: This can only be called from inside of `Caller::wrap` so this is always valid.
267        unsafe { self.raw.consume_gas(gas) }
268    }
269}
270
271/// A handle used to access the execution context, with erased lifetimes for convenience.
272///
273/// Can only be used from within the handler to which the original [`Caller`] was passed.
274/// Will panic if used incorrectly.
275pub struct CallerRef<T> {
276    raw: *mut CallerRaw,
277    lifetime: Weak<()>,
278    _phantom: core::marker::PhantomData<T>,
279}
280
281impl<T> CallerRef<T> {
282    fn check_lifetime_or_panic(&self) {
283        assert!(self.lifetime.strong_count() > 0, "CallerRef accessed outside of a hostcall handler");
284    }
285
286    pub fn data(&self) -> &T {
287        self.check_lifetime_or_panic();
288
289        // SAFETY: We've made sure the lifetime is valid.
290        unsafe { (*self.raw).data() }
291    }
292
293    pub fn data_mut(&mut self) -> &mut T {
294        self.check_lifetime_or_panic();
295
296        // SAFETY: We've made sure the lifetime is valid.
297        unsafe { (*self.raw).data_mut() }
298    }
299
300    pub fn get_reg(&self, reg: Reg) -> u32 {
301        self.check_lifetime_or_panic();
302
303        // SAFETY: We've made sure the lifetime is valid.
304        unsafe { (*self.raw).get_reg(reg) }
305    }
306
307    pub fn set_reg(&mut self, reg: Reg, value: u32) {
308        self.check_lifetime_or_panic();
309
310        // SAFETY: We've made sure the lifetime is valid.
311        unsafe { (*self.raw).set_reg(reg, value) }
312    }
313
314    pub fn read_memory_into_slice<'slice, B>(&self, address: u32, buffer: &'slice mut B) -> Result<&'slice mut [u8], Trap>
315    where
316        B: ?Sized + AsUninitSliceMut,
317    {
318        self.check_lifetime_or_panic();
319
320        // SAFETY: We've made sure the lifetime is valid.
321        unsafe { (*self.raw).read_memory_into_slice(address, buffer) }
322    }
323
324    pub fn read_memory_into_vec(&self, address: u32, length: u32) -> Result<Vec<u8>, Trap> {
325        self.check_lifetime_or_panic();
326
327        // SAFETY: We've made sure the lifetime is valid.
328        unsafe { (*self.raw).read_memory_into_vec(address, length) }
329    }
330
331    pub fn read_u32(&self, address: u32) -> Result<u32, Trap> {
332        self.check_lifetime_or_panic();
333
334        // SAFETY: We've made sure the lifetime is valid.
335        unsafe { (*self.raw).read_u32(address) }
336    }
337
338    pub fn write_memory(&mut self, address: u32, data: &[u8]) -> Result<(), Trap> {
339        self.check_lifetime_or_panic();
340
341        // SAFETY: We've made sure the lifetime is valid.
342        unsafe { (*self.raw).write_memory(address, data) }
343    }
344
345    pub fn gas_remaining(&self) -> Option<Gas> {
346        self.check_lifetime_or_panic();
347
348        // SAFETY: We've made sure the lifetime is valid.
349        unsafe { (*self.raw).gas_remaining() }
350    }
351
352    pub fn consume_gas(&mut self, gas: u64) {
353        self.check_lifetime_or_panic();
354
355        // SAFETY: We've made sure the lifetime is valid.
356        unsafe { (*self.raw).consume_gas(gas) }
357    }
358}
359
360// Source: https://users.rust-lang.org/t/a-macro-to-assert-that-a-type-does-not-implement-trait-bounds/31179
361macro_rules! assert_not_impl {
362    ($x:ty, $($t:path),+ $(,)*) => {
363        const _: fn() -> () = || {
364            struct Check<T: ?Sized>(T);
365            trait AmbiguousIfImpl<A> { fn some_item() { } }
366
367            impl<T: ?Sized> AmbiguousIfImpl<()> for Check<T> { }
368            impl<T: ?Sized $(+ $t)*> AmbiguousIfImpl<u8> for Check<T> { }
369
370            <Check::<$x> as AmbiguousIfImpl<_>>::some_item()
371        };
372    };
373}
374
375assert_not_impl!(CallerRef<()>, Send);
376assert_not_impl!(CallerRef<()>, Sync);
377assert_not_impl!(Caller<'static, ()>, Send);
378assert_not_impl!(Caller<'static, ()>, Sync);