wasm_instrument/stack_limiter/
mod.rs

1//! Contains the code for the stack height limiter instrumentation.
2
3use alloc::{vec, vec::Vec};
4use core::mem;
5use parity_wasm::{
6	builder,
7	elements::{self, Instruction, Instructions, Type},
8};
9
10/// Macro to generate preamble and postamble.
11macro_rules! instrument_call {
12	($callee_idx: expr, $callee_stack_cost: expr, $stack_height_global_idx: expr, $stack_limit: expr) => {{
13		use $crate::parity_wasm::elements::Instruction::*;
14		[
15			// stack_height += stack_cost(F)
16			GetGlobal($stack_height_global_idx),
17			I32Const($callee_stack_cost),
18			I32Add,
19			SetGlobal($stack_height_global_idx),
20			// if stack_counter > LIMIT: unreachable
21			GetGlobal($stack_height_global_idx),
22			I32Const($stack_limit as i32),
23			I32GtU,
24			If(elements::BlockType::NoResult),
25			Unreachable,
26			End,
27			// Original call
28			Call($callee_idx),
29			// stack_height -= stack_cost(F)
30			GetGlobal($stack_height_global_idx),
31			I32Const($callee_stack_cost),
32			I32Sub,
33			SetGlobal($stack_height_global_idx),
34		]
35	}};
36}
37
38mod max_height;
39mod thunk;
40
41pub struct Context {
42	stack_height_global_idx: u32,
43	func_stack_costs: Vec<u32>,
44	stack_limit: u32,
45}
46
47impl Context {
48	/// Returns index in a global index space of a stack_height global variable.
49	fn stack_height_global_idx(&self) -> u32 {
50		self.stack_height_global_idx
51	}
52
53	/// Returns `stack_cost` for `func_idx`.
54	fn stack_cost(&self, func_idx: u32) -> Option<u32> {
55		self.func_stack_costs.get(func_idx as usize).cloned()
56	}
57
58	/// Returns stack limit specified by the rules.
59	fn stack_limit(&self) -> u32 {
60		self.stack_limit
61	}
62}
63
64/// Inject the instumentation that makes stack overflows deterministic, by introducing
65/// an upper bound of the stack size.
66///
67/// This pass introduces a global mutable variable to track stack height,
68/// and instruments all calls with preamble and postamble.
69///
70/// Stack height is increased prior the call. Otherwise, the check would
71/// be made after the stack frame is allocated.
72///
73/// The preamble is inserted before the call. It increments
74/// the global stack height variable with statically determined "stack cost"
75/// of the callee. If after the increment the stack height exceeds
76/// the limit (specified by the `rules`) then execution traps.
77/// Otherwise, the call is executed.
78///
79/// The postamble is inserted after the call. The purpose of the postamble is to decrease
80/// the stack height by the "stack cost" of the callee function.
81///
82/// Note, that we can't instrument all possible ways to return from the function. The simplest
83/// example would be a trap issued by the host function.
84/// That means stack height global won't be equal to zero upon the next execution after such trap.
85///
86/// # Thunks
87///
88/// Because stack height is increased prior the call few problems arises:
89///
90/// - Stack height isn't increased upon an entry to the first function, i.e. exported function.
91/// - Start function is executed externally (similar to exported functions).
92/// - It is statically unknown what function will be invoked in an indirect call.
93///
94/// The solution for this problems is to generate a intermediate functions, called 'thunks', which
95/// will increase before and decrease the stack height after the call to original function, and
96/// then make exported function and table entries, start section to point to a corresponding thunks.
97///
98/// # Stack cost
99///
100/// Stack cost of the function is calculated as a sum of it's locals
101/// and the maximal height of the value stack.
102///
103/// All values are treated equally, as they have the same size.
104///
105/// The rationale is that this makes it possible to use the following very naive wasm executor:
106///
107/// - values are implemented by a union, so each value takes a size equal to the size of the largest
108///   possible value type this union can hold. (In MVP it is 8 bytes)
109/// - each value from the value stack is placed on the native stack.
110/// - each local variable and function argument is placed on the native stack.
111/// - arguments pushed by the caller are copied into callee stack rather than shared between the
112///   frames.
113/// - upon entry into the function entire stack frame is allocated.
114pub fn inject(
115	mut module: elements::Module,
116	stack_limit: u32,
117) -> Result<elements::Module, &'static str> {
118	let mut ctx = Context {
119		stack_height_global_idx: generate_stack_height_global(&mut module),
120		func_stack_costs: compute_stack_costs(&module)?,
121		stack_limit,
122	};
123
124	instrument_functions(&mut ctx, &mut module)?;
125	let module = thunk::generate_thunks(&mut ctx, module)?;
126
127	Ok(module)
128}
129
130/// Generate a new global that will be used for tracking current stack height.
131fn generate_stack_height_global(module: &mut elements::Module) -> u32 {
132	let global_entry = builder::global()
133		.value_type()
134		.i32()
135		.mutable()
136		.init_expr(Instruction::I32Const(0))
137		.build();
138
139	// Try to find an existing global section.
140	for section in module.sections_mut() {
141		if let elements::Section::Global(gs) = section {
142			gs.entries_mut().push(global_entry);
143			return (gs.entries().len() as u32) - 1
144		}
145	}
146
147	// Existing section not found, create one!
148	module
149		.sections_mut()
150		.push(elements::Section::Global(elements::GlobalSection::with_entries(vec![global_entry])));
151	0
152}
153
154/// Calculate stack costs for all functions.
155///
156/// Returns a vector with a stack cost for each function, including imports.
157fn compute_stack_costs(module: &elements::Module) -> Result<Vec<u32>, &'static str> {
158	let func_imports = module.import_count(elements::ImportCountType::Function);
159
160	// TODO: optimize!
161	(0..module.functions_space())
162		.map(|func_idx| {
163			if func_idx < func_imports {
164				// We can't calculate stack_cost of the import functions.
165				Ok(0)
166			} else {
167				compute_stack_cost(func_idx as u32, module)
168			}
169		})
170		.collect()
171}
172
173/// Stack cost of the given *defined* function is the sum of it's locals count (that is,
174/// number of arguments plus number of local variables) and the maximal stack
175/// height.
176fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, &'static str> {
177	// To calculate the cost of a function we need to convert index from
178	// function index space to defined function spaces.
179	let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
180	let defined_func_idx = func_idx
181		.checked_sub(func_imports)
182		.ok_or("This should be a index of a defined function")?;
183
184	let code_section =
185		module.code_section().ok_or("Due to validation code section should exists")?;
186	let body = &code_section
187		.bodies()
188		.get(defined_func_idx as usize)
189		.ok_or("Function body is out of bounds")?;
190
191	let mut locals_count: u32 = 0;
192	for local_group in body.locals() {
193		locals_count =
194			locals_count.checked_add(local_group.count()).ok_or("Overflow in local count")?;
195	}
196
197	let max_stack_height = max_height::compute(defined_func_idx, module)?;
198
199	locals_count
200		.checked_add(max_stack_height)
201		.ok_or("Overflow in adding locals_count and max_stack_height")
202}
203
204fn instrument_functions(
205	ctx: &mut Context,
206	module: &mut elements::Module,
207) -> Result<(), &'static str> {
208	for section in module.sections_mut() {
209		if let elements::Section::Code(code_section) = section {
210			for func_body in code_section.bodies_mut() {
211				let opcodes = func_body.code_mut();
212				instrument_function(ctx, opcodes)?;
213			}
214		}
215	}
216	Ok(())
217}
218
219/// This function searches `call` instructions and wrap each call
220/// with preamble and postamble.
221///
222/// Before:
223///
224/// ```text
225/// get_local 0
226/// get_local 1
227/// call 228
228/// drop
229/// ```
230///
231/// After:
232///
233/// ```text
234/// get_local 0
235/// get_local 1
236///
237/// < ... preamble ... >
238///
239/// call 228
240///
241/// < .. postamble ... >
242///
243/// drop
244/// ```
245fn instrument_function(ctx: &mut Context, func: &mut Instructions) -> Result<(), &'static str> {
246	use Instruction::*;
247
248	struct InstrumentCall {
249		offset: usize,
250		callee: u32,
251		cost: u32,
252	}
253
254	let calls: Vec<_> = func
255		.elements()
256		.iter()
257		.enumerate()
258		.filter_map(|(offset, instruction)| {
259			if let Call(callee) = instruction {
260				ctx.stack_cost(*callee).and_then(|cost| {
261					if cost > 0 {
262						Some(InstrumentCall { callee: *callee, offset, cost })
263					} else {
264						None
265					}
266				})
267			} else {
268				None
269			}
270		})
271		.collect();
272
273	// The `instrumented_call!` contains the call itself. This is why we need to subtract one.
274	let len = func.elements().len() + calls.len() * (instrument_call!(0, 0, 0, 0).len() - 1);
275	let original_instrs = mem::replace(func.elements_mut(), Vec::with_capacity(len));
276	let new_instrs = func.elements_mut();
277
278	let mut calls = calls.into_iter().peekable();
279	for (original_pos, instr) in original_instrs.into_iter().enumerate() {
280		// whether there is some call instruction at this position that needs to be instrumented
281		let did_instrument = if let Some(call) = calls.peek() {
282			if call.offset == original_pos {
283				let new_seq = instrument_call!(
284					call.callee,
285					call.cost as i32,
286					ctx.stack_height_global_idx(),
287					ctx.stack_limit()
288				);
289				new_instrs.extend_from_slice(&new_seq);
290				true
291			} else {
292				false
293			}
294		} else {
295			false
296		};
297
298		if did_instrument {
299			calls.next();
300		} else {
301			new_instrs.push(instr);
302		}
303	}
304
305	if calls.next().is_some() {
306		return Err("Not all calls were used")
307	}
308
309	Ok(())
310}
311
312fn resolve_func_type(
313	func_idx: u32,
314	module: &elements::Module,
315) -> Result<&elements::FunctionType, &'static str> {
316	let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
317	let functions = module.function_section().map(|fs| fs.entries()).unwrap_or(&[]);
318
319	let func_imports = module.import_count(elements::ImportCountType::Function);
320	let sig_idx = if func_idx < func_imports as u32 {
321		module
322			.import_section()
323			.expect("function import count is not zero; import section must exists; qed")
324			.entries()
325			.iter()
326			.filter_map(|entry| match entry.external() {
327				elements::External::Function(idx) => Some(*idx),
328				_ => None,
329			})
330			.nth(func_idx as usize)
331			.expect(
332				"func_idx is less than function imports count;
333				nth function import must be `Some`;
334				qed",
335			)
336	} else {
337		functions
338			.get(func_idx as usize - func_imports)
339			.ok_or("Function at the specified index is not defined")?
340			.type_ref()
341	};
342	let Type::Function(ty) = types
343		.get(sig_idx as usize)
344		.ok_or("The signature as specified by a function isn't defined")?;
345	Ok(ty)
346}
347
348#[cfg(test)]
349mod tests {
350	use super::*;
351	use parity_wasm::elements;
352
353	fn parse_wat(source: &str) -> elements::Module {
354		elements::deserialize_buffer(&wat::parse_str(source).expect("Failed to wat2wasm"))
355			.expect("Failed to deserialize the module")
356	}
357
358	fn validate_module(module: elements::Module) {
359		let binary = elements::serialize(module).expect("Failed to serialize");
360		wasmparser::validate(&binary).expect("Invalid module");
361	}
362
363	#[test]
364	fn test_with_params_and_result() {
365		let module = parse_wat(
366			r#"
367(module
368	(func (export "i32.add") (param i32 i32) (result i32)
369		get_local 0
370	get_local 1
371	i32.add
372	)
373)
374"#,
375		);
376
377		let module = inject(module, 1024).expect("Failed to inject stack counter");
378		validate_module(module);
379	}
380}