polkavm/
linker.rs

1use crate::api::{MemoryProtection, RegValue};
2use crate::error::bail;
3use crate::program::ProgramSymbol;
4use crate::{Error, InterruptKind, Module, ProgramCounter, RawInstance, Reg};
5use alloc::borrow::ToOwned;
6use alloc::format;
7use alloc::string::String;
8use alloc::sync::Arc;
9use alloc::vec::Vec;
10use core::marker::PhantomData;
11
12#[cfg(not(feature = "std"))]
13use alloc::collections::btree_map::Entry;
14#[cfg(not(feature = "std"))]
15use alloc::collections::BTreeMap as LookupMap;
16
17#[cfg(feature = "std")]
18use std::collections::hash_map::Entry;
19#[cfg(feature = "std")]
20use std::collections::HashMap as LookupMap;
21
22trait CallFn<UserData, UserError>: Send + Sync {
23    fn call(&self, user_data: &mut UserData, instance: &mut RawInstance) -> Result<(), UserError>;
24}
25
26#[repr(transparent)]
27pub struct CallFnArc<UserData, UserError>(Arc<dyn CallFn<UserData, UserError>>);
28
29type FallbackHandlerArc<UserData, UserError> = Arc<dyn Fn(Caller<UserData>, u32) -> Result<(), UserError> + Send + Sync + 'static>;
30
31impl<UserData, UserError> Clone for CallFnArc<UserData, UserError> {
32    fn clone(&self) -> Self {
33        Self(Arc::clone(&self.0))
34    }
35}
36
37pub trait IntoCallFn<UserData, UserError, Params, Result>: Send + Sync + 'static {
38    #[doc(hidden)]
39    const _REGS_REQUIRED_32: usize;
40
41    #[doc(hidden)]
42    const _REGS_REQUIRED_64: usize;
43
44    #[doc(hidden)]
45    fn _into_extern_fn(self) -> CallFnArc<UserData, UserError>;
46}
47
48/// A type which can be marshalled through the VM's FFI boundary.
49pub trait AbiTy: Sized + Send + 'static {
50    #[doc(hidden)]
51    const _REGS_REQUIRED_32: usize;
52
53    #[doc(hidden)]
54    const _REGS_REQUIRED_64: usize;
55
56    #[doc(hidden)]
57    fn _get32(get_reg: impl FnMut() -> RegValue) -> Self;
58
59    #[doc(hidden)]
60    fn _get64(get_reg: impl FnMut() -> RegValue) -> Self;
61
62    #[doc(hidden)]
63    fn _set32(self, set_reg: impl FnMut(RegValue));
64
65    #[doc(hidden)]
66    fn _set64(self, set_reg: impl FnMut(RegValue));
67}
68
69impl AbiTy for u32 {
70    const _REGS_REQUIRED_32: usize = 1;
71    const _REGS_REQUIRED_64: usize = 1;
72
73    fn _get32(mut get_reg: impl FnMut() -> RegValue) -> Self {
74        get_reg() as u32
75    }
76
77    fn _get64(mut get_reg: impl FnMut() -> RegValue) -> Self {
78        get_reg() as u32
79    }
80
81    fn _set32(self, mut set_reg: impl FnMut(RegValue)) {
82        set_reg(u64::from(self))
83    }
84
85    fn _set64(self, mut set_reg: impl FnMut(RegValue)) {
86        set_reg(u64::from(self))
87    }
88}
89
90impl AbiTy for i32 {
91    const _REGS_REQUIRED_32: usize = <u32 as AbiTy>::_REGS_REQUIRED_32;
92    const _REGS_REQUIRED_64: usize = <u32 as AbiTy>::_REGS_REQUIRED_64;
93
94    fn _get32(get_reg: impl FnMut() -> RegValue) -> Self {
95        <u32 as AbiTy>::_get32(get_reg) as i32
96    }
97
98    fn _get64(get_reg: impl FnMut() -> RegValue) -> Self {
99        <u32 as AbiTy>::_get64(get_reg) as i32
100    }
101
102    fn _set32(self, set_reg: impl FnMut(RegValue)) {
103        (self as u32)._set32(set_reg)
104    }
105
106    fn _set64(self, set_reg: impl FnMut(RegValue)) {
107        i64::from(self)._set64(set_reg)
108    }
109}
110
111impl AbiTy for u64 {
112    const _REGS_REQUIRED_32: usize = 2;
113    const _REGS_REQUIRED_64: usize = 1;
114
115    fn _get32(mut get_reg: impl FnMut() -> RegValue) -> Self {
116        let value_lo = get_reg();
117        let value_hi = get_reg();
118        debug_assert!(value_lo <= u64::from(u32::MAX));
119        debug_assert!(value_hi <= u64::from(u32::MAX));
120        value_lo | (value_hi << 32)
121    }
122
123    fn _get64(mut get_reg: impl FnMut() -> RegValue) -> Self {
124        get_reg()
125    }
126
127    fn _set32(self, mut set_reg: impl FnMut(RegValue)) {
128        set_reg(self);
129        set_reg(self >> 32);
130    }
131
132    fn _set64(self, mut set_reg: impl FnMut(RegValue)) {
133        set_reg(self);
134    }
135}
136
137impl AbiTy for i64 {
138    const _REGS_REQUIRED_32: usize = <u64 as AbiTy>::_REGS_REQUIRED_32;
139    const _REGS_REQUIRED_64: usize = <u64 as AbiTy>::_REGS_REQUIRED_64;
140
141    fn _get32(get_reg: impl FnMut() -> RegValue) -> Self {
142        <u64 as AbiTy>::_get32(get_reg) as i64
143    }
144
145    fn _get64(get_reg: impl FnMut() -> RegValue) -> Self {
146        <u64 as AbiTy>::_get64(get_reg) as i64
147    }
148
149    fn _set32(self, set_reg: impl FnMut(RegValue)) {
150        (self as u64)._set32(set_reg)
151    }
152
153    fn _set64(self, set_reg: impl FnMut(RegValue)) {
154        (self as u64)._set64(set_reg)
155    }
156}
157
158// `AbiTy` is deliberately not implemented for `usize`.
159
160/// A type which can be returned from a host function.
161pub trait ReturnTy<UserError>: Sized + 'static {
162    #[doc(hidden)]
163    const _REGS_REQUIRED_32: usize;
164
165    #[doc(hidden)]
166    const _REGS_REQUIRED_64: usize;
167
168    #[doc(hidden)]
169    fn _handle_return32(self, set_reg: impl FnMut(RegValue)) -> Result<(), UserError>;
170
171    #[doc(hidden)]
172    fn _handle_return64(self, set_reg: impl FnMut(RegValue)) -> Result<(), UserError>;
173}
174
175impl<UserError, T> ReturnTy<UserError> for T
176where
177    T: AbiTy,
178{
179    const _REGS_REQUIRED_32: usize = <T as AbiTy>::_REGS_REQUIRED_32;
180    const _REGS_REQUIRED_64: usize = <T as AbiTy>::_REGS_REQUIRED_64;
181
182    fn _handle_return32(self, set_reg: impl FnMut(RegValue)) -> Result<(), UserError> {
183        self._set32(set_reg);
184        Ok(())
185    }
186
187    fn _handle_return64(self, set_reg: impl FnMut(RegValue)) -> Result<(), UserError> {
188        self._set64(set_reg);
189        Ok(())
190    }
191}
192
193impl<UserError> ReturnTy<UserError> for () {
194    const _REGS_REQUIRED_32: usize = 0;
195    const _REGS_REQUIRED_64: usize = 0;
196
197    fn _handle_return32(self, _set_reg: impl FnMut(RegValue)) -> Result<(), UserError> {
198        Ok(())
199    }
200
201    fn _handle_return64(self, _set_reg: impl FnMut(RegValue)) -> Result<(), UserError> {
202        Ok(())
203    }
204}
205
206impl<UserError, E> ReturnTy<UserError> for Result<(), E>
207where
208    UserError: From<E>,
209    E: 'static,
210{
211    const _REGS_REQUIRED_32: usize = 0;
212    const _REGS_REQUIRED_64: usize = 0;
213
214    fn _handle_return32(self, _set_reg: impl FnMut(RegValue)) -> Result<(), UserError> {
215        Ok(self?)
216    }
217
218    fn _handle_return64(self, _set_reg: impl FnMut(RegValue)) -> Result<(), UserError> {
219        Ok(self?)
220    }
221}
222
223impl<UserError, T, E> ReturnTy<UserError> for Result<T, E>
224where
225    UserError: From<E>,
226    E: 'static,
227    T: AbiTy,
228{
229    const _REGS_REQUIRED_32: usize = <T as AbiTy>::_REGS_REQUIRED_32;
230    const _REGS_REQUIRED_64: usize = <T as AbiTy>::_REGS_REQUIRED_64;
231
232    fn _handle_return32(self, set_reg: impl FnMut(RegValue)) -> Result<(), UserError> {
233        self?._set32(set_reg);
234        Ok(())
235    }
236
237    fn _handle_return64(self, set_reg: impl FnMut(RegValue)) -> Result<(), UserError> {
238        self?._set64(set_reg);
239        Ok(())
240    }
241}
242
243pub trait FuncArgs: Send {
244    #[doc(hidden)]
245    const _REGS_REQUIRED_32: usize;
246    #[doc(hidden)]
247    const _REGS_REQUIRED_64: usize;
248
249    #[doc(hidden)]
250    fn _set(self, is_64_bit: bool, set_reg: impl FnMut(RegValue))
251    where
252        Self: Sized,
253    {
254        if is_64_bit {
255            self._set64(set_reg);
256        } else {
257            self._set32(set_reg);
258        }
259    }
260
261    #[doc(hidden)]
262    fn _set32(self, set_reg: impl FnMut(RegValue));
263
264    #[doc(hidden)]
265    fn _set64(self, set_reg: impl FnMut(RegValue));
266}
267
268pub trait FuncResult: Send + Sized {
269    #[doc(hidden)]
270    const _REGS_REQUIRED_32: usize;
271    #[doc(hidden)]
272    const _REGS_REQUIRED_64: usize;
273
274    #[doc(hidden)]
275    fn _get(is_64_bit: bool, get_reg: impl FnMut() -> RegValue) -> Self {
276        if is_64_bit {
277            Self::_get64(get_reg)
278        } else {
279            Self::_get32(get_reg)
280        }
281    }
282
283    #[doc(hidden)]
284    fn _get32(get_reg: impl FnMut() -> RegValue) -> Self;
285
286    #[doc(hidden)]
287    fn _get64(get_reg: impl FnMut() -> RegValue) -> Self;
288}
289
290impl FuncResult for () {
291    const _REGS_REQUIRED_32: usize = 0;
292    const _REGS_REQUIRED_64: usize = 0;
293
294    fn _get32(_: impl FnMut() -> RegValue) -> Self {}
295    fn _get64(_: impl FnMut() -> RegValue) -> Self {}
296}
297
298impl<T> FuncResult for T
299where
300    T: AbiTy,
301{
302    const _REGS_REQUIRED_32: usize = <T as AbiTy>::_REGS_REQUIRED_32;
303    const _REGS_REQUIRED_64: usize = <T as AbiTy>::_REGS_REQUIRED_64;
304
305    fn _get32(get_reg: impl FnMut() -> RegValue) -> Self {
306        <T as AbiTy>::_get32(get_reg)
307    }
308
309    fn _get64(get_reg: impl FnMut() -> RegValue) -> Self {
310        <T as AbiTy>::_get64(get_reg)
311    }
312}
313
314macro_rules! impl_into_extern_fn {
315    (@check_reg_count $regs_required:expr) => {
316        if $regs_required > Reg::ARG_REGS.len() {
317            // TODO: We should probably print out which exact function it is.
318            panic!("external call failed: too many registers required for arguments!");
319        }
320    };
321
322    (@call $is_64_bit:expr, $caller:expr, $callback:expr, ) => {{
323        ($callback)($caller)
324    }};
325
326    (@get_reg $caller:expr) => {{
327        let mut reg_index = 0;
328        let caller = &mut $caller;
329        move || -> RegValue {
330            let value = caller.instance.reg(Reg::ARG_REGS[reg_index]);
331            reg_index += 1;
332            value
333        }
334    }};
335
336    (@call $is_64_bit:expr, $caller:expr, $callback:expr, $a0:ident) => {{
337        let cb = impl_into_extern_fn!(@get_reg $caller);
338        let a0;
339        if $is_64_bit {
340            impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_64);
341            a0 = $a0::_get64(cb);
342        } else {
343            impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_32);
344            a0 = $a0::_get32(cb);
345        }
346
347        ($callback)($caller, a0)
348    }};
349
350    (@call $is_64_bit:expr, $caller:expr, $callback:expr, $a0:ident, $a1:ident) => {{
351        let mut cb = impl_into_extern_fn!(@get_reg $caller);
352        let a0;
353        let a1;
354        if $is_64_bit {
355            impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_64 + $a1::_REGS_REQUIRED_64);
356            a0 = $a0::_get64(&mut cb);
357            a1 = $a1::_get64(cb);
358        } else {
359            impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_32 + $a1::_REGS_REQUIRED_32);
360            a0 = $a0::_get32(&mut cb);
361            a1 = $a1::_get32(cb);
362        }
363
364        ($callback)($caller, a0, a1)
365    }};
366
367    (@call $is_64_bit:expr, $caller:expr, $callback:expr, $a0:ident, $a1:ident, $a2:ident) => {{
368        let mut cb = impl_into_extern_fn!(@get_reg $caller);
369        let a0;
370        let a1;
371        let a2;
372        if $is_64_bit {
373            impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_64 + $a1::_REGS_REQUIRED_64 + $a2::_REGS_REQUIRED_64);
374            a0 = $a0::_get64(&mut cb);
375            a1 = $a1::_get64(&mut cb);
376            a2 = $a2::_get64(cb);
377        } else {
378            impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_32 + $a1::_REGS_REQUIRED_32 + $a2::_REGS_REQUIRED_32);
379            a0 = $a0::_get32(&mut cb);
380            a1 = $a1::_get32(&mut cb);
381            a2 = $a2::_get32(cb);
382        }
383
384        ($callback)($caller, a0, a1, a2)
385    }};
386
387    (@call $is_64_bit:expr, $caller:expr, $callback:expr, $a0:ident, $a1:ident, $a2:ident, $a3:ident) => {{
388        let mut cb = impl_into_extern_fn!(@get_reg $caller);
389        let a0;
390        let a1;
391        let a2;
392        let a3;
393        if $is_64_bit {
394            impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_64 + $a1::_REGS_REQUIRED_64 + $a2::_REGS_REQUIRED_64 + $a3::_REGS_REQUIRED_64);
395            a0 = $a0::_get64(&mut cb);
396            a1 = $a1::_get64(&mut cb);
397            a2 = $a2::_get64(&mut cb);
398            a3 = $a3::_get64(cb);
399        } else {
400            impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_32 + $a1::_REGS_REQUIRED_32 + $a2::_REGS_REQUIRED_32 + $a3::_REGS_REQUIRED_32);
401            a0 = $a0::_get32(&mut cb);
402            a1 = $a1::_get32(&mut cb);
403            a2 = $a2::_get32(&mut cb);
404            a3 = $a3::_get32(cb);
405        }
406
407        ($callback)($caller, a0, a1, a2, a3)
408    }};
409
410    (@call $is_64_bit:expr, $caller:expr, $callback:expr, $a0:ident, $a1:ident, $a2:ident, $a3:ident, $a4:ident) => {{
411        let mut cb = impl_into_extern_fn!(@get_reg $caller);
412        let a0;
413        let a1;
414        let a2;
415        let a3;
416        let a4;
417        if $is_64_bit {
418            impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_64 + $a1::_REGS_REQUIRED_64 + $a2::_REGS_REQUIRED_64 + $a3::_REGS_REQUIRED_64 + $a4::_REGS_REQUIRED_64);
419            a0 = $a0::_get64(&mut cb);
420            a1 = $a1::_get64(&mut cb);
421            a2 = $a2::_get64(&mut cb);
422            a3 = $a3::_get64(&mut cb);
423            a4 = $a4::_get64(cb);
424        } else {
425            impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_32 + $a1::_REGS_REQUIRED_32 + $a2::_REGS_REQUIRED_32 + $a3::_REGS_REQUIRED_32 + $a4::_REGS_REQUIRED_32);
426            a0 = $a0::_get32(&mut cb);
427            a1 = $a1::_get32(&mut cb);
428            a2 = $a2::_get32(&mut cb);
429            a3 = $a3::_get32(&mut cb);
430            a4 = $a4::_get32(cb);
431        }
432
433        ($callback)($caller, a0, a1, a2, a3, a4)
434    }};
435
436    (@call $is_64_bit:expr, $caller:expr, $callback:expr, $a0:ident, $a1:ident, $a2:ident, $a3:ident, $a4:ident, $a5:ident) => {{
437        let mut cb = impl_into_extern_fn!(@get_reg $caller);
438        let a0;
439        let a1;
440        let a2;
441        let a3;
442        let a4;
443        let a5;
444        if $is_64_bit {
445            impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_64 + $a1::_REGS_REQUIRED_64 + $a2::_REGS_REQUIRED_64 + $a3::_REGS_REQUIRED_64 + $a4::_REGS_REQUIRED_64 + $a5::_REGS_REQUIRED_64);
446            a0 = $a0::_get64(&mut cb);
447            a1 = $a1::_get64(&mut cb);
448            a2 = $a2::_get64(&mut cb);
449            a3 = $a3::_get64(&mut cb);
450            a4 = $a4::_get64(&mut cb);
451            a5 = $a5::_get64(cb);
452        } else {
453            impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_32 + $a1::_REGS_REQUIRED_32 + $a2::_REGS_REQUIRED_32 + $a3::_REGS_REQUIRED_32 + $a4::_REGS_REQUIRED_32 + $a5::_REGS_REQUIRED_32);
454            a0 = $a0::_get32(&mut cb);
455            a1 = $a1::_get32(&mut cb);
456            a2 = $a2::_get32(&mut cb);
457            a3 = $a3::_get32(&mut cb);
458            a4 = $a4::_get32(&mut cb);
459            a5 = $a5::_get32(cb);
460        }
461
462        ($callback)($caller, a0, a1, a2, a3, a4, a5)
463    }};
464
465    ($arg_count:tt $($args:ident)*) => {
466        impl<UserData, UserError, F, $($args,)* R> CallFn<UserData, UserError> for (F, UnsafePhantomData<(R, $($args),*)>)
467            where
468            F: Fn(Caller<'_, UserData>, $($args),*) -> R + Send + Sync + 'static,
469            $($args: AbiTy,)*
470            R: ReturnTy<UserError>,
471        {
472            fn call(&self, user_data: &mut UserData, instance: &mut RawInstance) -> Result<(), UserError> {
473                let is_64_bit = instance.module().blob().is_64_bit();
474                let result = {
475                    #[allow(unused_mut)]
476                    let mut caller = Caller {
477                        user_data,
478                        instance
479                    };
480
481                    impl_into_extern_fn!(@call is_64_bit, caller, self.0, $($args),*)
482                };
483
484                let set_reg = {
485                    let mut reg_index = 0;
486                    move |value: RegValue| {
487                        let reg = Reg::ARG_REGS[reg_index];
488                        instance.set_reg(reg, value);
489                        reg_index += 1;
490                    }
491                };
492
493                if is_64_bit {
494                    result._handle_return64(set_reg)
495                } else {
496                    result._handle_return32(set_reg)
497                }
498            }
499        }
500
501        impl<UserData, UserError, F, $($args,)* R> IntoCallFn<UserData, UserError, ($($args,)*), R> for F
502        where
503            F: Fn($($args),*) -> R + Send + Sync + 'static,
504            $($args: AbiTy,)*
505            R: ReturnTy<UserError>,
506        {
507            const _REGS_REQUIRED_32: usize = 0 $(+ $args::_REGS_REQUIRED_32)*;
508            const _REGS_REQUIRED_64: usize = 0 $(+ $args::_REGS_REQUIRED_64)*;
509
510            fn _into_extern_fn(self) -> CallFnArc<UserData, UserError> {
511                #[allow(non_snake_case)]
512                let callback = move |_caller: Caller<UserData>, $($args: $args),*| -> R {
513                    self($($args),*)
514                };
515                CallFnArc(Arc::new((callback, UnsafePhantomData(PhantomData::<(R, $($args),*)>))))
516            }
517        }
518
519        impl<UserData, UserError, F, $($args,)* R> IntoCallFn<UserData, UserError, (Caller<'_, UserData>, $($args,)*), R> for F
520        where
521            F: Fn(Caller<'_, UserData>, $($args),*) -> R + Send + Sync + 'static,
522            $($args: AbiTy,)*
523            R: ReturnTy<UserError>,
524        {
525            const _REGS_REQUIRED_32: usize = 0 $(+ $args::_REGS_REQUIRED_32)*;
526            const _REGS_REQUIRED_64: usize = 0 $(+ $args::_REGS_REQUIRED_64)*;
527
528            fn _into_extern_fn(self) -> CallFnArc<UserData, UserError> {
529                CallFnArc(Arc::new((self, UnsafePhantomData(PhantomData::<(R, $($args),*)>))))
530            }
531        }
532
533        impl<$($args: Send + AbiTy,)*> FuncArgs for ($($args,)*) {
534            const _REGS_REQUIRED_32: usize = 0 $(+ $args::_REGS_REQUIRED_32)*;
535            const _REGS_REQUIRED_64: usize = 0 $(+ $args::_REGS_REQUIRED_64)*;
536
537            #[allow(unused_mut)]
538            #[allow(unused_variables)]
539            #[allow(non_snake_case)]
540            fn _set32(self, mut set_reg: impl FnMut(RegValue)) {
541                let ($($args,)*) = self;
542                $($args._set32(&mut set_reg);)*
543            }
544
545            #[allow(unused_mut)]
546            #[allow(unused_variables)]
547            #[allow(non_snake_case)]
548            fn _set64(self, mut set_reg: impl FnMut(RegValue)) {
549                let ($($args,)*) = self;
550                $($args._set64(&mut set_reg);)*
551            }
552        }
553    };
554}
555
556impl_into_extern_fn!(0);
557impl_into_extern_fn!(1 A0);
558impl_into_extern_fn!(2 A0 A1);
559impl_into_extern_fn!(3 A0 A1 A2);
560impl_into_extern_fn!(4 A0 A1 A2 A3);
561impl_into_extern_fn!(5 A0 A1 A2 A3 A4);
562impl_into_extern_fn!(6 A0 A1 A2 A3 A4 A5);
563
564#[repr(transparent)]
565struct UnsafePhantomData<T>(PhantomData<T>);
566
567// SAFETY: This is only used to hold a type used exclusively at compile time, so regardless of whether it implements `Send` this will be safe.
568unsafe impl<T> Send for UnsafePhantomData<T> {}
569
570// SAFETY: This is only used to hold a type used exclusively at compile time, so regardless of whether it implements `Sync` this will be safe.
571unsafe impl<T> Sync for UnsafePhantomData<T> {}
572
573struct DynamicFn<T, F> {
574    callback: F,
575    _phantom: UnsafePhantomData<T>,
576}
577
578impl<UserData, UserError, F> CallFn<UserData, UserError> for DynamicFn<UserData, F>
579where
580    F: Fn(Caller<'_, UserData>) -> Result<(), UserError> + Send + Sync + 'static,
581    UserData: 'static,
582{
583    fn call(&self, user_data: &mut UserData, instance: &mut RawInstance) -> Result<(), UserError> {
584        let caller = Caller { user_data, instance };
585
586        (self.callback)(caller)
587    }
588}
589
590#[non_exhaustive]
591pub struct Caller<'a, UserData = ()> {
592    pub user_data: &'a mut UserData,
593    pub instance: &'a mut RawInstance,
594}
595
596pub struct Linker<UserData = (), UserError = core::convert::Infallible> {
597    host_functions: LookupMap<Vec<u8>, CallFnArc<UserData, UserError>>,
598    #[allow(clippy::type_complexity)]
599    fallback_handler: Option<FallbackHandlerArc<UserData, UserError>>,
600    phantom: PhantomData<(UserData, UserError)>,
601}
602
603impl<UserData, UserError> Default for Linker<UserData, UserError> {
604    fn default() -> Self {
605        Self::new()
606    }
607}
608
609impl<UserData, UserError> Linker<UserData, UserError> {
610    pub fn new() -> Self {
611        Self {
612            host_functions: Default::default(),
613            fallback_handler: None,
614            phantom: PhantomData,
615        }
616    }
617
618    /// Defines a fallback external call handler, in case no other registered functions match.
619    pub fn define_fallback(&mut self, func: impl Fn(Caller<UserData>, u32) -> Result<(), UserError> + Send + Sync + 'static) {
620        self.fallback_handler = Some(Arc::new(func));
621    }
622
623    /// Defines a new untyped handler for external calls with a given symbol.
624    pub fn define_untyped(
625        &mut self,
626        symbol: impl AsRef<[u8]>,
627        func: impl Fn(Caller<UserData>) -> Result<(), UserError> + Send + Sync + 'static,
628    ) -> Result<&mut Self, Error>
629    where
630        UserData: 'static,
631    {
632        let symbol = symbol.as_ref();
633        if self.host_functions.contains_key(symbol) {
634            bail!(
635                "cannot register host function: host function was already registered: {}",
636                ProgramSymbol::new(symbol)
637            );
638        }
639
640        self.host_functions.insert(
641            symbol.to_owned(),
642            CallFnArc(Arc::new(DynamicFn {
643                callback: func,
644                _phantom: UnsafePhantomData(PhantomData),
645            })),
646        );
647
648        Ok(self)
649    }
650
651    /// Defines a new statically typed handler for external calls with a given symbol.
652    pub fn define_typed<Params, Args>(
653        &mut self,
654        symbol: impl AsRef<[u8]>,
655        func: impl IntoCallFn<UserData, UserError, Params, Args>,
656    ) -> Result<&mut Self, Error> {
657        let symbol = symbol.as_ref();
658        if self.host_functions.contains_key(symbol) {
659            bail!(
660                "cannot register host function: host function was already registered: {}",
661                ProgramSymbol::new(symbol)
662            );
663        }
664
665        self.host_functions.insert(symbol.to_owned(), func._into_extern_fn());
666        Ok(self)
667    }
668
669    /// Pre-instantiates a new module, resolving its imports and exports.
670    pub fn instantiate_pre(&self, module: &Module) -> Result<InstancePre<UserData, UserError>, Error> {
671        let mut exports = LookupMap::new();
672        for export in module.exports() {
673            match exports.entry(export.symbol().as_bytes().to_owned()) {
674                Entry::Occupied(_) => {
675                    if module.is_strict() {
676                        return Err(format!("duplicate export: {}", export.symbol()).into());
677                    } else {
678                        log::debug!("Duplicate export: {}", export.symbol());
679                        continue;
680                    }
681                }
682                Entry::Vacant(entry) => {
683                    entry.insert(export.program_counter());
684                }
685            }
686        }
687
688        let mut imports: Vec<Option<CallFnArc<UserData, UserError>>> = Vec::with_capacity(module.imports().len() as usize);
689        for symbol in module.imports() {
690            let Some(symbol) = symbol else {
691                if module.is_strict() {
692                    return Err("failed to parse an import".into());
693                } else {
694                    imports.push(None);
695                    continue;
696                }
697            };
698
699            let host_fn = if let Some(host_fn) = self.host_functions.get(symbol.as_bytes()) {
700                Some(host_fn.clone())
701            } else if self.fallback_handler.is_some() {
702                None
703            } else if module.is_strict() {
704                return Err(format!("missing host function: {}", symbol).into());
705            } else {
706                log::debug!("Missing host function: {}", symbol);
707                None
708            };
709
710            imports.push(host_fn);
711        }
712
713        assert_eq!(imports.len(), module.imports().len() as usize);
714        Ok(InstancePre(Arc::new(InstancePreState {
715            module: module.clone(),
716            imports,
717            exports,
718            fallback_handler: self.fallback_handler.clone(),
719        })))
720    }
721}
722
723struct InstancePreState<UserData, UserError> {
724    module: Module,
725    imports: Vec<Option<CallFnArc<UserData, UserError>>>,
726    exports: LookupMap<Vec<u8>, ProgramCounter>,
727    fallback_handler: Option<FallbackHandlerArc<UserData, UserError>>,
728}
729
730pub struct InstancePre<UserData = (), UserError = core::convert::Infallible>(Arc<InstancePreState<UserData, UserError>>);
731
732impl<UserData, UserError> Clone for InstancePre<UserData, UserError> {
733    fn clone(&self) -> Self {
734        Self(Arc::clone(&self.0))
735    }
736}
737
738pub struct Instance<UserData = (), UserError = core::convert::Infallible> {
739    instance: RawInstance,
740    pre: InstancePre<UserData, UserError>,
741}
742
743impl<UserData, UserError> core::ops::Deref for Instance<UserData, UserError> {
744    type Target = RawInstance;
745    fn deref(&self) -> &Self::Target {
746        &self.instance
747    }
748}
749
750impl<UserData, UserError> core::ops::DerefMut for Instance<UserData, UserError> {
751    fn deref_mut(&mut self) -> &mut Self::Target {
752        &mut self.instance
753    }
754}
755
756#[derive(Debug)]
757pub enum CallError<UserError = core::convert::Infallible> {
758    /// The execution finished abnormally with a trap.
759    Trap,
760
761    /// The execution ran out of gas.
762    NotEnoughGas,
763
764    /// The execution failed.
765    Error(Error),
766
767    /// The execution failed with a custom user error.
768    User(UserError),
769
770    /// The execution stepped through one instruction.
771    ///
772    /// Requires execution step-tracing to be enabled with [`ModuleConfig::set_step_tracing`](crate::ModuleConfig::set_step_tracing), otherwise is never emitted.
773    Step,
774}
775
776impl<UserData, UserError> InstancePre<UserData, UserError> {
777    pub fn instantiate(&self) -> Result<Instance<UserData, UserError>, Error> {
778        Ok(Instance {
779            instance: self.0.module.instantiate()?,
780            pre: self.clone(),
781        })
782    }
783
784    pub fn module(&self) -> &Module {
785        &self.0.module
786    }
787}
788
789pub trait EntryPoint {
790    #[doc(hidden)]
791    fn get(self, exports: &LookupMap<Vec<u8>, ProgramCounter>) -> Result<ProgramCounter, String>;
792}
793
794impl<'a> EntryPoint for &'a str {
795    fn get(self, exports: &LookupMap<Vec<u8>, ProgramCounter>) -> Result<ProgramCounter, String> {
796        exports
797            .get(self.as_bytes())
798            .copied()
799            .ok_or_else(|| format!("export not found: '{self}'"))
800    }
801}
802
803impl EntryPoint for String {
804    fn get(self, exports: &LookupMap<Vec<u8>, ProgramCounter>) -> Result<ProgramCounter, String> {
805        EntryPoint::get(self.as_str(), exports)
806    }
807}
808
809impl EntryPoint for ProgramCounter {
810    fn get(self, _: &LookupMap<Vec<u8>, ProgramCounter>) -> Result<ProgramCounter, String> {
811        Ok(self)
812    }
813}
814
815impl<UserData, UserError> Instance<UserData, UserError> {
816    /// Calls a given exported function with the given arguments.
817    pub fn call_typed<FnArgs>(
818        &mut self,
819        user_data: &mut UserData,
820        entry_point: impl EntryPoint,
821        args: FnArgs,
822    ) -> Result<(), CallError<UserError>>
823    where
824        FnArgs: FuncArgs,
825    {
826        let entry_point = entry_point
827            .get(&self.pre.0.exports)
828            .map_err(|error| CallError::Error(Error::from_display(error)))?;
829
830        self.instance.prepare_call_typed(entry_point, args);
831        self.continue_execution(user_data)
832    }
833
834    /// Continues execution.
835    pub fn continue_execution(&mut self, user_data: &mut UserData) -> Result<(), CallError<UserError>> {
836        loop {
837            let interrupt = self.instance.run().map_err(CallError::Error)?;
838            match interrupt {
839                InterruptKind::Finished => break Ok(()),
840                InterruptKind::Trap => break Err(CallError::Trap),
841                InterruptKind::Ecalli(hostcall) => {
842                    if let Some(host_fn) = self.pre.0.imports.get(hostcall as usize).and_then(|host_fn| host_fn.as_ref()) {
843                        host_fn.0.call(user_data, &mut self.instance).map_err(CallError::User)?;
844                    } else if let Some(ref fallback_handler) = self.pre.0.fallback_handler {
845                        let caller = Caller {
846                            user_data,
847                            instance: &mut self.instance,
848                        };
849
850                        fallback_handler(caller, hostcall).map_err(CallError::User)?;
851                    } else {
852                        log::debug!("Called a missing host function with ID = {}", hostcall);
853                        break Err(CallError::Trap);
854                    };
855                }
856                InterruptKind::NotEnoughGas => return Err(CallError::NotEnoughGas),
857                InterruptKind::Segfault(segfault) => {
858                    let module = self.instance.module().clone();
859                    if segfault.page_address >= module.memory_map().stack_address_low()
860                        && segfault.page_address + segfault.page_size <= module.memory_map().stack_address_high()
861                    {
862                        self.instance
863                            .zero_memory_with_memory_protection(segfault.page_address, segfault.page_size, MemoryProtection::ReadWrite)
864                            .map_err(|error| {
865                                CallError::Error(Error::from_display(format!(
866                                    "failed to zero memory when handling a segfault at 0x{:x}: {error}",
867                                    segfault.page_address
868                                )))
869                            })?;
870
871                        continue;
872                    }
873
874                    macro_rules! handle {
875                        ($range:ident, $data:ident, $protection:ident) => {{
876                            if segfault.page_address >= module.memory_map().$range().start
877                                && segfault.page_address + segfault.page_size <= module.memory_map().$range().end
878                            {
879                                let data_offset = (segfault.page_address - module.memory_map().$range().start) as usize;
880                                let data = module.blob().$data();
881                                let chunk_length = data.len().checked_sub(data_offset);
882                                let initial_protection = if chunk_length.is_some() {
883                                    MemoryProtection::ReadWrite
884                                } else {
885                                    MemoryProtection::$protection
886                                };
887
888                                self.instance
889                                    .zero_memory_with_memory_protection(segfault.page_address, segfault.page_size, initial_protection)
890                                    .map_err(|error| {
891                                        CallError::Error(Error::from_display(format!(
892                                            "failed to zero memory when handling a segfault at 0x{:x}: {error}",
893                                            segfault.page_address
894                                        )))
895                                    })?;
896
897                                if let Some(chunk_length) = chunk_length {
898                                    let chunk_length = core::cmp::min(chunk_length, segfault.page_size as usize);
899                                    self.instance
900                                        .write_memory(segfault.page_address, &data[data_offset..data_offset + chunk_length])
901                                        .map_err(|error| {
902                                            CallError::Error(Error::from_display(format!(
903                                                "failed to write memory when handling a segfault at 0x{:x}: {error}",
904                                                segfault.page_address
905                                            )))
906                                        })?;
907                                };
908
909                                if MemoryProtection::$protection == MemoryProtection::Read && initial_protection != MemoryProtection::Read {
910                                    self.instance
911                                        .protect_memory(segfault.page_address, segfault.page_size)
912                                        .map_err(|error| {
913                                            CallError::Error(Error::from_display(format!(
914                                                "failed to protect memory when handling a segfault at 0x{:x}: {error}",
915                                                segfault.page_address
916                                            )))
917                                        })?;
918                                }
919
920                                continue;
921                            }
922                        }};
923                    }
924
925                    handle!(ro_data_range, ro_data, Read);
926                    handle!(rw_data_range, rw_data, ReadWrite);
927
928                    log::debug!("Unexpected segfault: 0x{:x}", segfault.page_address);
929                    break Err(CallError::Trap);
930                }
931                InterruptKind::Step => break Err(CallError::Step),
932            }
933        }
934    }
935
936    /// A conveniance function to call [`Instance::call_typed`] and [`RawInstance::get_result_typed`] in a single function call.
937    pub fn call_typed_and_get_result<FnResult, FnArgs>(
938        &mut self,
939        user_data: &mut UserData,
940        entry_point: impl EntryPoint,
941        args: FnArgs,
942    ) -> Result<FnResult, CallError<UserError>>
943    where
944        FnArgs: FuncArgs,
945        FnResult: FuncResult,
946    {
947        self.call_typed(user_data, entry_point, args)?;
948        Ok(self.instance.get_result_typed::<FnResult>())
949    }
950}