wasm_instrument/stack_limiter/
max_height.rs1use super::resolve_func_type;
2use alloc::vec::Vec;
3use parity_wasm::elements::{self, BlockType, Type};
4
5#[cfg(feature = "sign_ext")]
6use parity_wasm::elements::SignExtInstruction;
7
8const ACTIVATION_FRAME_COST: u32 = 2;
13
14#[derive(Debug)]
16struct Frame {
17 is_polymorphic: bool,
20
21 end_arity: u32,
24
25 branch_arity: u32,
31
32 start_height: u32,
34}
35
36struct Stack {
39 height: u32,
40 control_stack: Vec<Frame>,
41}
42
43impl Stack {
44 fn new() -> Stack {
45 Stack { height: ACTIVATION_FRAME_COST, control_stack: Vec::new() }
46 }
47
48 fn height(&self) -> u32 {
50 self.height
51 }
52
53 fn frame(&self, rel_depth: u32) -> Result<&Frame, &'static str> {
56 let control_stack_height: usize = self.control_stack.len();
57 let last_idx = control_stack_height.checked_sub(1).ok_or("control stack is empty")?;
58 let idx = last_idx.checked_sub(rel_depth as usize).ok_or("control stack out-of-bounds")?;
59 Ok(&self.control_stack[idx])
60 }
61
62 fn mark_unreachable(&mut self) -> Result<(), &'static str> {
66 let top_frame = self.control_stack.last_mut().ok_or("stack must be non-empty")?;
67 top_frame.is_polymorphic = true;
68 Ok(())
69 }
70
71 fn push_frame(&mut self, frame: Frame) {
73 self.control_stack.push(frame);
74 }
75
76 fn pop_frame(&mut self) -> Result<Frame, &'static str> {
80 self.control_stack.pop().ok_or("stack must be non-empty")
81 }
82
83 fn trunc(&mut self, new_height: u32) {
85 self.height = new_height;
86 }
87
88 fn push_values(&mut self, value_count: u32) -> Result<(), &'static str> {
92 self.height = self.height.checked_add(value_count).ok_or("stack overflow")?;
93 Ok(())
94 }
95
96 fn pop_values(&mut self, value_count: u32) -> Result<(), &'static str> {
101 if value_count == 0 {
102 return Ok(())
103 }
104 {
105 let top_frame = self.frame(0)?;
106 if self.height == top_frame.start_height {
107 return if top_frame.is_polymorphic {
111 Ok(())
112 } else {
113 return Err("trying to pop more values than pushed")
114 }
115 }
116 }
117
118 self.height = self.height.checked_sub(value_count).ok_or("stack underflow")?;
119
120 Ok(())
121 }
122}
123
124pub fn compute(func_idx: u32, module: &elements::Module) -> Result<u32, &'static str> {
126 use parity_wasm::elements::Instruction::*;
127
128 let func_section = module.function_section().ok_or("No function section")?;
129 let code_section = module.code_section().ok_or("No code section")?;
130 let type_section = module.type_section().ok_or("No type section")?;
131
132 let func_sig_idx = func_section
134 .entries()
135 .get(func_idx as usize)
136 .ok_or("Function is not found in func section")?
137 .type_ref();
138 let Type::Function(func_signature) = type_section
139 .types()
140 .get(func_sig_idx as usize)
141 .ok_or("Function is not found in func section")?;
142 let body = code_section
143 .bodies()
144 .get(func_idx as usize)
145 .ok_or("Function body for the index isn't found")?;
146 let instructions = body.code();
147
148 let mut stack = Stack::new();
149 let mut max_height: u32 = 0;
150 let mut pc = 0;
151
152 let func_arity = func_signature.results().len() as u32;
155 stack.push_frame(Frame {
156 is_polymorphic: false,
157 end_arity: func_arity,
158 branch_arity: func_arity,
159 start_height: 0,
160 });
161
162 loop {
163 if pc >= instructions.elements().len() {
164 break
165 }
166
167 if stack.height() > max_height && !stack.frame(0)?.is_polymorphic {
171 max_height = stack.height();
172 }
173
174 let opcode = &instructions.elements()[pc];
175
176 match opcode {
177 Nop => {},
178 Block(ty) | Loop(ty) | If(ty) => {
179 let end_arity = if *ty == BlockType::NoResult { 0 } else { 1 };
180 let branch_arity = if let Loop(_) = *opcode { 0 } else { end_arity };
181 if let If(_) = *opcode {
182 stack.pop_values(1)?;
183 }
184 let height = stack.height();
185 stack.push_frame(Frame {
186 is_polymorphic: false,
187 end_arity,
188 branch_arity,
189 start_height: height,
190 });
191 },
192 Else => {
193 },
196 End => {
197 let frame = stack.pop_frame()?;
198 stack.trunc(frame.start_height);
199 stack.push_values(frame.end_arity)?;
200 },
201 Unreachable => {
202 stack.mark_unreachable()?;
203 },
204 Br(target) => {
205 let target_arity = stack.frame(*target)?.branch_arity;
207 stack.pop_values(target_arity)?;
208
209 stack.mark_unreachable()?;
212 },
213 BrIf(target) => {
214 let target_arity = stack.frame(*target)?.branch_arity;
216 stack.pop_values(target_arity)?;
217
218 stack.pop_values(1)?;
220
221 stack.push_values(target_arity)?;
223 },
224 BrTable(br_table_data) => {
225 let arity_of_default = stack.frame(br_table_data.default)?.branch_arity;
226
227 for target in &*br_table_data.table {
229 let arity = stack.frame(*target)?.branch_arity;
230 if arity != arity_of_default {
231 return Err("Arity of all jump-targets must be equal")
232 }
233 }
234
235 stack.pop_values(arity_of_default)?;
238
239 stack.mark_unreachable()?;
242 },
243 Return => {
244 stack.pop_values(func_arity)?;
247 stack.mark_unreachable()?;
248 },
249 Call(idx) => {
250 let ty = resolve_func_type(*idx, module)?;
251
252 stack.pop_values(ty.params().len() as u32)?;
254
255 let callee_arity = ty.results().len() as u32;
257 stack.push_values(callee_arity)?;
258 },
259 CallIndirect(x, _) => {
260 let Type::Function(ty) =
261 type_section.types().get(*x as usize).ok_or("Type not found")?;
262
263 stack.pop_values(1)?;
265
266 stack.pop_values(ty.params().len() as u32)?;
268
269 let callee_arity = ty.results().len() as u32;
271 stack.push_values(callee_arity)?;
272 },
273 Drop => {
274 stack.pop_values(1)?;
275 },
276 Select => {
277 stack.pop_values(2)?;
279 stack.pop_values(1)?;
280
281 stack.push_values(1)?;
283 },
284 GetLocal(_) => {
285 stack.push_values(1)?;
286 },
287 SetLocal(_) => {
288 stack.pop_values(1)?;
289 },
290 TeeLocal(_) => {
291 stack.pop_values(1)?;
294 stack.push_values(1)?;
295 },
296 GetGlobal(_) => {
297 stack.push_values(1)?;
298 },
299 SetGlobal(_) => {
300 stack.pop_values(1)?;
301 },
302 I32Load(_, _) |
303 I64Load(_, _) |
304 F32Load(_, _) |
305 F64Load(_, _) |
306 I32Load8S(_, _) |
307 I32Load8U(_, _) |
308 I32Load16S(_, _) |
309 I32Load16U(_, _) |
310 I64Load8S(_, _) |
311 I64Load8U(_, _) |
312 I64Load16S(_, _) |
313 I64Load16U(_, _) |
314 I64Load32S(_, _) |
315 I64Load32U(_, _) => {
316 stack.pop_values(1)?;
319 stack.push_values(1)?;
320 },
321
322 I32Store(_, _) |
323 I64Store(_, _) |
324 F32Store(_, _) |
325 F64Store(_, _) |
326 I32Store8(_, _) |
327 I32Store16(_, _) |
328 I64Store8(_, _) |
329 I64Store16(_, _) |
330 I64Store32(_, _) => {
331 stack.pop_values(2)?;
333 },
334
335 CurrentMemory(_) => {
336 stack.push_values(1)?;
338 },
339 GrowMemory(_) => {
340 stack.pop_values(1)?;
342 stack.push_values(1)?;
343 },
344
345 I32Const(_) | I64Const(_) | F32Const(_) | F64Const(_) => {
346 stack.push_values(1)?;
348 },
349
350 I32Eqz | I64Eqz => {
351 stack.pop_values(1)?;
354 stack.push_values(1)?;
355 },
356
357 I32Eq | I32Ne | I32LtS | I32LtU | I32GtS | I32GtU | I32LeS | I32LeU | I32GeS |
358 I32GeU | I64Eq | I64Ne | I64LtS | I64LtU | I64GtS | I64GtU | I64LeS | I64LeU |
359 I64GeS | I64GeU | F32Eq | F32Ne | F32Lt | F32Gt | F32Le | F32Ge | F64Eq | F64Ne |
360 F64Lt | F64Gt | F64Le | F64Ge => {
361 stack.pop_values(2)?;
363 stack.push_values(1)?;
364 },
365
366 I32Clz | I32Ctz | I32Popcnt | I64Clz | I64Ctz | I64Popcnt | F32Abs | F32Neg |
367 F32Ceil | F32Floor | F32Trunc | F32Nearest | F32Sqrt | F64Abs | F64Neg | F64Ceil |
368 F64Floor | F64Trunc | F64Nearest | F64Sqrt => {
369 stack.pop_values(1)?;
371 stack.push_values(1)?;
372 },
373
374 I32Add | I32Sub | I32Mul | I32DivS | I32DivU | I32RemS | I32RemU | I32And | I32Or |
375 I32Xor | I32Shl | I32ShrS | I32ShrU | I32Rotl | I32Rotr | I64Add | I64Sub |
376 I64Mul | I64DivS | I64DivU | I64RemS | I64RemU | I64And | I64Or | I64Xor | I64Shl |
377 I64ShrS | I64ShrU | I64Rotl | I64Rotr | F32Add | F32Sub | F32Mul | F32Div |
378 F32Min | F32Max | F32Copysign | F64Add | F64Sub | F64Mul | F64Div | F64Min |
379 F64Max | F64Copysign => {
380 stack.pop_values(2)?;
382 stack.push_values(1)?;
383 },
384
385 I32WrapI64 | I32TruncSF32 | I32TruncUF32 | I32TruncSF64 | I32TruncUF64 |
386 I64ExtendSI32 | I64ExtendUI32 | I64TruncSF32 | I64TruncUF32 | I64TruncSF64 |
387 I64TruncUF64 | F32ConvertSI32 | F32ConvertUI32 | F32ConvertSI64 | F32ConvertUI64 |
388 F32DemoteF64 | F64ConvertSI32 | F64ConvertUI32 | F64ConvertSI64 | F64ConvertUI64 |
389 F64PromoteF32 | I32ReinterpretF32 | I64ReinterpretF64 | F32ReinterpretI32 |
390 F64ReinterpretI64 => {
391 stack.pop_values(1)?;
393 stack.push_values(1)?;
394 },
395
396 #[cfg(feature = "sign_ext")]
397 SignExt(SignExtInstruction::I32Extend8S) |
398 SignExt(SignExtInstruction::I32Extend16S) |
399 SignExt(SignExtInstruction::I64Extend8S) |
400 SignExt(SignExtInstruction::I64Extend16S) |
401 SignExt(SignExtInstruction::I64Extend32S) => {
402 stack.pop_values(1)?;
403 stack.push_values(1)?;
404 },
405 }
406 pc += 1;
407 }
408
409 Ok(max_height)
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use parity_wasm::elements;
416
417 fn parse_wat(source: &str) -> elements::Module {
418 elements::deserialize_buffer(&wat::parse_str(source).expect("Failed to wat2wasm"))
419 .expect("Failed to deserialize the module")
420 }
421
422 #[test]
423 fn simple_test() {
424 let module = parse_wat(
425 r#"
426(module
427 (func
428 i32.const 1
429 i32.const 2
430 i32.const 3
431 drop
432 drop
433 drop
434 )
435)
436"#,
437 );
438
439 let height = compute(0, &module).unwrap();
440 assert_eq!(height, 3 + ACTIVATION_FRAME_COST);
441 }
442
443 #[test]
444 fn implicit_and_explicit_return() {
445 let module = parse_wat(
446 r#"
447(module
448 (func (result i32)
449 i32.const 0
450 return
451 )
452)
453"#,
454 );
455
456 let height = compute(0, &module).unwrap();
457 assert_eq!(height, 1 + ACTIVATION_FRAME_COST);
458 }
459
460 #[test]
461 fn dont_count_in_unreachable() {
462 let module = parse_wat(
463 r#"
464(module
465 (memory 0)
466 (func (result i32)
467 unreachable
468 grow_memory
469 )
470)
471"#,
472 );
473
474 let height = compute(0, &module).unwrap();
475 assert_eq!(height, ACTIVATION_FRAME_COST);
476 }
477
478 #[test]
479 fn yet_another_test() {
480 let module = parse_wat(
481 r#"
482(module
483 (memory 0)
484 (func
485 ;; Push two values and then pop them.
486 ;; This will make max depth to be equal to 2.
487 i32.const 0
488 i32.const 1
489 drop
490 drop
491
492 ;; Code after `unreachable` shouldn't have an effect
493 ;; on the max depth.
494 unreachable
495 i32.const 0
496 i32.const 1
497 i32.const 2
498 )
499)
500"#,
501 );
502
503 let height = compute(0, &module).unwrap();
504 assert_eq!(height, 2 + ACTIVATION_FRAME_COST);
505 }
506
507 #[test]
508 fn call_indirect() {
509 let module = parse_wat(
510 r#"
511(module
512 (table $ptr 1 1 funcref)
513 (elem $ptr (i32.const 0) func 1)
514 (func $main
515 (call_indirect (i32.const 0))
516 (call_indirect (i32.const 0))
517 (call_indirect (i32.const 0))
518 )
519 (func $callee
520 i64.const 42
521 drop
522 )
523)
524"#,
525 );
526
527 let height = compute(0, &module).unwrap();
528 assert_eq!(height, 1 + ACTIVATION_FRAME_COST);
529 }
530
531 #[test]
532 fn breaks() {
533 let module = parse_wat(
534 r#"
535(module
536 (func $main
537 block (result i32)
538 block (result i32)
539 i32.const 99
540 br 1
541 end
542 end
543 drop
544 )
545)
546"#,
547 );
548
549 let height = compute(0, &module).unwrap();
550 assert_eq!(height, 1 + ACTIVATION_FRAME_COST);
551 }
552
553 #[test]
554 fn if_else_works() {
555 let module = parse_wat(
556 r#"
557(module
558 (func $main
559 i32.const 7
560 i32.const 1
561 if (result i32)
562 i32.const 42
563 else
564 i32.const 99
565 end
566 i32.const 97
567 drop
568 drop
569 drop
570 )
571)
572"#,
573 );
574
575 let height = compute(0, &module).unwrap();
576 assert_eq!(height, 3 + ACTIVATION_FRAME_COST);
577 }
578}