1use crate::api::{MemoryProtection, RegValue};
2use crate::error::bail;
3use crate::program::ProgramSymbol;
4use crate::{Error, InterruptKind, Module, ProgramCounter, RawInstance, Reg};
5use alloc::borrow::ToOwned;
6use alloc::format;
7use alloc::string::String;
8use alloc::sync::Arc;
9use alloc::vec::Vec;
10use core::marker::PhantomData;
11
12#[cfg(not(feature = "std"))]
13use alloc::collections::btree_map::Entry;
14#[cfg(not(feature = "std"))]
15use alloc::collections::BTreeMap as LookupMap;
16
17#[cfg(feature = "std")]
18use std::collections::hash_map::Entry;
19#[cfg(feature = "std")]
20use std::collections::HashMap as LookupMap;
21
22trait CallFn<UserData, UserError>: Send + Sync {
23 fn call(&self, user_data: &mut UserData, instance: &mut RawInstance) -> Result<(), UserError>;
24}
25
26#[repr(transparent)]
27pub struct CallFnArc<UserData, UserError>(Arc<dyn CallFn<UserData, UserError>>);
28
29type FallbackHandlerArc<UserData, UserError> = Arc<dyn Fn(Caller<UserData>, u32) -> Result<(), UserError> + Send + Sync + 'static>;
30
31impl<UserData, UserError> Clone for CallFnArc<UserData, UserError> {
32 fn clone(&self) -> Self {
33 Self(Arc::clone(&self.0))
34 }
35}
36
37pub trait IntoCallFn<UserData, UserError, Params, Result>: Send + Sync + 'static {
38 #[doc(hidden)]
39 const _REGS_REQUIRED_32: usize;
40
41 #[doc(hidden)]
42 const _REGS_REQUIRED_64: usize;
43
44 #[doc(hidden)]
45 fn _into_extern_fn(self) -> CallFnArc<UserData, UserError>;
46}
47
48pub trait AbiTy: Sized + Send + 'static {
50 #[doc(hidden)]
51 const _REGS_REQUIRED_32: usize;
52
53 #[doc(hidden)]
54 const _REGS_REQUIRED_64: usize;
55
56 #[doc(hidden)]
57 fn _get32(get_reg: impl FnMut() -> RegValue) -> Self;
58
59 #[doc(hidden)]
60 fn _get64(get_reg: impl FnMut() -> RegValue) -> Self;
61
62 #[doc(hidden)]
63 fn _set32(self, set_reg: impl FnMut(RegValue));
64
65 #[doc(hidden)]
66 fn _set64(self, set_reg: impl FnMut(RegValue));
67}
68
69impl AbiTy for u32 {
70 const _REGS_REQUIRED_32: usize = 1;
71 const _REGS_REQUIRED_64: usize = 1;
72
73 fn _get32(mut get_reg: impl FnMut() -> RegValue) -> Self {
74 get_reg() as u32
75 }
76
77 fn _get64(mut get_reg: impl FnMut() -> RegValue) -> Self {
78 get_reg() as u32
79 }
80
81 fn _set32(self, mut set_reg: impl FnMut(RegValue)) {
82 set_reg(u64::from(self))
83 }
84
85 fn _set64(self, mut set_reg: impl FnMut(RegValue)) {
86 set_reg(u64::from(self))
87 }
88}
89
90impl AbiTy for i32 {
91 const _REGS_REQUIRED_32: usize = <u32 as AbiTy>::_REGS_REQUIRED_32;
92 const _REGS_REQUIRED_64: usize = <u32 as AbiTy>::_REGS_REQUIRED_64;
93
94 fn _get32(get_reg: impl FnMut() -> RegValue) -> Self {
95 <u32 as AbiTy>::_get32(get_reg) as i32
96 }
97
98 fn _get64(get_reg: impl FnMut() -> RegValue) -> Self {
99 <u32 as AbiTy>::_get64(get_reg) as i32
100 }
101
102 fn _set32(self, set_reg: impl FnMut(RegValue)) {
103 (self as u32)._set32(set_reg)
104 }
105
106 fn _set64(self, set_reg: impl FnMut(RegValue)) {
107 i64::from(self)._set64(set_reg)
108 }
109}
110
111impl AbiTy for u64 {
112 const _REGS_REQUIRED_32: usize = 2;
113 const _REGS_REQUIRED_64: usize = 1;
114
115 fn _get32(mut get_reg: impl FnMut() -> RegValue) -> Self {
116 let value_lo = get_reg();
117 let value_hi = get_reg();
118 debug_assert!(value_lo <= u64::from(u32::MAX));
119 debug_assert!(value_hi <= u64::from(u32::MAX));
120 value_lo | (value_hi << 32)
121 }
122
123 fn _get64(mut get_reg: impl FnMut() -> RegValue) -> Self {
124 get_reg()
125 }
126
127 fn _set32(self, mut set_reg: impl FnMut(RegValue)) {
128 set_reg(self);
129 set_reg(self >> 32);
130 }
131
132 fn _set64(self, mut set_reg: impl FnMut(RegValue)) {
133 set_reg(self);
134 }
135}
136
137impl AbiTy for i64 {
138 const _REGS_REQUIRED_32: usize = <u64 as AbiTy>::_REGS_REQUIRED_32;
139 const _REGS_REQUIRED_64: usize = <u64 as AbiTy>::_REGS_REQUIRED_64;
140
141 fn _get32(get_reg: impl FnMut() -> RegValue) -> Self {
142 <u64 as AbiTy>::_get32(get_reg) as i64
143 }
144
145 fn _get64(get_reg: impl FnMut() -> RegValue) -> Self {
146 <u64 as AbiTy>::_get64(get_reg) as i64
147 }
148
149 fn _set32(self, set_reg: impl FnMut(RegValue)) {
150 (self as u64)._set32(set_reg)
151 }
152
153 fn _set64(self, set_reg: impl FnMut(RegValue)) {
154 (self as u64)._set64(set_reg)
155 }
156}
157
158pub trait ReturnTy<UserError>: Sized + 'static {
162 #[doc(hidden)]
163 const _REGS_REQUIRED_32: usize;
164
165 #[doc(hidden)]
166 const _REGS_REQUIRED_64: usize;
167
168 #[doc(hidden)]
169 fn _handle_return32(self, set_reg: impl FnMut(RegValue)) -> Result<(), UserError>;
170
171 #[doc(hidden)]
172 fn _handle_return64(self, set_reg: impl FnMut(RegValue)) -> Result<(), UserError>;
173}
174
175impl<UserError, T> ReturnTy<UserError> for T
176where
177 T: AbiTy,
178{
179 const _REGS_REQUIRED_32: usize = <T as AbiTy>::_REGS_REQUIRED_32;
180 const _REGS_REQUIRED_64: usize = <T as AbiTy>::_REGS_REQUIRED_64;
181
182 fn _handle_return32(self, set_reg: impl FnMut(RegValue)) -> Result<(), UserError> {
183 self._set32(set_reg);
184 Ok(())
185 }
186
187 fn _handle_return64(self, set_reg: impl FnMut(RegValue)) -> Result<(), UserError> {
188 self._set64(set_reg);
189 Ok(())
190 }
191}
192
193impl<UserError> ReturnTy<UserError> for () {
194 const _REGS_REQUIRED_32: usize = 0;
195 const _REGS_REQUIRED_64: usize = 0;
196
197 fn _handle_return32(self, _set_reg: impl FnMut(RegValue)) -> Result<(), UserError> {
198 Ok(())
199 }
200
201 fn _handle_return64(self, _set_reg: impl FnMut(RegValue)) -> Result<(), UserError> {
202 Ok(())
203 }
204}
205
206impl<UserError, E> ReturnTy<UserError> for Result<(), E>
207where
208 UserError: From<E>,
209 E: 'static,
210{
211 const _REGS_REQUIRED_32: usize = 0;
212 const _REGS_REQUIRED_64: usize = 0;
213
214 fn _handle_return32(self, _set_reg: impl FnMut(RegValue)) -> Result<(), UserError> {
215 Ok(self?)
216 }
217
218 fn _handle_return64(self, _set_reg: impl FnMut(RegValue)) -> Result<(), UserError> {
219 Ok(self?)
220 }
221}
222
223impl<UserError, T, E> ReturnTy<UserError> for Result<T, E>
224where
225 UserError: From<E>,
226 E: 'static,
227 T: AbiTy,
228{
229 const _REGS_REQUIRED_32: usize = <T as AbiTy>::_REGS_REQUIRED_32;
230 const _REGS_REQUIRED_64: usize = <T as AbiTy>::_REGS_REQUIRED_64;
231
232 fn _handle_return32(self, set_reg: impl FnMut(RegValue)) -> Result<(), UserError> {
233 self?._set32(set_reg);
234 Ok(())
235 }
236
237 fn _handle_return64(self, set_reg: impl FnMut(RegValue)) -> Result<(), UserError> {
238 self?._set64(set_reg);
239 Ok(())
240 }
241}
242
243pub trait FuncArgs: Send {
244 #[doc(hidden)]
245 const _REGS_REQUIRED_32: usize;
246 #[doc(hidden)]
247 const _REGS_REQUIRED_64: usize;
248
249 #[doc(hidden)]
250 fn _set(self, is_64_bit: bool, set_reg: impl FnMut(RegValue))
251 where
252 Self: Sized,
253 {
254 if is_64_bit {
255 self._set64(set_reg);
256 } else {
257 self._set32(set_reg);
258 }
259 }
260
261 #[doc(hidden)]
262 fn _set32(self, set_reg: impl FnMut(RegValue));
263
264 #[doc(hidden)]
265 fn _set64(self, set_reg: impl FnMut(RegValue));
266}
267
268pub trait FuncResult: Send + Sized {
269 #[doc(hidden)]
270 const _REGS_REQUIRED_32: usize;
271 #[doc(hidden)]
272 const _REGS_REQUIRED_64: usize;
273
274 #[doc(hidden)]
275 fn _get(is_64_bit: bool, get_reg: impl FnMut() -> RegValue) -> Self {
276 if is_64_bit {
277 Self::_get64(get_reg)
278 } else {
279 Self::_get32(get_reg)
280 }
281 }
282
283 #[doc(hidden)]
284 fn _get32(get_reg: impl FnMut() -> RegValue) -> Self;
285
286 #[doc(hidden)]
287 fn _get64(get_reg: impl FnMut() -> RegValue) -> Self;
288}
289
290impl FuncResult for () {
291 const _REGS_REQUIRED_32: usize = 0;
292 const _REGS_REQUIRED_64: usize = 0;
293
294 fn _get32(_: impl FnMut() -> RegValue) -> Self {}
295 fn _get64(_: impl FnMut() -> RegValue) -> Self {}
296}
297
298impl<T> FuncResult for T
299where
300 T: AbiTy,
301{
302 const _REGS_REQUIRED_32: usize = <T as AbiTy>::_REGS_REQUIRED_32;
303 const _REGS_REQUIRED_64: usize = <T as AbiTy>::_REGS_REQUIRED_64;
304
305 fn _get32(get_reg: impl FnMut() -> RegValue) -> Self {
306 <T as AbiTy>::_get32(get_reg)
307 }
308
309 fn _get64(get_reg: impl FnMut() -> RegValue) -> Self {
310 <T as AbiTy>::_get64(get_reg)
311 }
312}
313
314macro_rules! impl_into_extern_fn {
315 (@check_reg_count $regs_required:expr) => {
316 if $regs_required > Reg::ARG_REGS.len() {
317 panic!("external call failed: too many registers required for arguments!");
319 }
320 };
321
322 (@call $is_64_bit:expr, $caller:expr, $callback:expr, ) => {{
323 ($callback)($caller)
324 }};
325
326 (@get_reg $caller:expr) => {{
327 let mut reg_index = 0;
328 let caller = &mut $caller;
329 move || -> RegValue {
330 let value = caller.instance.reg(Reg::ARG_REGS[reg_index]);
331 reg_index += 1;
332 value
333 }
334 }};
335
336 (@call $is_64_bit:expr, $caller:expr, $callback:expr, $a0:ident) => {{
337 let cb = impl_into_extern_fn!(@get_reg $caller);
338 let a0;
339 if $is_64_bit {
340 impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_64);
341 a0 = $a0::_get64(cb);
342 } else {
343 impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_32);
344 a0 = $a0::_get32(cb);
345 }
346
347 ($callback)($caller, a0)
348 }};
349
350 (@call $is_64_bit:expr, $caller:expr, $callback:expr, $a0:ident, $a1:ident) => {{
351 let mut cb = impl_into_extern_fn!(@get_reg $caller);
352 let a0;
353 let a1;
354 if $is_64_bit {
355 impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_64 + $a1::_REGS_REQUIRED_64);
356 a0 = $a0::_get64(&mut cb);
357 a1 = $a1::_get64(cb);
358 } else {
359 impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_32 + $a1::_REGS_REQUIRED_32);
360 a0 = $a0::_get32(&mut cb);
361 a1 = $a1::_get32(cb);
362 }
363
364 ($callback)($caller, a0, a1)
365 }};
366
367 (@call $is_64_bit:expr, $caller:expr, $callback:expr, $a0:ident, $a1:ident, $a2:ident) => {{
368 let mut cb = impl_into_extern_fn!(@get_reg $caller);
369 let a0;
370 let a1;
371 let a2;
372 if $is_64_bit {
373 impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_64 + $a1::_REGS_REQUIRED_64 + $a2::_REGS_REQUIRED_64);
374 a0 = $a0::_get64(&mut cb);
375 a1 = $a1::_get64(&mut cb);
376 a2 = $a2::_get64(cb);
377 } else {
378 impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_32 + $a1::_REGS_REQUIRED_32 + $a2::_REGS_REQUIRED_32);
379 a0 = $a0::_get32(&mut cb);
380 a1 = $a1::_get32(&mut cb);
381 a2 = $a2::_get32(cb);
382 }
383
384 ($callback)($caller, a0, a1, a2)
385 }};
386
387 (@call $is_64_bit:expr, $caller:expr, $callback:expr, $a0:ident, $a1:ident, $a2:ident, $a3:ident) => {{
388 let mut cb = impl_into_extern_fn!(@get_reg $caller);
389 let a0;
390 let a1;
391 let a2;
392 let a3;
393 if $is_64_bit {
394 impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_64 + $a1::_REGS_REQUIRED_64 + $a2::_REGS_REQUIRED_64 + $a3::_REGS_REQUIRED_64);
395 a0 = $a0::_get64(&mut cb);
396 a1 = $a1::_get64(&mut cb);
397 a2 = $a2::_get64(&mut cb);
398 a3 = $a3::_get64(cb);
399 } else {
400 impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_32 + $a1::_REGS_REQUIRED_32 + $a2::_REGS_REQUIRED_32 + $a3::_REGS_REQUIRED_32);
401 a0 = $a0::_get32(&mut cb);
402 a1 = $a1::_get32(&mut cb);
403 a2 = $a2::_get32(&mut cb);
404 a3 = $a3::_get32(cb);
405 }
406
407 ($callback)($caller, a0, a1, a2, a3)
408 }};
409
410 (@call $is_64_bit:expr, $caller:expr, $callback:expr, $a0:ident, $a1:ident, $a2:ident, $a3:ident, $a4:ident) => {{
411 let mut cb = impl_into_extern_fn!(@get_reg $caller);
412 let a0;
413 let a1;
414 let a2;
415 let a3;
416 let a4;
417 if $is_64_bit {
418 impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_64 + $a1::_REGS_REQUIRED_64 + $a2::_REGS_REQUIRED_64 + $a3::_REGS_REQUIRED_64 + $a4::_REGS_REQUIRED_64);
419 a0 = $a0::_get64(&mut cb);
420 a1 = $a1::_get64(&mut cb);
421 a2 = $a2::_get64(&mut cb);
422 a3 = $a3::_get64(&mut cb);
423 a4 = $a4::_get64(cb);
424 } else {
425 impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_32 + $a1::_REGS_REQUIRED_32 + $a2::_REGS_REQUIRED_32 + $a3::_REGS_REQUIRED_32 + $a4::_REGS_REQUIRED_32);
426 a0 = $a0::_get32(&mut cb);
427 a1 = $a1::_get32(&mut cb);
428 a2 = $a2::_get32(&mut cb);
429 a3 = $a3::_get32(&mut cb);
430 a4 = $a4::_get32(cb);
431 }
432
433 ($callback)($caller, a0, a1, a2, a3, a4)
434 }};
435
436 (@call $is_64_bit:expr, $caller:expr, $callback:expr, $a0:ident, $a1:ident, $a2:ident, $a3:ident, $a4:ident, $a5:ident) => {{
437 let mut cb = impl_into_extern_fn!(@get_reg $caller);
438 let a0;
439 let a1;
440 let a2;
441 let a3;
442 let a4;
443 let a5;
444 if $is_64_bit {
445 impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_64 + $a1::_REGS_REQUIRED_64 + $a2::_REGS_REQUIRED_64 + $a3::_REGS_REQUIRED_64 + $a4::_REGS_REQUIRED_64 + $a5::_REGS_REQUIRED_64);
446 a0 = $a0::_get64(&mut cb);
447 a1 = $a1::_get64(&mut cb);
448 a2 = $a2::_get64(&mut cb);
449 a3 = $a3::_get64(&mut cb);
450 a4 = $a4::_get64(&mut cb);
451 a5 = $a5::_get64(cb);
452 } else {
453 impl_into_extern_fn!(@check_reg_count $a0::_REGS_REQUIRED_32 + $a1::_REGS_REQUIRED_32 + $a2::_REGS_REQUIRED_32 + $a3::_REGS_REQUIRED_32 + $a4::_REGS_REQUIRED_32 + $a5::_REGS_REQUIRED_32);
454 a0 = $a0::_get32(&mut cb);
455 a1 = $a1::_get32(&mut cb);
456 a2 = $a2::_get32(&mut cb);
457 a3 = $a3::_get32(&mut cb);
458 a4 = $a4::_get32(&mut cb);
459 a5 = $a5::_get32(cb);
460 }
461
462 ($callback)($caller, a0, a1, a2, a3, a4, a5)
463 }};
464
465 ($arg_count:tt $($args:ident)*) => {
466 impl<UserData, UserError, F, $($args,)* R> CallFn<UserData, UserError> for (F, UnsafePhantomData<(R, $($args),*)>)
467 where
468 F: Fn(Caller<'_, UserData>, $($args),*) -> R + Send + Sync + 'static,
469 $($args: AbiTy,)*
470 R: ReturnTy<UserError>,
471 {
472 fn call(&self, user_data: &mut UserData, instance: &mut RawInstance) -> Result<(), UserError> {
473 let is_64_bit = instance.module().blob().is_64_bit();
474 let result = {
475 #[allow(unused_mut)]
476 let mut caller = Caller {
477 user_data,
478 instance
479 };
480
481 impl_into_extern_fn!(@call is_64_bit, caller, self.0, $($args),*)
482 };
483
484 let set_reg = {
485 let mut reg_index = 0;
486 move |value: RegValue| {
487 let reg = Reg::ARG_REGS[reg_index];
488 instance.set_reg(reg, value);
489 reg_index += 1;
490 }
491 };
492
493 if is_64_bit {
494 result._handle_return64(set_reg)
495 } else {
496 result._handle_return32(set_reg)
497 }
498 }
499 }
500
501 impl<UserData, UserError, F, $($args,)* R> IntoCallFn<UserData, UserError, ($($args,)*), R> for F
502 where
503 F: Fn($($args),*) -> R + Send + Sync + 'static,
504 $($args: AbiTy,)*
505 R: ReturnTy<UserError>,
506 {
507 const _REGS_REQUIRED_32: usize = 0 $(+ $args::_REGS_REQUIRED_32)*;
508 const _REGS_REQUIRED_64: usize = 0 $(+ $args::_REGS_REQUIRED_64)*;
509
510 fn _into_extern_fn(self) -> CallFnArc<UserData, UserError> {
511 #[allow(non_snake_case)]
512 let callback = move |_caller: Caller<UserData>, $($args: $args),*| -> R {
513 self($($args),*)
514 };
515 CallFnArc(Arc::new((callback, UnsafePhantomData(PhantomData::<(R, $($args),*)>))))
516 }
517 }
518
519 impl<UserData, UserError, F, $($args,)* R> IntoCallFn<UserData, UserError, (Caller<'_, UserData>, $($args,)*), R> for F
520 where
521 F: Fn(Caller<'_, UserData>, $($args),*) -> R + Send + Sync + 'static,
522 $($args: AbiTy,)*
523 R: ReturnTy<UserError>,
524 {
525 const _REGS_REQUIRED_32: usize = 0 $(+ $args::_REGS_REQUIRED_32)*;
526 const _REGS_REQUIRED_64: usize = 0 $(+ $args::_REGS_REQUIRED_64)*;
527
528 fn _into_extern_fn(self) -> CallFnArc<UserData, UserError> {
529 CallFnArc(Arc::new((self, UnsafePhantomData(PhantomData::<(R, $($args),*)>))))
530 }
531 }
532
533 impl<$($args: Send + AbiTy,)*> FuncArgs for ($($args,)*) {
534 const _REGS_REQUIRED_32: usize = 0 $(+ $args::_REGS_REQUIRED_32)*;
535 const _REGS_REQUIRED_64: usize = 0 $(+ $args::_REGS_REQUIRED_64)*;
536
537 #[allow(unused_mut)]
538 #[allow(unused_variables)]
539 #[allow(non_snake_case)]
540 fn _set32(self, mut set_reg: impl FnMut(RegValue)) {
541 let ($($args,)*) = self;
542 $($args._set32(&mut set_reg);)*
543 }
544
545 #[allow(unused_mut)]
546 #[allow(unused_variables)]
547 #[allow(non_snake_case)]
548 fn _set64(self, mut set_reg: impl FnMut(RegValue)) {
549 let ($($args,)*) = self;
550 $($args._set64(&mut set_reg);)*
551 }
552 }
553 };
554}
555
556impl_into_extern_fn!(0);
557impl_into_extern_fn!(1 A0);
558impl_into_extern_fn!(2 A0 A1);
559impl_into_extern_fn!(3 A0 A1 A2);
560impl_into_extern_fn!(4 A0 A1 A2 A3);
561impl_into_extern_fn!(5 A0 A1 A2 A3 A4);
562impl_into_extern_fn!(6 A0 A1 A2 A3 A4 A5);
563
564#[repr(transparent)]
565struct UnsafePhantomData<T>(PhantomData<T>);
566
567unsafe impl<T> Send for UnsafePhantomData<T> {}
569
570unsafe impl<T> Sync for UnsafePhantomData<T> {}
572
573struct DynamicFn<T, F> {
574 callback: F,
575 _phantom: UnsafePhantomData<T>,
576}
577
578impl<UserData, UserError, F> CallFn<UserData, UserError> for DynamicFn<UserData, F>
579where
580 F: Fn(Caller<'_, UserData>) -> Result<(), UserError> + Send + Sync + 'static,
581 UserData: 'static,
582{
583 fn call(&self, user_data: &mut UserData, instance: &mut RawInstance) -> Result<(), UserError> {
584 let caller = Caller { user_data, instance };
585
586 (self.callback)(caller)
587 }
588}
589
590#[non_exhaustive]
591pub struct Caller<'a, UserData = ()> {
592 pub user_data: &'a mut UserData,
593 pub instance: &'a mut RawInstance,
594}
595
596pub struct Linker<UserData = (), UserError = core::convert::Infallible> {
597 host_functions: LookupMap<Vec<u8>, CallFnArc<UserData, UserError>>,
598 #[allow(clippy::type_complexity)]
599 fallback_handler: Option<FallbackHandlerArc<UserData, UserError>>,
600 phantom: PhantomData<(UserData, UserError)>,
601}
602
603impl<UserData, UserError> Default for Linker<UserData, UserError> {
604 fn default() -> Self {
605 Self::new()
606 }
607}
608
609impl<UserData, UserError> Linker<UserData, UserError> {
610 pub fn new() -> Self {
611 Self {
612 host_functions: Default::default(),
613 fallback_handler: None,
614 phantom: PhantomData,
615 }
616 }
617
618 pub fn define_fallback(&mut self, func: impl Fn(Caller<UserData>, u32) -> Result<(), UserError> + Send + Sync + 'static) {
620 self.fallback_handler = Some(Arc::new(func));
621 }
622
623 pub fn define_untyped(
625 &mut self,
626 symbol: impl AsRef<[u8]>,
627 func: impl Fn(Caller<UserData>) -> Result<(), UserError> + Send + Sync + 'static,
628 ) -> Result<&mut Self, Error>
629 where
630 UserData: 'static,
631 {
632 let symbol = symbol.as_ref();
633 if self.host_functions.contains_key(symbol) {
634 bail!(
635 "cannot register host function: host function was already registered: {}",
636 ProgramSymbol::new(symbol)
637 );
638 }
639
640 self.host_functions.insert(
641 symbol.to_owned(),
642 CallFnArc(Arc::new(DynamicFn {
643 callback: func,
644 _phantom: UnsafePhantomData(PhantomData),
645 })),
646 );
647
648 Ok(self)
649 }
650
651 pub fn define_typed<Params, Args>(
653 &mut self,
654 symbol: impl AsRef<[u8]>,
655 func: impl IntoCallFn<UserData, UserError, Params, Args>,
656 ) -> Result<&mut Self, Error> {
657 let symbol = symbol.as_ref();
658 if self.host_functions.contains_key(symbol) {
659 bail!(
660 "cannot register host function: host function was already registered: {}",
661 ProgramSymbol::new(symbol)
662 );
663 }
664
665 self.host_functions.insert(symbol.to_owned(), func._into_extern_fn());
666 Ok(self)
667 }
668
669 pub fn instantiate_pre(&self, module: &Module) -> Result<InstancePre<UserData, UserError>, Error> {
671 let mut exports = LookupMap::new();
672 for export in module.exports() {
673 match exports.entry(export.symbol().as_bytes().to_owned()) {
674 Entry::Occupied(_) => {
675 if module.is_strict() {
676 return Err(format!("duplicate export: {}", export.symbol()).into());
677 } else {
678 log::debug!("Duplicate export: {}", export.symbol());
679 continue;
680 }
681 }
682 Entry::Vacant(entry) => {
683 entry.insert(export.program_counter());
684 }
685 }
686 }
687
688 let mut imports: Vec<Option<CallFnArc<UserData, UserError>>> = Vec::with_capacity(module.imports().len() as usize);
689 for symbol in module.imports() {
690 let Some(symbol) = symbol else {
691 if module.is_strict() {
692 return Err("failed to parse an import".into());
693 } else {
694 imports.push(None);
695 continue;
696 }
697 };
698
699 let host_fn = if let Some(host_fn) = self.host_functions.get(symbol.as_bytes()) {
700 Some(host_fn.clone())
701 } else if self.fallback_handler.is_some() {
702 None
703 } else if module.is_strict() {
704 return Err(format!("missing host function: {}", symbol).into());
705 } else {
706 log::debug!("Missing host function: {}", symbol);
707 None
708 };
709
710 imports.push(host_fn);
711 }
712
713 assert_eq!(imports.len(), module.imports().len() as usize);
714 Ok(InstancePre(Arc::new(InstancePreState {
715 module: module.clone(),
716 imports,
717 exports,
718 fallback_handler: self.fallback_handler.clone(),
719 })))
720 }
721}
722
723struct InstancePreState<UserData, UserError> {
724 module: Module,
725 imports: Vec<Option<CallFnArc<UserData, UserError>>>,
726 exports: LookupMap<Vec<u8>, ProgramCounter>,
727 fallback_handler: Option<FallbackHandlerArc<UserData, UserError>>,
728}
729
730pub struct InstancePre<UserData = (), UserError = core::convert::Infallible>(Arc<InstancePreState<UserData, UserError>>);
731
732impl<UserData, UserError> Clone for InstancePre<UserData, UserError> {
733 fn clone(&self) -> Self {
734 Self(Arc::clone(&self.0))
735 }
736}
737
738pub struct Instance<UserData = (), UserError = core::convert::Infallible> {
739 instance: RawInstance,
740 pre: InstancePre<UserData, UserError>,
741}
742
743impl<UserData, UserError> core::ops::Deref for Instance<UserData, UserError> {
744 type Target = RawInstance;
745 fn deref(&self) -> &Self::Target {
746 &self.instance
747 }
748}
749
750impl<UserData, UserError> core::ops::DerefMut for Instance<UserData, UserError> {
751 fn deref_mut(&mut self) -> &mut Self::Target {
752 &mut self.instance
753 }
754}
755
756#[derive(Debug)]
757pub enum CallError<UserError = core::convert::Infallible> {
758 Trap,
760
761 NotEnoughGas,
763
764 Error(Error),
766
767 User(UserError),
769
770 Step,
774}
775
776impl<UserData, UserError> InstancePre<UserData, UserError> {
777 pub fn instantiate(&self) -> Result<Instance<UserData, UserError>, Error> {
778 Ok(Instance {
779 instance: self.0.module.instantiate()?,
780 pre: self.clone(),
781 })
782 }
783
784 pub fn module(&self) -> &Module {
785 &self.0.module
786 }
787}
788
789pub trait EntryPoint {
790 #[doc(hidden)]
791 fn get(self, exports: &LookupMap<Vec<u8>, ProgramCounter>) -> Result<ProgramCounter, String>;
792}
793
794impl<'a> EntryPoint for &'a str {
795 fn get(self, exports: &LookupMap<Vec<u8>, ProgramCounter>) -> Result<ProgramCounter, String> {
796 exports
797 .get(self.as_bytes())
798 .copied()
799 .ok_or_else(|| format!("export not found: '{self}'"))
800 }
801}
802
803impl EntryPoint for String {
804 fn get(self, exports: &LookupMap<Vec<u8>, ProgramCounter>) -> Result<ProgramCounter, String> {
805 EntryPoint::get(self.as_str(), exports)
806 }
807}
808
809impl EntryPoint for ProgramCounter {
810 fn get(self, _: &LookupMap<Vec<u8>, ProgramCounter>) -> Result<ProgramCounter, String> {
811 Ok(self)
812 }
813}
814
815impl<UserData, UserError> Instance<UserData, UserError> {
816 pub fn call_typed<FnArgs>(
818 &mut self,
819 user_data: &mut UserData,
820 entry_point: impl EntryPoint,
821 args: FnArgs,
822 ) -> Result<(), CallError<UserError>>
823 where
824 FnArgs: FuncArgs,
825 {
826 let entry_point = entry_point
827 .get(&self.pre.0.exports)
828 .map_err(|error| CallError::Error(Error::from_display(error)))?;
829
830 self.instance.prepare_call_typed(entry_point, args);
831 self.continue_execution(user_data)
832 }
833
834 pub fn continue_execution(&mut self, user_data: &mut UserData) -> Result<(), CallError<UserError>> {
836 loop {
837 let interrupt = self.instance.run().map_err(CallError::Error)?;
838 match interrupt {
839 InterruptKind::Finished => break Ok(()),
840 InterruptKind::Trap => break Err(CallError::Trap),
841 InterruptKind::Ecalli(hostcall) => {
842 if let Some(host_fn) = self.pre.0.imports.get(hostcall as usize).and_then(|host_fn| host_fn.as_ref()) {
843 host_fn.0.call(user_data, &mut self.instance).map_err(CallError::User)?;
844 } else if let Some(ref fallback_handler) = self.pre.0.fallback_handler {
845 let caller = Caller {
846 user_data,
847 instance: &mut self.instance,
848 };
849
850 fallback_handler(caller, hostcall).map_err(CallError::User)?;
851 } else {
852 log::debug!("Called a missing host function with ID = {}", hostcall);
853 break Err(CallError::Trap);
854 };
855 }
856 InterruptKind::NotEnoughGas => return Err(CallError::NotEnoughGas),
857 InterruptKind::Segfault(segfault) => {
858 let module = self.instance.module().clone();
859 if segfault.page_address >= module.memory_map().stack_address_low()
860 && segfault.page_address + segfault.page_size <= module.memory_map().stack_address_high()
861 {
862 self.instance
863 .zero_memory_with_memory_protection(segfault.page_address, segfault.page_size, MemoryProtection::ReadWrite)
864 .map_err(|error| {
865 CallError::Error(Error::from_display(format!(
866 "failed to zero memory when handling a segfault at 0x{:x}: {error}",
867 segfault.page_address
868 )))
869 })?;
870
871 continue;
872 }
873
874 macro_rules! handle {
875 ($range:ident, $data:ident, $protection:ident) => {{
876 if segfault.page_address >= module.memory_map().$range().start
877 && segfault.page_address + segfault.page_size <= module.memory_map().$range().end
878 {
879 let data_offset = (segfault.page_address - module.memory_map().$range().start) as usize;
880 let data = module.blob().$data();
881 let chunk_length = data.len().checked_sub(data_offset);
882 let initial_protection = if chunk_length.is_some() {
883 MemoryProtection::ReadWrite
884 } else {
885 MemoryProtection::$protection
886 };
887
888 self.instance
889 .zero_memory_with_memory_protection(segfault.page_address, segfault.page_size, initial_protection)
890 .map_err(|error| {
891 CallError::Error(Error::from_display(format!(
892 "failed to zero memory when handling a segfault at 0x{:x}: {error}",
893 segfault.page_address
894 )))
895 })?;
896
897 if let Some(chunk_length) = chunk_length {
898 let chunk_length = core::cmp::min(chunk_length, segfault.page_size as usize);
899 self.instance
900 .write_memory(segfault.page_address, &data[data_offset..data_offset + chunk_length])
901 .map_err(|error| {
902 CallError::Error(Error::from_display(format!(
903 "failed to write memory when handling a segfault at 0x{:x}: {error}",
904 segfault.page_address
905 )))
906 })?;
907 };
908
909 if MemoryProtection::$protection == MemoryProtection::Read && initial_protection != MemoryProtection::Read {
910 self.instance
911 .protect_memory(segfault.page_address, segfault.page_size)
912 .map_err(|error| {
913 CallError::Error(Error::from_display(format!(
914 "failed to protect memory when handling a segfault at 0x{:x}: {error}",
915 segfault.page_address
916 )))
917 })?;
918 }
919
920 continue;
921 }
922 }};
923 }
924
925 handle!(ro_data_range, ro_data, Read);
926 handle!(rw_data_range, rw_data, ReadWrite);
927
928 log::debug!("Unexpected segfault: 0x{:x}", segfault.page_address);
929 break Err(CallError::Trap);
930 }
931 InterruptKind::Step => break Err(CallError::Step),
932 }
933 }
934 }
935
936 pub fn call_typed_and_get_result<FnResult, FnArgs>(
938 &mut self,
939 user_data: &mut UserData,
940 entry_point: impl EntryPoint,
941 args: FnArgs,
942 ) -> Result<FnResult, CallError<UserError>>
943 where
944 FnArgs: FuncArgs,
945 FnResult: FuncResult,
946 {
947 self.call_typed(user_data, entry_point, args)?;
948 Ok(self.instance.get_result_typed::<FnResult>())
949 }
950}