wasm_instrument/stack_limiter/
thunk.rs1#[cfg(not(features = "std"))]
2use alloc::collections::BTreeMap as Map;
3use alloc::vec::Vec;
4use parity_wasm::{
5 builder,
6 elements::{self, FunctionType, Internal},
7};
8#[cfg(features = "std")]
9use std::collections::HashMap as Map;
10
11use super::{resolve_func_type, Context};
12
13struct Thunk {
14 signature: FunctionType,
15 idx: Option<u32>,
17 callee_stack_cost: u32,
18}
19
20pub fn generate_thunks(
21 ctx: &mut Context,
22 module: elements::Module,
23) -> Result<elements::Module, &'static str> {
24 let mut replacement_map: Map<u32, Thunk> = {
26 let exports = module.export_section().map(|es| es.entries()).unwrap_or(&[]);
27 let elem_segments = module.elements_section().map(|es| es.entries()).unwrap_or(&[]);
28 let start_func_idx = module.start_section();
29
30 let exported_func_indices = exports.iter().filter_map(|entry| match entry.internal() {
31 Internal::Function(function_idx) => Some(*function_idx),
32 _ => None,
33 });
34 let table_func_indices =
35 elem_segments.iter().flat_map(|segment| segment.members()).cloned();
36
37 let mut replacement_map: Map<u32, Thunk> = Map::new();
39
40 for func_idx in exported_func_indices
41 .chain(table_func_indices)
42 .chain(start_func_idx.into_iter())
43 {
44 let callee_stack_cost = ctx.stack_cost(func_idx).ok_or("function index isn't found")?;
45
46 if callee_stack_cost != 0 {
48 replacement_map.insert(
49 func_idx,
50 Thunk {
51 signature: resolve_func_type(func_idx, &module)?.clone(),
52 idx: None,
53 callee_stack_cost,
54 },
55 );
56 }
57 }
58
59 replacement_map
60 };
61
62 let mut next_func_idx = module.functions_space() as u32;
66
67 let mut mbuilder = builder::from_module(module);
68 for (func_idx, thunk) in replacement_map.iter_mut() {
69 let instrumented_call = instrument_call!(
70 *func_idx,
71 thunk.callee_stack_cost as i32,
72 ctx.stack_height_global_idx(),
73 ctx.stack_limit()
74 );
75 let mut thunk_body: Vec<elements::Instruction> =
80 Vec::with_capacity(thunk.signature.params().len() + instrumented_call.len() + 1);
81
82 for (arg_idx, _) in thunk.signature.params().iter().enumerate() {
83 thunk_body.push(elements::Instruction::GetLocal(arg_idx as u32));
84 }
85 thunk_body.extend_from_slice(&instrumented_call);
86 thunk_body.push(elements::Instruction::End);
87
88 mbuilder = mbuilder
91 .function()
92 .signature()
94 .with_params(thunk.signature.params().to_vec())
95 .with_results(thunk.signature.results().to_vec())
96 .build()
97 .body()
98 .with_instructions(elements::Instructions::new(thunk_body))
99 .build()
100 .build();
101
102 thunk.idx = Some(next_func_idx);
103 next_func_idx += 1;
104 }
105 let mut module = mbuilder.build();
106
107 let fixup = |function_idx: &mut u32| {
111 if let Some(thunk) = replacement_map.get(function_idx) {
114 *function_idx =
115 thunk.idx.expect("At this point an index must be assigned to each thunk");
116 }
117 };
118
119 for section in module.sections_mut() {
120 match section {
121 elements::Section::Export(export_section) =>
122 for entry in export_section.entries_mut() {
123 if let Internal::Function(function_idx) = entry.internal_mut() {
124 fixup(function_idx)
125 }
126 },
127 elements::Section::Element(elem_section) =>
128 for segment in elem_section.entries_mut() {
129 for function_idx in segment.members_mut() {
130 fixup(function_idx)
131 }
132 },
133 elements::Section::Start(start_idx) => fixup(start_idx),
134 _ => {},
135 }
136 }
137
138 Ok(module)
139}