1use super::HashMap;
2use crate::frontend::FunctionBuilder;
3use alloc::vec::Vec;
4use core::convert::TryFrom;
5use cranelift_codegen::ir::condcodes::IntCC;
6use cranelift_codegen::ir::*;
7
8type EntryIndex = u128;
9
10#[derive(Debug, Default)]
43pub struct Switch {
44 cases: HashMap<EntryIndex, Block>,
45}
46
47impl Switch {
48 pub fn new() -> Self {
50 Self {
51 cases: HashMap::new(),
52 }
53 }
54
55 pub fn set_entry(&mut self, index: EntryIndex, block: Block) {
57 let prev = self.cases.insert(index, block);
58 assert!(
59 prev.is_none(),
60 "Tried to set the same entry {} twice",
61 index
62 );
63 }
64
65 pub fn entries(&self) -> &HashMap<EntryIndex, Block> {
67 &self.cases
68 }
69
70 fn collect_contiguous_case_ranges(self) -> Vec<ContiguousCaseRange> {
79 log::trace!("build_contiguous_case_ranges before: {:#?}", self.cases);
80 let mut cases = self.cases.into_iter().collect::<Vec<(_, _)>>();
81 cases.sort_by_key(|&(index, _)| index);
82
83 let mut contiguous_case_ranges: Vec<ContiguousCaseRange> = vec![];
84 let mut last_index = None;
85 for (index, block) in cases {
86 match last_index {
87 None => contiguous_case_ranges.push(ContiguousCaseRange::new(index)),
88 Some(last_index) => {
89 if index > last_index + 1 {
90 contiguous_case_ranges.push(ContiguousCaseRange::new(index));
91 }
92 }
93 }
94 contiguous_case_ranges
95 .last_mut()
96 .unwrap()
97 .blocks
98 .push(block);
99 last_index = Some(index);
100 }
101
102 log::trace!(
103 "build_contiguous_case_ranges after: {:#?}",
104 contiguous_case_ranges
105 );
106
107 contiguous_case_ranges
108 }
109
110 fn build_search_tree<'a>(
112 bx: &mut FunctionBuilder,
113 val: Value,
114 otherwise: Block,
115 contiguous_case_ranges: &'a [ContiguousCaseRange],
116 ) {
117 if contiguous_case_ranges.is_empty() {
119 bx.ins().jump(otherwise, &[]);
120 return;
121 }
122
123 if contiguous_case_ranges.len() <= 3 {
125 Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
126 return;
127 }
128
129 let mut stack = Vec::new();
130 stack.push((None, contiguous_case_ranges));
131
132 while let Some((block, contiguous_case_ranges)) = stack.pop() {
133 if let Some(block) = block {
134 bx.switch_to_block(block);
135 }
136
137 if contiguous_case_ranges.len() <= 3 {
138 Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
139 } else {
140 let split_point = contiguous_case_ranges.len() / 2;
141 let (left, right) = contiguous_case_ranges.split_at(split_point);
142
143 let left_block = bx.create_block();
144 let right_block = bx.create_block();
145
146 let first_index = right[0].first_index;
147 let should_take_right_side =
148 icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
149 bx.ins()
150 .brif(should_take_right_side, right_block, &[], left_block, &[]);
151
152 bx.seal_block(left_block);
153 bx.seal_block(right_block);
154
155 stack.push((Some(left_block), left));
156 stack.push((Some(right_block), right));
157 }
158 }
159 }
160
161 fn build_search_branches<'a>(
163 bx: &mut FunctionBuilder,
164 val: Value,
165 otherwise: Block,
166 contiguous_case_ranges: &'a [ContiguousCaseRange],
167 ) {
168 for (ix, range) in contiguous_case_ranges.iter().enumerate().rev() {
169 let alternate = if ix == 0 {
170 otherwise
171 } else {
172 bx.create_block()
173 };
174
175 if range.first_index == 0 {
176 assert_eq!(alternate, otherwise);
177
178 if let Some(block) = range.single_block() {
179 bx.ins().brif(val, otherwise, &[], block, &[]);
180 } else {
181 Self::build_jump_table(bx, val, otherwise, 0, &range.blocks);
182 }
183 } else {
184 if let Some(block) = range.single_block() {
185 let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, range.first_index);
186 bx.ins().brif(is_good_val, block, &[], alternate, &[]);
187 } else {
188 let is_good_val = icmp_imm_u128(
189 bx,
190 IntCC::UnsignedGreaterThanOrEqual,
191 val,
192 range.first_index,
193 );
194 let jt_block = bx.create_block();
195 bx.ins().brif(is_good_val, jt_block, &[], alternate, &[]);
196 bx.seal_block(jt_block);
197 bx.switch_to_block(jt_block);
198 Self::build_jump_table(bx, val, otherwise, range.first_index, &range.blocks);
199 }
200 }
201
202 if alternate != otherwise {
203 bx.seal_block(alternate);
204 bx.switch_to_block(alternate);
205 }
206 }
207 }
208
209 fn build_jump_table(
210 bx: &mut FunctionBuilder,
211 val: Value,
212 otherwise: Block,
213 first_index: EntryIndex,
214 blocks: &[Block],
215 ) {
216 assert!(
219 u32::try_from(blocks.len()).is_ok(),
220 "Jump tables bigger than 2^32-1 are not yet supported"
221 );
222
223 let jt_data = JumpTableData::new(
224 bx.func.dfg.block_call(otherwise, &[]),
225 &blocks
226 .iter()
227 .map(|block| bx.func.dfg.block_call(*block, &[]))
228 .collect::<Vec<_>>(),
229 );
230 let jump_table = bx.create_jump_table(jt_data);
231
232 let discr = if first_index == 0 {
233 val
234 } else {
235 if let Ok(first_index) = u64::try_from(first_index) {
236 bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())
237 } else {
238 let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64);
239 let lsb = bx.ins().iconst(types::I64, lsb as i64);
240 let msb = bx.ins().iconst(types::I64, msb as i64);
241 let index = bx.ins().iconcat(lsb, msb);
242 bx.ins().isub(val, index)
243 }
244 };
245
246 let discr = match bx.func.dfg.value_type(discr).bits() {
247 bits if bits > 32 => {
248 let new_block = bx.create_block();
250 let bigger_than_u32 =
251 bx.ins()
252 .icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64);
253 bx.ins()
254 .brif(bigger_than_u32, otherwise, &[], new_block, &[]);
255 bx.seal_block(new_block);
256 bx.switch_to_block(new_block);
257
258 bx.ins().ireduce(types::I32, discr)
260 }
261 bits if bits < 32 => bx.ins().uextend(types::I32, discr),
262 _ => discr,
263 };
264
265 bx.ins().br_table(discr, jump_table);
266 }
267
268 pub fn emit(self, bx: &mut FunctionBuilder, val: Value, otherwise: Block) {
276 let max = self.cases.keys().max().copied().unwrap_or(0);
278 let val_ty = bx.func.dfg.value_type(val);
279 let val_ty_max = val_ty.bounds(false).1;
280 if max > val_ty_max {
281 panic!(
282 "The index type {} does not fit the maximum switch entry of {}",
283 val_ty, max
284 );
285 }
286
287 let contiguous_case_ranges = self.collect_contiguous_case_ranges();
288 Self::build_search_tree(bx, val, otherwise, &contiguous_case_ranges);
289 }
290}
291
292fn icmp_imm_u128(bx: &mut FunctionBuilder, cond: IntCC, x: Value, y: u128) -> Value {
293 if let Ok(index) = u64::try_from(y) {
294 bx.ins().icmp_imm(cond, x, index as i64)
295 } else {
296 let (lsb, msb) = (y as u64, (y >> 64) as u64);
297 let lsb = bx.ins().iconst(types::I64, lsb as i64);
298 let msb = bx.ins().iconst(types::I64, msb as i64);
299 let index = bx.ins().iconcat(lsb, msb);
300 bx.ins().icmp(cond, x, index)
301 }
302}
303
304#[derive(Debug)]
315struct ContiguousCaseRange {
316 first_index: EntryIndex,
318
319 blocks: Vec<Block>,
321}
322
323impl ContiguousCaseRange {
324 fn new(first_index: EntryIndex) -> Self {
325 Self {
326 first_index,
327 blocks: Vec::new(),
328 }
329 }
330
331 fn single_block(&self) -> Option<Block> {
333 if self.blocks.len() == 1 {
334 Some(self.blocks[0])
335 } else {
336 None
337 }
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344 use crate::frontend::FunctionBuilderContext;
345 use alloc::string::ToString;
346 use cranelift_codegen::ir::Function;
347
348 macro_rules! setup {
349 ($default:expr, [$($index:expr,)*]) => {{
350 let mut func = Function::new();
351 let mut func_ctx = FunctionBuilderContext::new();
352 {
353 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
354 let block = bx.create_block();
355 bx.switch_to_block(block);
356 let val = bx.ins().iconst(types::I8, 0);
357 #[allow(unused_mut)]
358 let mut switch = Switch::new();
359 $(
360 let block = bx.create_block();
361 switch.set_entry($index, block);
362 )*
363 switch.emit(&mut bx, val, Block::with_number($default).unwrap());
364 }
365 func
366 .to_string()
367 .trim_start_matches("function u0:0() fast {\n")
368 .trim_end_matches("\n}\n")
369 .to_string()
370 }};
371 }
372
373 macro_rules! assert_eq_output {
374 ($actual:ident, $expected:literal) => {
375 assert_eq!(
376 $actual,
377 $expected,
378 "\n{}",
379 similar::TextDiff::from_lines($expected, &$actual)
380 .unified_diff()
381 .header("expected", "actual")
382 )
383 };
384 }
385
386 #[test]
387 fn switch_empty() {
388 let func = setup!(42, []);
389 assert_eq_output!(
390 func,
391 "block0:
392 v0 = iconst.i8 0
393 jump block42"
394 );
395 }
396
397 #[test]
398 fn switch_zero() {
399 let func = setup!(0, [0,]);
400 assert_eq_output!(
401 func,
402 "block0:
403 v0 = iconst.i8 0
404 brif v0, block0, block1 ; v0 = 0"
405 );
406 }
407
408 #[test]
409 fn switch_single() {
410 let func = setup!(0, [1,]);
411 assert_eq_output!(
412 func,
413 "block0:
414 v0 = iconst.i8 0
415 v1 = icmp_imm eq v0, 1 ; v0 = 0
416 brif v1, block1, block0"
417 );
418 }
419
420 #[test]
421 fn switch_bool() {
422 let func = setup!(0, [0, 1,]);
423 assert_eq_output!(
424 func,
425 "block0:
426 v0 = iconst.i8 0
427 v1 = uextend.i32 v0 ; v0 = 0
428 br_table v1, block0, [block1, block2]"
429 );
430 }
431
432 #[test]
433 fn switch_two_gap() {
434 let func = setup!(0, [0, 2,]);
435 assert_eq_output!(
436 func,
437 "block0:
438 v0 = iconst.i8 0
439 v1 = icmp_imm eq v0, 2 ; v0 = 0
440 brif v1, block2, block3
441
442block3:
443 brif.i8 v0, block0, block1 ; v0 = 0"
444 );
445 }
446
447 #[test]
448 fn switch_many() {
449 let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
450 assert_eq_output!(
451 func,
452 "block0:
453 v0 = iconst.i8 0
454 v1 = icmp_imm uge v0, 7 ; v0 = 0
455 brif v1, block9, block8
456
457block9:
458 v2 = icmp_imm.i8 uge v0, 10 ; v0 = 0
459 brif v2, block11, block10
460
461block11:
462 v3 = iadd_imm.i8 v0, -10 ; v0 = 0
463 v4 = uextend.i32 v3
464 br_table v4, block0, [block5, block6, block7]
465
466block10:
467 v5 = icmp_imm.i8 eq v0, 7 ; v0 = 0
468 brif v5, block4, block0
469
470block8:
471 v6 = icmp_imm.i8 eq v0, 5 ; v0 = 0
472 brif v6, block3, block12
473
474block12:
475 v7 = uextend.i32 v0 ; v0 = 0
476 br_table v7, block0, [block1, block2]"
477 );
478 }
479
480 #[test]
481 fn switch_min_index_value() {
482 let func = setup!(0, [i8::MIN as u8 as u128, 1,]);
483 assert_eq_output!(
484 func,
485 "block0:
486 v0 = iconst.i8 0
487 v1 = icmp_imm eq v0, 128 ; v0 = 0
488 brif v1, block1, block3
489
490block3:
491 v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
492 brif v2, block2, block0"
493 );
494 }
495
496 #[test]
497 fn switch_max_index_value() {
498 let func = setup!(0, [i8::MAX as u8 as u128, 1,]);
499 assert_eq_output!(
500 func,
501 "block0:
502 v0 = iconst.i8 0
503 v1 = icmp_imm eq v0, 127 ; v0 = 0
504 brif v1, block1, block3
505
506block3:
507 v2 = icmp_imm.i8 eq v0, 1 ; v0 = 0
508 brif v2, block2, block0"
509 )
510 }
511
512 #[test]
513 fn switch_optimal_codegen() {
514 let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]);
515 assert_eq_output!(
516 func,
517 "block0:
518 v0 = iconst.i8 0
519 v1 = icmp_imm eq v0, 255 ; v0 = 0
520 brif v1, block1, block4
521
522block4:
523 v2 = uextend.i32 v0 ; v0 = 0
524 br_table v2, block0, [block2, block3]"
525 );
526 }
527
528 #[test]
529 #[should_panic(
530 expected = "The index type i8 does not fit the maximum switch entry of 4683743612477887600"
531 )]
532 fn switch_rejects_small_inputs() {
533 setup!(1, [0x4100_0000_00bf_d470,]);
538 }
539
540 #[test]
541 fn switch_seal_generated_blocks() {
542 let cases = &[vec![0, 1, 2], vec![0, 1, 2, 10, 11, 12, 20, 30, 40, 50]];
543
544 for case in cases {
545 for typ in &[types::I8, types::I16, types::I32, types::I64, types::I128] {
546 eprintln!("Testing {:?} with keys: {:?}", typ, case);
547 do_case(case, *typ);
548 }
549 }
550
551 fn do_case(keys: &[u128], typ: Type) {
552 let mut func = Function::new();
553 let mut builder_ctx = FunctionBuilderContext::new();
554 let mut builder = FunctionBuilder::new(&mut func, &mut builder_ctx);
555
556 let root_block = builder.create_block();
557 let default_block = builder.create_block();
558 let mut switch = Switch::new();
559
560 let case_blocks = keys
561 .iter()
562 .map(|key| {
563 let block = builder.create_block();
564 switch.set_entry(*key, block);
565 block
566 })
567 .collect::<Vec<_>>();
568
569 builder.seal_block(root_block);
570 builder.switch_to_block(root_block);
571
572 let val = builder.ins().iconst(typ, 1);
573 switch.emit(&mut builder, val, default_block);
574
575 for &block in case_blocks.iter().chain(std::iter::once(&default_block)) {
576 builder.seal_block(block);
577 builder.switch_to_block(block);
578 builder.ins().return_(&[]);
579 }
580
581 builder.finalize(); }
583 }
584
585 #[test]
586 fn switch_64bit() {
587 let mut func = Function::new();
588 let mut func_ctx = FunctionBuilderContext::new();
589 {
590 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
591 let block0 = bx.create_block();
592 bx.switch_to_block(block0);
593 let val = bx.ins().iconst(types::I64, 0);
594 let mut switch = Switch::new();
595 let block1 = bx.create_block();
596 switch.set_entry(1, block1);
597 let block2 = bx.create_block();
598 switch.set_entry(0, block2);
599 let block3 = bx.create_block();
600 switch.emit(&mut bx, val, block3);
601 }
602 let func = func
603 .to_string()
604 .trim_start_matches("function u0:0() fast {\n")
605 .trim_end_matches("\n}\n")
606 .to_string();
607 assert_eq_output!(
608 func,
609 "block0:
610 v0 = iconst.i64 0
611 v1 = icmp_imm ugt v0, 0xffff_ffff ; v0 = 0
612 brif v1, block3, block4
613
614block4:
615 v2 = ireduce.i32 v0 ; v0 = 0
616 br_table v2, block3, [block2, block1]"
617 );
618 }
619
620 #[test]
621 fn switch_128bit() {
622 let mut func = Function::new();
623 let mut func_ctx = FunctionBuilderContext::new();
624 {
625 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
626 let block0 = bx.create_block();
627 bx.switch_to_block(block0);
628 let val = bx.ins().iconst(types::I64, 0);
629 let val = bx.ins().uextend(types::I128, val);
630 let mut switch = Switch::new();
631 let block1 = bx.create_block();
632 switch.set_entry(1, block1);
633 let block2 = bx.create_block();
634 switch.set_entry(0, block2);
635 let block3 = bx.create_block();
636 switch.emit(&mut bx, val, block3);
637 }
638 let func = func
639 .to_string()
640 .trim_start_matches("function u0:0() fast {\n")
641 .trim_end_matches("\n}\n")
642 .to_string();
643 assert_eq_output!(
644 func,
645 "block0:
646 v0 = iconst.i64 0
647 v1 = uextend.i128 v0 ; v0 = 0
648 v2 = icmp_imm ugt v1, 0xffff_ffff
649 brif v2, block3, block4
650
651block4:
652 v3 = ireduce.i32 v1
653 br_table v3, block3, [block2, block1]"
654 );
655 }
656}