cranelift_frontend/
switch.rs

1use super::HashMap;
2use crate::frontend::FunctionBuilder;
3use alloc::vec::Vec;
4use core::convert::TryFrom;
5use cranelift_codegen::ir::condcodes::IntCC;
6use cranelift_codegen::ir::*;
7
8type EntryIndex = u128;
9
10/// Unlike with `br_table`, `Switch` cases may be sparse or non-0-based.
11/// They emit efficient code using branches, jump tables, or a combination of both.
12///
13/// # Example
14///
15/// ```rust
16/// # use cranelift_codegen::ir::types::*;
17/// # use cranelift_codegen::ir::{UserFuncName, Function, Signature, InstBuilder};
18/// # use cranelift_codegen::isa::CallConv;
19/// # use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext, Switch};
20/// #
21/// # let mut sig = Signature::new(CallConv::SystemV);
22/// # let mut fn_builder_ctx = FunctionBuilderContext::new();
23/// # let mut func = Function::with_name_signature(UserFuncName::user(0, 0), sig);
24/// # let mut builder = FunctionBuilder::new(&mut func, &mut fn_builder_ctx);
25/// #
26/// # let entry = builder.create_block();
27/// # builder.switch_to_block(entry);
28/// #
29/// let block0 = builder.create_block();
30/// let block1 = builder.create_block();
31/// let block2 = builder.create_block();
32/// let fallback = builder.create_block();
33///
34/// let val = builder.ins().iconst(I32, 1);
35///
36/// let mut switch = Switch::new();
37/// switch.set_entry(0, block0);
38/// switch.set_entry(1, block1);
39/// switch.set_entry(7, block2);
40/// switch.emit(&mut builder, val, fallback);
41/// ```
42#[derive(Debug, Default)]
43pub struct Switch {
44    cases: HashMap<EntryIndex, Block>,
45}
46
47impl Switch {
48    /// Create a new empty switch
49    pub fn new() -> Self {
50        Self {
51            cases: HashMap::new(),
52        }
53    }
54
55    /// Set a switch entry
56    pub fn set_entry(&mut self, index: EntryIndex, block: Block) {
57        let prev = self.cases.insert(index, block);
58        assert!(
59            prev.is_none(),
60            "Tried to set the same entry {} twice",
61            index
62        );
63    }
64
65    /// Get a reference to all existing entries
66    pub fn entries(&self) -> &HashMap<EntryIndex, Block> {
67        &self.cases
68    }
69
70    /// Turn the `cases` `HashMap` into a list of `ContiguousCaseRange`s.
71    ///
72    /// # Postconditions
73    ///
74    /// * Every entry will be represented.
75    /// * The `ContiguousCaseRange`s will not overlap.
76    /// * Between two `ContiguousCaseRange`s there will be at least one entry index.
77    /// * No `ContiguousCaseRange`s will be empty.
78    fn collect_contiguous_case_ranges(self) -> Vec<ContiguousCaseRange> {
79        log::trace!("build_contiguous_case_ranges before: {:#?}", self.cases);
80        let mut cases = self.cases.into_iter().collect::<Vec<(_, _)>>();
81        cases.sort_by_key(|&(index, _)| index);
82
83        let mut contiguous_case_ranges: Vec<ContiguousCaseRange> = vec![];
84        let mut last_index = None;
85        for (index, block) in cases {
86            match last_index {
87                None => contiguous_case_ranges.push(ContiguousCaseRange::new(index)),
88                Some(last_index) => {
89                    if index > last_index + 1 {
90                        contiguous_case_ranges.push(ContiguousCaseRange::new(index));
91                    }
92                }
93            }
94            contiguous_case_ranges
95                .last_mut()
96                .unwrap()
97                .blocks
98                .push(block);
99            last_index = Some(index);
100        }
101
102        log::trace!(
103            "build_contiguous_case_ranges after: {:#?}",
104            contiguous_case_ranges
105        );
106
107        contiguous_case_ranges
108    }
109
110    /// Binary search for the right `ContiguousCaseRange`.
111    fn build_search_tree<'a>(
112        bx: &mut FunctionBuilder,
113        val: Value,
114        otherwise: Block,
115        contiguous_case_ranges: &'a [ContiguousCaseRange],
116    ) {
117        // If no switch cases were added to begin with, we can just emit `jump otherwise`.
118        if contiguous_case_ranges.is_empty() {
119            bx.ins().jump(otherwise, &[]);
120            return;
121        }
122
123        // Avoid allocation in the common case
124        if contiguous_case_ranges.len() <= 3 {
125            Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
126            return;
127        }
128
129        let mut stack = Vec::new();
130        stack.push((None, contiguous_case_ranges));
131
132        while let Some((block, contiguous_case_ranges)) = stack.pop() {
133            if let Some(block) = block {
134                bx.switch_to_block(block);
135            }
136
137            if contiguous_case_ranges.len() <= 3 {
138                Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
139            } else {
140                let split_point = contiguous_case_ranges.len() / 2;
141                let (left, right) = contiguous_case_ranges.split_at(split_point);
142
143                let left_block = bx.create_block();
144                let right_block = bx.create_block();
145
146                let first_index = right[0].first_index;
147                let should_take_right_side =
148                    icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
149                bx.ins()
150                    .brif(should_take_right_side, right_block, &[], left_block, &[]);
151
152                bx.seal_block(left_block);
153                bx.seal_block(right_block);
154
155                stack.push((Some(left_block), left));
156                stack.push((Some(right_block), right));
157            }
158        }
159    }
160
161    /// Linear search for the right `ContiguousCaseRange`.
162    fn build_search_branches<'a>(
163        bx: &mut FunctionBuilder,
164        val: Value,
165        otherwise: Block,
166        contiguous_case_ranges: &'a [ContiguousCaseRange],
167    ) {
168        for (ix, range) in contiguous_case_ranges.iter().enumerate().rev() {
169            let alternate = if ix == 0 {
170                otherwise
171            } else {
172                bx.create_block()
173            };
174
175            if range.first_index == 0 {
176                assert_eq!(alternate, otherwise);
177
178                if let Some(block) = range.single_block() {
179                    bx.ins().brif(val, otherwise, &[], block, &[]);
180                } else {
181                    Self::build_jump_table(bx, val, otherwise, 0, &range.blocks);
182                }
183            } else {
184                if let Some(block) = range.single_block() {
185                    let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, range.first_index);
186                    bx.ins().brif(is_good_val, block, &[], alternate, &[]);
187                } else {
188                    let is_good_val = icmp_imm_u128(
189                        bx,
190                        IntCC::UnsignedGreaterThanOrEqual,
191                        val,
192                        range.first_index,
193                    );
194                    let jt_block = bx.create_block();
195                    bx.ins().brif(is_good_val, jt_block, &[], alternate, &[]);
196                    bx.seal_block(jt_block);
197                    bx.switch_to_block(jt_block);
198                    Self::build_jump_table(bx, val, otherwise, range.first_index, &range.blocks);
199                }
200            }
201
202            if alternate != otherwise {
203                bx.seal_block(alternate);
204                bx.switch_to_block(alternate);
205            }
206        }
207    }
208
209    fn build_jump_table(
210        bx: &mut FunctionBuilder,
211        val: Value,
212        otherwise: Block,
213        first_index: EntryIndex,
214        blocks: &[Block],
215    ) {
216        // There are currently no 128bit systems supported by rustc, but once we do ensure that
217        // we don't silently ignore a part of the jump table for 128bit integers on 128bit systems.
218        assert!(
219            u32::try_from(blocks.len()).is_ok(),
220            "Jump tables bigger than 2^32-1 are not yet supported"
221        );
222
223        let jt_data = JumpTableData::new(
224            bx.func.dfg.block_call(otherwise, &[]),
225            &blocks
226                .iter()
227                .map(|block| bx.func.dfg.block_call(*block, &[]))
228                .collect::<Vec<_>>(),
229        );
230        let jump_table = bx.create_jump_table(jt_data);
231
232        let discr = if first_index == 0 {
233            val
234        } else {
235            if let Ok(first_index) = u64::try_from(first_index) {
236                bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())
237            } else {
238                let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64);
239                let lsb = bx.ins().iconst(types::I64, lsb as i64);
240                let msb = bx.ins().iconst(types::I64, msb as i64);
241                let index = bx.ins().iconcat(lsb, msb);
242                bx.ins().isub(val, index)
243            }
244        };
245
246        let discr = match bx.func.dfg.value_type(discr).bits() {
247            bits if bits > 32 => {
248                // Check for overflow of cast to u32. This is the max supported jump table entries.
249                let new_block = bx.create_block();
250                let bigger_than_u32 =
251                    bx.ins()
252                        .icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64);
253                bx.ins()
254                    .brif(bigger_than_u32, otherwise, &[], new_block, &[]);
255                bx.seal_block(new_block);
256                bx.switch_to_block(new_block);
257
258                // Cast to i32, as br_table is not implemented for i64/i128
259                bx.ins().ireduce(types::I32, discr)
260            }
261            bits if bits < 32 => bx.ins().uextend(types::I32, discr),
262            _ => discr,
263        };
264
265        bx.ins().br_table(discr, jump_table);
266    }
267
268    /// Build the switch
269    ///
270    /// # Arguments
271    ///
272    /// * The function builder to emit to
273    /// * The value to switch on
274    /// * The default block
275    pub fn emit(self, bx: &mut FunctionBuilder, val: Value, otherwise: Block) {
276        // Validate that the type of `val` is sufficiently wide to address all cases.
277        let max = self.cases.keys().max().copied().unwrap_or(0);
278        let val_ty = bx.func.dfg.value_type(val);
279        let val_ty_max = val_ty.bounds(false).1;
280        if max > val_ty_max {
281            panic!(
282                "The index type {} does not fit the maximum switch entry of {}",
283                val_ty, max
284            );
285        }
286
287        let contiguous_case_ranges = self.collect_contiguous_case_ranges();
288        Self::build_search_tree(bx, val, otherwise, &contiguous_case_ranges);
289    }
290}
291
292fn icmp_imm_u128(bx: &mut FunctionBuilder, cond: IntCC, x: Value, y: u128) -> Value {
293    if let Ok(index) = u64::try_from(y) {
294        bx.ins().icmp_imm(cond, x, index as i64)
295    } else {
296        let (lsb, msb) = (y as u64, (y >> 64) as u64);
297        let lsb = bx.ins().iconst(types::I64, lsb as i64);
298        let msb = bx.ins().iconst(types::I64, msb as i64);
299        let index = bx.ins().iconcat(lsb, msb);
300        bx.ins().icmp(cond, x, index)
301    }
302}
303
304/// This represents a contiguous range of cases to switch on.
305///
306/// For example 10 => block1, 11 => block2, 12 => block7 will be represented as:
307///
308/// ```plain
309/// ContiguousCaseRange {
310///     first_index: 10,
311///     blocks: vec![Block::from_u32(1), Block::from_u32(2), Block::from_u32(7)]
312/// }
313/// ```
314#[derive(Debug)]
315struct ContiguousCaseRange {
316    /// The entry index of the first case. Eg. 10 when the entry indexes are 10, 11, 12 and 13.
317    first_index: EntryIndex,
318
319    /// The blocks to jump to sorted in ascending order of entry index.
320    blocks: Vec<Block>,
321}
322
323impl ContiguousCaseRange {
324    fn new(first_index: EntryIndex) -> Self {
325        Self {
326            first_index,
327            blocks: Vec::new(),
328        }
329    }
330
331    /// Returns `Some` block when there is only a single block in this range.
332    fn single_block(&self) -> Option<Block> {
333        if self.blocks.len() == 1 {
334            Some(self.blocks[0])
335        } else {
336            None
337        }
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344    use crate::frontend::FunctionBuilderContext;
345    use alloc::string::ToString;
346    use cranelift_codegen::ir::Function;
347
348    macro_rules! setup {
349        ($default:expr, [$($index:expr,)*]) => {{
350            let mut func = Function::new();
351            let mut func_ctx = FunctionBuilderContext::new();
352            {
353                let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
354                let block = bx.create_block();
355                bx.switch_to_block(block);
356                let val = bx.ins().iconst(types::I8, 0);
357                #[allow(unused_mut)]
358                let mut switch = Switch::new();
359                $(
360                    let block = bx.create_block();
361                    switch.set_entry($index, block);
362                )*
363                switch.emit(&mut bx, val, Block::with_number($default).unwrap());
364            }
365            func
366                .to_string()
367                .trim_start_matches("function u0:0() fast {\n")
368                .trim_end_matches("\n}\n")
369                .to_string()
370        }};
371    }
372
373    macro_rules! assert_eq_output {
374        ($actual:ident, $expected:literal) => {
375            assert_eq!(
376                $actual,
377                $expected,
378                "\n{}",
379                similar::TextDiff::from_lines($expected, &$actual)
380                    .unified_diff()
381                    .header("expected", "actual")
382            )
383        };
384    }
385
386    #[test]
387    fn switch_empty() {
388        let func = setup!(42, []);
389        assert_eq_output!(
390            func,
391            "block0:
392    v0 = iconst.i8 0
393    jump block42"
394        );
395    }
396
397    #[test]
398    fn switch_zero() {
399        let func = setup!(0, [0,]);
400        assert_eq_output!(
401            func,
402            "block0:
403    v0 = iconst.i8 0
404    brif v0, block0, block1  ; v0 = 0"
405        );
406    }
407
408    #[test]
409    fn switch_single() {
410        let func = setup!(0, [1,]);
411        assert_eq_output!(
412            func,
413            "block0:
414    v0 = iconst.i8 0
415    v1 = icmp_imm eq v0, 1  ; v0 = 0
416    brif v1, block1, block0"
417        );
418    }
419
420    #[test]
421    fn switch_bool() {
422        let func = setup!(0, [0, 1,]);
423        assert_eq_output!(
424            func,
425            "block0:
426    v0 = iconst.i8 0
427    v1 = uextend.i32 v0  ; v0 = 0
428    br_table v1, block0, [block1, block2]"
429        );
430    }
431
432    #[test]
433    fn switch_two_gap() {
434        let func = setup!(0, [0, 2,]);
435        assert_eq_output!(
436            func,
437            "block0:
438    v0 = iconst.i8 0
439    v1 = icmp_imm eq v0, 2  ; v0 = 0
440    brif v1, block2, block3
441
442block3:
443    brif.i8 v0, block0, block1  ; v0 = 0"
444        );
445    }
446
447    #[test]
448    fn switch_many() {
449        let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
450        assert_eq_output!(
451            func,
452            "block0:
453    v0 = iconst.i8 0
454    v1 = icmp_imm uge v0, 7  ; v0 = 0
455    brif v1, block9, block8
456
457block9:
458    v2 = icmp_imm.i8 uge v0, 10  ; v0 = 0
459    brif v2, block11, block10
460
461block11:
462    v3 = iadd_imm.i8 v0, -10  ; v0 = 0
463    v4 = uextend.i32 v3
464    br_table v4, block0, [block5, block6, block7]
465
466block10:
467    v5 = icmp_imm.i8 eq v0, 7  ; v0 = 0
468    brif v5, block4, block0
469
470block8:
471    v6 = icmp_imm.i8 eq v0, 5  ; v0 = 0
472    brif v6, block3, block12
473
474block12:
475    v7 = uextend.i32 v0  ; v0 = 0
476    br_table v7, block0, [block1, block2]"
477        );
478    }
479
480    #[test]
481    fn switch_min_index_value() {
482        let func = setup!(0, [i8::MIN as u8 as u128, 1,]);
483        assert_eq_output!(
484            func,
485            "block0:
486    v0 = iconst.i8 0
487    v1 = icmp_imm eq v0, 128  ; v0 = 0
488    brif v1, block1, block3
489
490block3:
491    v2 = icmp_imm.i8 eq v0, 1  ; v0 = 0
492    brif v2, block2, block0"
493        );
494    }
495
496    #[test]
497    fn switch_max_index_value() {
498        let func = setup!(0, [i8::MAX as u8 as u128, 1,]);
499        assert_eq_output!(
500            func,
501            "block0:
502    v0 = iconst.i8 0
503    v1 = icmp_imm eq v0, 127  ; v0 = 0
504    brif v1, block1, block3
505
506block3:
507    v2 = icmp_imm.i8 eq v0, 1  ; v0 = 0
508    brif v2, block2, block0"
509        )
510    }
511
512    #[test]
513    fn switch_optimal_codegen() {
514        let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]);
515        assert_eq_output!(
516            func,
517            "block0:
518    v0 = iconst.i8 0
519    v1 = icmp_imm eq v0, 255  ; v0 = 0
520    brif v1, block1, block4
521
522block4:
523    v2 = uextend.i32 v0  ; v0 = 0
524    br_table v2, block0, [block2, block3]"
525        );
526    }
527
528    #[test]
529    #[should_panic(
530        expected = "The index type i8 does not fit the maximum switch entry of 4683743612477887600"
531    )]
532    fn switch_rejects_small_inputs() {
533        // This is a regression test for a bug that we found where we would emit a cmp
534        // with a type that was not able to fully represent a large index.
535        //
536        // See: https://github.com/bytecodealliance/wasmtime/pull/4502#issuecomment-1191961677
537        setup!(1, [0x4100_0000_00bf_d470,]);
538    }
539
540    #[test]
541    fn switch_seal_generated_blocks() {
542        let cases = &[vec![0, 1, 2], vec![0, 1, 2, 10, 11, 12, 20, 30, 40, 50]];
543
544        for case in cases {
545            for typ in &[types::I8, types::I16, types::I32, types::I64, types::I128] {
546                eprintln!("Testing {:?} with keys: {:?}", typ, case);
547                do_case(case, *typ);
548            }
549        }
550
551        fn do_case(keys: &[u128], typ: Type) {
552            let mut func = Function::new();
553            let mut builder_ctx = FunctionBuilderContext::new();
554            let mut builder = FunctionBuilder::new(&mut func, &mut builder_ctx);
555
556            let root_block = builder.create_block();
557            let default_block = builder.create_block();
558            let mut switch = Switch::new();
559
560            let case_blocks = keys
561                .iter()
562                .map(|key| {
563                    let block = builder.create_block();
564                    switch.set_entry(*key, block);
565                    block
566                })
567                .collect::<Vec<_>>();
568
569            builder.seal_block(root_block);
570            builder.switch_to_block(root_block);
571
572            let val = builder.ins().iconst(typ, 1);
573            switch.emit(&mut builder, val, default_block);
574
575            for &block in case_blocks.iter().chain(std::iter::once(&default_block)) {
576                builder.seal_block(block);
577                builder.switch_to_block(block);
578                builder.ins().return_(&[]);
579            }
580
581            builder.finalize(); // Will panic if some blocks are not sealed
582        }
583    }
584
585    #[test]
586    fn switch_64bit() {
587        let mut func = Function::new();
588        let mut func_ctx = FunctionBuilderContext::new();
589        {
590            let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
591            let block0 = bx.create_block();
592            bx.switch_to_block(block0);
593            let val = bx.ins().iconst(types::I64, 0);
594            let mut switch = Switch::new();
595            let block1 = bx.create_block();
596            switch.set_entry(1, block1);
597            let block2 = bx.create_block();
598            switch.set_entry(0, block2);
599            let block3 = bx.create_block();
600            switch.emit(&mut bx, val, block3);
601        }
602        let func = func
603            .to_string()
604            .trim_start_matches("function u0:0() fast {\n")
605            .trim_end_matches("\n}\n")
606            .to_string();
607        assert_eq_output!(
608            func,
609            "block0:
610    v0 = iconst.i64 0
611    v1 = icmp_imm ugt v0, 0xffff_ffff  ; v0 = 0
612    brif v1, block3, block4
613
614block4:
615    v2 = ireduce.i32 v0  ; v0 = 0
616    br_table v2, block3, [block2, block1]"
617        );
618    }
619
620    #[test]
621    fn switch_128bit() {
622        let mut func = Function::new();
623        let mut func_ctx = FunctionBuilderContext::new();
624        {
625            let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
626            let block0 = bx.create_block();
627            bx.switch_to_block(block0);
628            let val = bx.ins().iconst(types::I64, 0);
629            let val = bx.ins().uextend(types::I128, val);
630            let mut switch = Switch::new();
631            let block1 = bx.create_block();
632            switch.set_entry(1, block1);
633            let block2 = bx.create_block();
634            switch.set_entry(0, block2);
635            let block3 = bx.create_block();
636            switch.emit(&mut bx, val, block3);
637        }
638        let func = func
639            .to_string()
640            .trim_start_matches("function u0:0() fast {\n")
641            .trim_end_matches("\n}\n")
642            .to_string();
643        assert_eq_output!(
644            func,
645            "block0:
646    v0 = iconst.i64 0
647    v1 = uextend.i128 v0  ; v0 = 0
648    v2 = icmp_imm ugt v1, 0xffff_ffff
649    brif v2, block3, block4
650
651block4:
652    v3 = ireduce.i32 v1
653    br_table v3, block3, [block2, block1]"
654        );
655    }
656}