1use fxhash::FxHashMap;
9use std::cell::Cell;
10
11const SMALL_ELEMS: usize = 12;
12
13#[derive(Clone, Debug)]
19enum AdaptiveMap {
20 Small {
21 len: u32,
22 keys: [u32; SMALL_ELEMS],
23 values: [u64; SMALL_ELEMS],
24 },
25 Large(FxHashMap<u32, u64>),
26}
27
28const INVALID: u32 = 0xffff_ffff;
29
30impl AdaptiveMap {
31 fn new() -> Self {
32 Self::Small {
33 len: 0,
34 keys: [INVALID; SMALL_ELEMS],
35 values: [0; SMALL_ELEMS],
36 }
37 }
38
39 #[inline(always)]
40 fn get_or_insert<'a>(&'a mut self, key: u32) -> &'a mut u64 {
41 let small_mode_idx = match self {
44 &mut Self::Small {
45 len,
46 ref mut keys,
47 ref values,
48 } => {
49 if let Some(i) = keys[..len as usize].iter().position(|&k| k == key) {
56 Some(i)
57 } else if len != SMALL_ELEMS as u32 {
58 debug_assert!(len < SMALL_ELEMS as u32);
59 None
60 } else if let Some(i) = values.iter().position(|&v| v == 0) {
61 keys[i] = key;
63 Some(i)
64 } else {
65 *self = Self::Large(keys.iter().copied().zip(values.iter().copied()).collect());
66 None
67 }
68 }
69 _ => None,
70 };
71
72 match self {
73 Self::Small { len, keys, values } => {
74 if let Some(i) = small_mode_idx {
78 return &mut values[i];
79 }
80 debug_assert!(*len < SMALL_ELEMS as u32);
83 let idx = *len as usize;
84 *len += 1;
85 keys[idx] = key;
86 values[idx] = 0;
87 &mut values[idx]
88 }
89 Self::Large(map) => map.entry(key).or_insert(0),
90 }
91 }
92
93 #[inline(always)]
94 fn get_mut(&mut self, key: u32) -> Option<&mut u64> {
95 match self {
96 &mut Self::Small {
97 len,
98 ref keys,
99 ref mut values,
100 } => {
101 for i in 0..len {
102 if keys[i as usize] == key {
103 return Some(&mut values[i as usize]);
104 }
105 }
106 None
107 }
108 &mut Self::Large(ref mut map) => map.get_mut(&key),
109 }
110 }
111 #[inline(always)]
112 fn get(&self, key: u32) -> Option<u64> {
113 match self {
114 &Self::Small {
115 len,
116 ref keys,
117 ref values,
118 } => {
119 for i in 0..len {
120 if keys[i as usize] == key {
121 let value = values[i as usize];
122 return Some(value);
123 }
124 }
125 None
126 }
127 &Self::Large(ref map) => {
128 let value = map.get(&key).cloned();
129 value
130 }
131 }
132 }
133 fn iter<'a>(&'a self) -> AdaptiveMapIter<'a> {
134 match self {
135 &Self::Small {
136 len,
137 ref keys,
138 ref values,
139 } => AdaptiveMapIter::Small(&keys[0..len as usize], &values[0..len as usize]),
140 &Self::Large(ref map) => AdaptiveMapIter::Large(map.iter()),
141 }
142 }
143
144 fn is_empty(&self) -> bool {
145 match self {
146 AdaptiveMap::Small { values, .. } => values.iter().all(|&value| value == 0),
147 AdaptiveMap::Large(m) => m.values().all(|&value| value == 0),
148 }
149 }
150}
151
152enum AdaptiveMapIter<'a> {
153 Small(&'a [u32], &'a [u64]),
154 Large(std::collections::hash_map::Iter<'a, u32, u64>),
155}
156
157impl<'a> std::iter::Iterator for AdaptiveMapIter<'a> {
158 type Item = (u32, u64);
159
160 #[inline]
161 fn next(&mut self) -> Option<Self::Item> {
162 match self {
163 &mut Self::Small(ref mut keys, ref mut values) => {
164 if keys.is_empty() {
165 None
166 } else {
167 let (k, v) = ((*keys)[0], (*values)[0]);
168 *keys = &(*keys)[1..];
169 *values = &(*values)[1..];
170 Some((k, v))
171 }
172 }
173 &mut Self::Large(ref mut it) => it.next().map(|(&k, &v)| (k, v)),
174 }
175 }
176}
177
178#[derive(Clone)]
181pub struct IndexSet {
182 elems: AdaptiveMap,
183 cache: Cell<(u32, u64)>,
184}
185
186const BITS_PER_WORD: usize = 64;
187
188impl IndexSet {
189 pub fn new() -> Self {
190 Self {
191 elems: AdaptiveMap::new(),
192 cache: Cell::new((INVALID, 0)),
193 }
194 }
195
196 #[inline(always)]
197 fn elem(&mut self, bit_index: usize) -> &mut u64 {
198 let word_index = (bit_index / BITS_PER_WORD) as u32;
199 if self.cache.get().0 == word_index {
200 self.cache.set((INVALID, 0));
201 }
202 self.elems.get_or_insert(word_index)
203 }
204
205 #[inline(always)]
206 fn maybe_elem_mut(&mut self, bit_index: usize) -> Option<&mut u64> {
207 let word_index = (bit_index / BITS_PER_WORD) as u32;
208 if self.cache.get().0 == word_index {
209 self.cache.set((INVALID, 0));
210 }
211 self.elems.get_mut(word_index)
212 }
213
214 #[inline(always)]
215 fn maybe_elem(&self, bit_index: usize) -> Option<u64> {
216 let word_index = (bit_index / BITS_PER_WORD) as u32;
217 if self.cache.get().0 == word_index {
218 Some(self.cache.get().1)
219 } else {
220 self.elems.get(word_index)
221 }
222 }
223
224 #[inline(always)]
225 pub fn set(&mut self, idx: usize, val: bool) {
226 let bit = idx % BITS_PER_WORD;
227 if val {
228 *self.elem(idx) |= 1 << bit;
229 } else if let Some(word) = self.maybe_elem_mut(idx) {
230 *word &= !(1 << bit);
231 }
232 }
233
234 pub fn assign(&mut self, other: &Self) {
235 self.elems = other.elems.clone();
236 self.cache = other.cache.clone();
237 }
238
239 #[inline(always)]
240 pub fn get(&self, idx: usize) -> bool {
241 let bit = idx % BITS_PER_WORD;
242 if let Some(word) = self.maybe_elem(idx) {
243 (word & (1 << bit)) != 0
244 } else {
245 false
246 }
247 }
248
249 pub fn union_with(&mut self, other: &Self) -> bool {
250 let mut changed = 0;
251 for (word_idx, bits) in other.elems.iter() {
252 if bits == 0 {
253 continue;
254 }
255 let word_idx = word_idx as usize;
256 let self_word = self.elem(word_idx * BITS_PER_WORD);
257 changed |= bits & !*self_word;
258 *self_word |= bits;
259 }
260 changed != 0
261 }
262
263 pub fn iter<'a>(&'a self) -> impl Iterator<Item = usize> + 'a {
264 self.elems.iter().flat_map(|(word_idx, bits)| {
265 let word_idx = word_idx as usize;
266 SetBitsIter(bits).map(move |i| BITS_PER_WORD * word_idx + i)
267 })
268 }
269
270 pub(crate) fn is_small(&self) -> bool {
273 match &self.elems {
274 &AdaptiveMap::Small { .. } => true,
275 _ => false,
276 }
277 }
278
279 pub(crate) fn is_empty(&self) -> bool {
281 self.elems.is_empty()
282 }
283}
284
285pub struct SetBitsIter(u64);
286
287impl Iterator for SetBitsIter {
288 type Item = usize;
289
290 #[inline]
291 fn next(&mut self) -> Option<usize> {
292 std::num::NonZeroU64::new(self.0).map(|nz| {
296 let bitidx = nz.trailing_zeros();
297 self.0 &= self.0 - 1; bitidx as usize
299 })
300 }
301}
302
303impl std::fmt::Debug for IndexSet {
304 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
305 let vals = self.iter().collect::<Vec<_>>();
306 write!(f, "{:?}", vals)
307 }
308}
309
310#[cfg(test)]
311mod test {
312 use super::IndexSet;
313
314 #[test]
315 fn test_set_bits_iter() {
316 let mut vec = IndexSet::new();
317 let mut sum = 0;
318 for i in 0..1024 {
319 if i % 17 == 0 {
320 vec.set(i, true);
321 sum += i;
322 }
323 }
324
325 let mut checksum = 0;
326 for bit in vec.iter() {
327 debug_assert!(bit % 17 == 0);
328 checksum += bit;
329 }
330
331 debug_assert_eq!(sum, checksum);
332 }
333
334 #[test]
335 fn test_expand_remove_zero_elems() {
336 let mut vec = IndexSet::new();
337 for i in 0..12 {
339 vec.set(64 * i, true);
340 }
341 vec.set(64 * 5, false);
344 vec.set(64 * 100, true);
345 debug_assert!(vec.is_small());
346 }
347}