wasm_instrument/stack_limiter/
mod.rs1use alloc::{vec, vec::Vec};
4use core::mem;
5use parity_wasm::{
6 builder,
7 elements::{self, Instruction, Instructions, Type},
8};
9
10macro_rules! instrument_call {
12 ($callee_idx: expr, $callee_stack_cost: expr, $stack_height_global_idx: expr, $stack_limit: expr) => {{
13 use $crate::parity_wasm::elements::Instruction::*;
14 [
15 GetGlobal($stack_height_global_idx),
17 I32Const($callee_stack_cost),
18 I32Add,
19 SetGlobal($stack_height_global_idx),
20 GetGlobal($stack_height_global_idx),
22 I32Const($stack_limit as i32),
23 I32GtU,
24 If(elements::BlockType::NoResult),
25 Unreachable,
26 End,
27 Call($callee_idx),
29 GetGlobal($stack_height_global_idx),
31 I32Const($callee_stack_cost),
32 I32Sub,
33 SetGlobal($stack_height_global_idx),
34 ]
35 }};
36}
37
38mod max_height;
39mod thunk;
40
41pub struct Context {
42 stack_height_global_idx: u32,
43 func_stack_costs: Vec<u32>,
44 stack_limit: u32,
45}
46
47impl Context {
48 fn stack_height_global_idx(&self) -> u32 {
50 self.stack_height_global_idx
51 }
52
53 fn stack_cost(&self, func_idx: u32) -> Option<u32> {
55 self.func_stack_costs.get(func_idx as usize).cloned()
56 }
57
58 fn stack_limit(&self) -> u32 {
60 self.stack_limit
61 }
62}
63
64pub fn inject(
115 mut module: elements::Module,
116 stack_limit: u32,
117) -> Result<elements::Module, &'static str> {
118 let mut ctx = Context {
119 stack_height_global_idx: generate_stack_height_global(&mut module),
120 func_stack_costs: compute_stack_costs(&module)?,
121 stack_limit,
122 };
123
124 instrument_functions(&mut ctx, &mut module)?;
125 let module = thunk::generate_thunks(&mut ctx, module)?;
126
127 Ok(module)
128}
129
130fn generate_stack_height_global(module: &mut elements::Module) -> u32 {
132 let global_entry = builder::global()
133 .value_type()
134 .i32()
135 .mutable()
136 .init_expr(Instruction::I32Const(0))
137 .build();
138
139 for section in module.sections_mut() {
141 if let elements::Section::Global(gs) = section {
142 gs.entries_mut().push(global_entry);
143 return (gs.entries().len() as u32) - 1
144 }
145 }
146
147 module
149 .sections_mut()
150 .push(elements::Section::Global(elements::GlobalSection::with_entries(vec![global_entry])));
151 0
152}
153
154fn compute_stack_costs(module: &elements::Module) -> Result<Vec<u32>, &'static str> {
158 let func_imports = module.import_count(elements::ImportCountType::Function);
159
160 (0..module.functions_space())
162 .map(|func_idx| {
163 if func_idx < func_imports {
164 Ok(0)
166 } else {
167 compute_stack_cost(func_idx as u32, module)
168 }
169 })
170 .collect()
171}
172
173fn compute_stack_cost(func_idx: u32, module: &elements::Module) -> Result<u32, &'static str> {
177 let func_imports = module.import_count(elements::ImportCountType::Function) as u32;
180 let defined_func_idx = func_idx
181 .checked_sub(func_imports)
182 .ok_or("This should be a index of a defined function")?;
183
184 let code_section =
185 module.code_section().ok_or("Due to validation code section should exists")?;
186 let body = &code_section
187 .bodies()
188 .get(defined_func_idx as usize)
189 .ok_or("Function body is out of bounds")?;
190
191 let mut locals_count: u32 = 0;
192 for local_group in body.locals() {
193 locals_count =
194 locals_count.checked_add(local_group.count()).ok_or("Overflow in local count")?;
195 }
196
197 let max_stack_height = max_height::compute(defined_func_idx, module)?;
198
199 locals_count
200 .checked_add(max_stack_height)
201 .ok_or("Overflow in adding locals_count and max_stack_height")
202}
203
204fn instrument_functions(
205 ctx: &mut Context,
206 module: &mut elements::Module,
207) -> Result<(), &'static str> {
208 for section in module.sections_mut() {
209 if let elements::Section::Code(code_section) = section {
210 for func_body in code_section.bodies_mut() {
211 let opcodes = func_body.code_mut();
212 instrument_function(ctx, opcodes)?;
213 }
214 }
215 }
216 Ok(())
217}
218
219fn instrument_function(ctx: &mut Context, func: &mut Instructions) -> Result<(), &'static str> {
246 use Instruction::*;
247
248 struct InstrumentCall {
249 offset: usize,
250 callee: u32,
251 cost: u32,
252 }
253
254 let calls: Vec<_> = func
255 .elements()
256 .iter()
257 .enumerate()
258 .filter_map(|(offset, instruction)| {
259 if let Call(callee) = instruction {
260 ctx.stack_cost(*callee).and_then(|cost| {
261 if cost > 0 {
262 Some(InstrumentCall { callee: *callee, offset, cost })
263 } else {
264 None
265 }
266 })
267 } else {
268 None
269 }
270 })
271 .collect();
272
273 let len = func.elements().len() + calls.len() * (instrument_call!(0, 0, 0, 0).len() - 1);
275 let original_instrs = mem::replace(func.elements_mut(), Vec::with_capacity(len));
276 let new_instrs = func.elements_mut();
277
278 let mut calls = calls.into_iter().peekable();
279 for (original_pos, instr) in original_instrs.into_iter().enumerate() {
280 let did_instrument = if let Some(call) = calls.peek() {
282 if call.offset == original_pos {
283 let new_seq = instrument_call!(
284 call.callee,
285 call.cost as i32,
286 ctx.stack_height_global_idx(),
287 ctx.stack_limit()
288 );
289 new_instrs.extend_from_slice(&new_seq);
290 true
291 } else {
292 false
293 }
294 } else {
295 false
296 };
297
298 if did_instrument {
299 calls.next();
300 } else {
301 new_instrs.push(instr);
302 }
303 }
304
305 if calls.next().is_some() {
306 return Err("Not all calls were used")
307 }
308
309 Ok(())
310}
311
312fn resolve_func_type(
313 func_idx: u32,
314 module: &elements::Module,
315) -> Result<&elements::FunctionType, &'static str> {
316 let types = module.type_section().map(|ts| ts.types()).unwrap_or(&[]);
317 let functions = module.function_section().map(|fs| fs.entries()).unwrap_or(&[]);
318
319 let func_imports = module.import_count(elements::ImportCountType::Function);
320 let sig_idx = if func_idx < func_imports as u32 {
321 module
322 .import_section()
323 .expect("function import count is not zero; import section must exists; qed")
324 .entries()
325 .iter()
326 .filter_map(|entry| match entry.external() {
327 elements::External::Function(idx) => Some(*idx),
328 _ => None,
329 })
330 .nth(func_idx as usize)
331 .expect(
332 "func_idx is less than function imports count;
333 nth function import must be `Some`;
334 qed",
335 )
336 } else {
337 functions
338 .get(func_idx as usize - func_imports)
339 .ok_or("Function at the specified index is not defined")?
340 .type_ref()
341 };
342 let Type::Function(ty) = types
343 .get(sig_idx as usize)
344 .ok_or("The signature as specified by a function isn't defined")?;
345 Ok(ty)
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use parity_wasm::elements;
352
353 fn parse_wat(source: &str) -> elements::Module {
354 elements::deserialize_buffer(&wat::parse_str(source).expect("Failed to wat2wasm"))
355 .expect("Failed to deserialize the module")
356 }
357
358 fn validate_module(module: elements::Module) {
359 let binary = elements::serialize(module).expect("Failed to serialize");
360 wasmparser::validate(&binary).expect("Invalid module");
361 }
362
363 #[test]
364 fn test_with_params_and_result() {
365 let module = parse_wat(
366 r#"
367(module
368 (func (export "i32.add") (param i32 i32) (result i32)
369 get_local 0
370 get_local 1
371 i32.add
372 )
373)
374"#,
375 );
376
377 let module = inject(module, 1024).expect("Failed to inject stack counter");
378 validate_module(module);
379 }
380}