wasm_instrument/stack_limiter/
thunk.rs

1#[cfg(not(features = "std"))]
2use alloc::collections::BTreeMap as Map;
3use alloc::vec::Vec;
4use parity_wasm::{
5	builder,
6	elements::{self, FunctionType, Internal},
7};
8#[cfg(features = "std")]
9use std::collections::HashMap as Map;
10
11use super::{resolve_func_type, Context};
12
13struct Thunk {
14	signature: FunctionType,
15	// Index in function space of this thunk.
16	idx: Option<u32>,
17	callee_stack_cost: u32,
18}
19
20pub fn generate_thunks(
21	ctx: &mut Context,
22	module: elements::Module,
23) -> Result<elements::Module, &'static str> {
24	// First, we need to collect all function indices that should be replaced by thunks
25	let mut replacement_map: Map<u32, Thunk> = {
26		let exports = module.export_section().map(|es| es.entries()).unwrap_or(&[]);
27		let elem_segments = module.elements_section().map(|es| es.entries()).unwrap_or(&[]);
28		let start_func_idx = module.start_section();
29
30		let exported_func_indices = exports.iter().filter_map(|entry| match entry.internal() {
31			Internal::Function(function_idx) => Some(*function_idx),
32			_ => None,
33		});
34		let table_func_indices =
35			elem_segments.iter().flat_map(|segment| segment.members()).cloned();
36
37		// Replacement map is at least export section size.
38		let mut replacement_map: Map<u32, Thunk> = Map::new();
39
40		for func_idx in exported_func_indices
41			.chain(table_func_indices)
42			.chain(start_func_idx.into_iter())
43		{
44			let callee_stack_cost = ctx.stack_cost(func_idx).ok_or("function index isn't found")?;
45
46			// Don't generate a thunk if stack_cost of a callee is zero.
47			if callee_stack_cost != 0 {
48				replacement_map.insert(
49					func_idx,
50					Thunk {
51						signature: resolve_func_type(func_idx, &module)?.clone(),
52						idx: None,
53						callee_stack_cost,
54					},
55				);
56			}
57		}
58
59		replacement_map
60	};
61
62	// Then, we generate a thunk for each original function.
63
64	// Save current func_idx
65	let mut next_func_idx = module.functions_space() as u32;
66
67	let mut mbuilder = builder::from_module(module);
68	for (func_idx, thunk) in replacement_map.iter_mut() {
69		let instrumented_call = instrument_call!(
70			*func_idx,
71			thunk.callee_stack_cost as i32,
72			ctx.stack_height_global_idx(),
73			ctx.stack_limit()
74		);
75		// Thunk body consist of:
76		//  - argument pushing
77		//  - instrumented call
78		//  - end
79		let mut thunk_body: Vec<elements::Instruction> =
80			Vec::with_capacity(thunk.signature.params().len() + instrumented_call.len() + 1);
81
82		for (arg_idx, _) in thunk.signature.params().iter().enumerate() {
83			thunk_body.push(elements::Instruction::GetLocal(arg_idx as u32));
84		}
85		thunk_body.extend_from_slice(&instrumented_call);
86		thunk_body.push(elements::Instruction::End);
87
88		// TODO: Don't generate a signature, but find an existing one.
89
90		mbuilder = mbuilder
91			.function()
92			// Signature of the thunk should match the original function signature.
93			.signature()
94			.with_params(thunk.signature.params().to_vec())
95			.with_results(thunk.signature.results().to_vec())
96			.build()
97			.body()
98			.with_instructions(elements::Instructions::new(thunk_body))
99			.build()
100			.build();
101
102		thunk.idx = Some(next_func_idx);
103		next_func_idx += 1;
104	}
105	let mut module = mbuilder.build();
106
107	// And finally, fixup thunks in export and table sections.
108
109	// Fixup original function index to a index of a thunk generated earlier.
110	let fixup = |function_idx: &mut u32| {
111		// Check whether this function is in replacement_map, since
112		// we can skip thunk generation (e.g. if stack_cost of function is 0).
113		if let Some(thunk) = replacement_map.get(function_idx) {
114			*function_idx =
115				thunk.idx.expect("At this point an index must be assigned to each thunk");
116		}
117	};
118
119	for section in module.sections_mut() {
120		match section {
121			elements::Section::Export(export_section) =>
122				for entry in export_section.entries_mut() {
123					if let Internal::Function(function_idx) = entry.internal_mut() {
124						fixup(function_idx)
125					}
126				},
127			elements::Section::Element(elem_section) =>
128				for segment in elem_section.entries_mut() {
129					for function_idx in segment.members_mut() {
130						fixup(function_idx)
131					}
132				},
133			elements::Section::Start(start_idx) => fixup(start_idx),
134			_ => {},
135		}
136	}
137
138	Ok(module)
139}