cranelift_codegen/
unionfind.rs

1//! Simple union-find data structure.
2
3use crate::trace;
4use cranelift_entity::{packed_option::ReservedValue, EntityRef, SecondaryMap};
5use std::hash::Hash;
6
7/// A union-find data structure. The data structure can allocate
8/// `Id`s, indicating eclasses, and can merge eclasses together.
9#[derive(Clone, Debug, PartialEq)]
10pub struct UnionFind<Idx: EntityRef> {
11    parent: SecondaryMap<Idx, Val<Idx>>,
12}
13
14#[derive(Clone, Debug, PartialEq)]
15struct Val<Idx>(Idx);
16impl<Idx: EntityRef + ReservedValue> Default for Val<Idx> {
17    fn default() -> Self {
18        Self(Idx::reserved_value())
19    }
20}
21
22impl<Idx: EntityRef + Hash + std::fmt::Display + Ord + ReservedValue> UnionFind<Idx> {
23    /// Create a new `UnionFind` with the given capacity.
24    pub fn with_capacity(cap: usize) -> Self {
25        UnionFind {
26            parent: SecondaryMap::with_capacity(cap),
27        }
28    }
29
30    /// Add an `Idx` to the `UnionFind`, with its own equivalence class
31    /// initially. All `Idx`s must be added before being queried or
32    /// unioned.
33    pub fn add(&mut self, id: Idx) {
34        debug_assert!(id != Idx::reserved_value());
35        self.parent[id] = Val(id);
36    }
37
38    /// Find the canonical `Idx` of a given `Idx`.
39    pub fn find(&self, mut node: Idx) -> Idx {
40        while node != self.parent[node].0 {
41            node = self.parent[node].0;
42        }
43        node
44    }
45
46    /// Find the canonical `Idx` of a given `Idx`, updating the data
47    /// structure in the process so that future queries for this `Idx`
48    /// (and others in its chain up to the root of the equivalence
49    /// class) will be faster.
50    pub fn find_and_update(&mut self, mut node: Idx) -> Idx {
51        // "Path splitting" mutating find (Tarjan and Van Leeuwen).
52        debug_assert!(node != Idx::reserved_value());
53        while node != self.parent[node].0 {
54            let next = self.parent[self.parent[node].0].0;
55            debug_assert!(next != Idx::reserved_value());
56            self.parent[node] = Val(next);
57            node = next;
58        }
59        debug_assert!(node != Idx::reserved_value());
60        node
61    }
62
63    /// Merge the equivalence classes of the two `Idx`s.
64    pub fn union(&mut self, a: Idx, b: Idx) {
65        let a = self.find_and_update(a);
66        let b = self.find_and_update(b);
67        let (a, b) = (std::cmp::min(a, b), std::cmp::max(a, b));
68        if a != b {
69            // Always canonicalize toward lower IDs.
70            self.parent[b] = Val(a);
71            trace!("union: {}, {}", a, b);
72        }
73    }
74}