partial_sort/
lib.rs

1//! # partial_sort
2//!
3//! [![Build Status](https://github.com/sundy-li/partial_sort/actions/workflows/Build.yml/badge.svg)](https://github.com/sundy-li/partial_sort/actions/workflows/Build.yml)
4//! [![](http://meritbadge.herokuapp.com/partial_sort)](https://crates.io/crates/partial_sort)
5//! [![](https://img.shields.io/crates/d/partial_sort.svg)](https://crates.io/crates/partial_sort)
6//! [![](https://img.shields.io/crates/dv/partial_sort.svg)](https://crates.io/crates/partial_sort)
7//! [![](https://docs.rs/partial_sort/badge.svg)](https://docs.rs/partial_sort/)
8//!
9//!
10//! partial_sort is Rust version of [std::partial_sort](https://en.cppreference.com/w/cpp/algorithm/partial_sort)
11//!
12//! ```toml
13//! [dependencies]
14//! partial_sort = "0.1.2"
15//! ```
16//!
17//!
18//! # Example
19//! ```
20//! # use partial_sort::PartialSort;
21//!
22//! let mut vec = vec![4, 4, 3, 3, 1, 1, 2, 2];
23//! vec.partial_sort(4, |a, b| a.cmp(b));
24//! println!("{:?}", vec);
25//! ```
26
27#![crate_type = "lib"]
28#![crate_name = "partial_sort"]
29#![cfg_attr(feature = "nightly", feature(test))]
30
31use std::cmp::Ordering;
32use std::cmp::Ordering::Less;
33use std::{mem, ptr};
34
35pub trait PartialSort {
36    type Item;
37
38    fn partial_sort<F>(&mut self, _: usize, _: F)
39    where
40        F: FnMut(&Self::Item, &Self::Item) -> Ordering;
41}
42
43impl<T> PartialSort for [T] {
44    type Item = T;
45
46    fn partial_sort<F>(&mut self, last: usize, mut cmp: F)
47    where
48        F: FnMut(&Self::Item, &Self::Item) -> Ordering,
49    {
50        partial_sort(self, last, |a, b| cmp(a, b) == Less);
51    }
52}
53
54pub fn partial_sort<T, F>(v: &mut [T], last: usize, mut is_less: F)
55where
56    F: FnMut(&T, &T) -> bool,
57{
58    assert!(last <= v.len());
59
60    make_heap(v, last, &mut is_less);
61
62    for i in last..v.len() {
63        if is_less(&v[i], &v[0]) {
64            v.swap(0, i);
65            adjust_heap(v, 0, last, &mut is_less);
66        }
67    }
68
69    sort_heap(v, last, &mut is_less);
70}
71
72#[inline]
73fn make_heap<T, F>(v: &mut [T], last: usize, is_less: &mut F)
74where
75    F: FnMut(&T, &T) -> bool,
76{
77    if last < 2 {
78        return;
79    }
80
81    let len = last;
82    let mut parent = (len - 2) / 2;
83
84    loop {
85        adjust_heap(v, parent, len, is_less);
86        if parent == 0 {
87            return;
88        }
89        parent -= 1;
90    }
91}
92
93/// adjust_heap is a shift up adjust op for the heap
94#[inline]
95fn adjust_heap<T, F>(v: &mut [T], hole_index: usize, len: usize, is_less: &mut F)
96where
97    F: FnMut(&T, &T) -> bool,
98{
99    let mut left_child = hole_index * 2 + 1;
100
101    //SAFETY: we ensure hole_index point to a properly initialized value of type T
102    let mut tmp = unsafe { mem::ManuallyDrop::new(ptr::read(&v[hole_index])) };
103    let mut hole = InsertionHole {
104        src: &mut *tmp,
105        dest: &mut v[hole_index],
106    };
107    // Panic safety:
108    //
109    // If `is_less` panics at any point during the process, `hole` will get dropped and fill the
110    // hole in `v` with the unconsumed range in `buf`, thus ensuring that `v` still holds every
111    // object it initially held exactly once.
112
113    // SAFETY:
114    // we ensure src/dest point to a properly initialized value of type T
115    // src is valid for reads of `count * size_of::<T>()` bytes.
116    // dest is valid for reads of `count * size_of::<T>()` bytes.
117    // Both `src` and `dst` are properly aligned.
118
119    unsafe {
120        while left_child < len {
121            // SAFETY:
122            // we ensure left_child and left_child + 1 are between [0, len)
123            if left_child + 1 < len {
124                left_child += usize::from(is_less(
125                    v.get_unchecked(left_child),
126                    v.get_unchecked(left_child + 1),
127                ));
128            }
129
130            // SAFETY:
131            // left_child and hole.dest point to a properly initialized value of type T
132            if is_less(&*tmp, v.get_unchecked(left_child)) {
133                ptr::copy_nonoverlapping(&v[left_child], hole.dest, 1);
134                hole.dest = &mut v[left_child];
135            } else {
136                break;
137            }
138
139            left_child = left_child * 2 + 1;
140        }
141    }
142
143    // These codes is from std::sort_by
144    // When dropped, copies from `src` into `dest`.
145    struct InsertionHole<T> {
146        src: *mut T,
147        dest: *mut T,
148    }
149
150    impl<T> Drop for InsertionHole<T> {
151        fn drop(&mut self) {
152            // SAFETY:
153            // we ensure src/dest point to a properly initialized value of type T
154            // src is valid for reads of `count * size_of::<T>()` bytes.
155            // dest is valid for reads of `count * size_of::<T>()` bytes.
156            // Both `src` and `dst` are properly aligned.
157            unsafe {
158                ptr::copy_nonoverlapping(self.src, self.dest, 1);
159            }
160        }
161    }
162}
163
164#[inline]
165fn sort_heap<T, F>(v: &mut [T], mut last: usize, is_less: &mut F)
166where
167    F: FnMut(&T, &T) -> bool,
168{
169    while last > 1 {
170        last -= 1;
171        v.swap(0, last);
172        adjust_heap(v, 0, last, is_less);
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use rand::Rng;
179    use std::cmp::Ordering;
180    use std::fmt;
181    use std::sync::Arc;
182
183    use crate::PartialSort;
184
185    #[test]
186    fn empty_test() {
187        let mut before: Vec<u32> = vec![4, 4, 3, 3, 1, 1, 2, 2];
188        before.partial_sort(0, |a, b| a.cmp(b));
189    }
190
191    #[test]
192    fn single_test() {
193        let mut before: Vec<u32> = vec![4, 4, 3, 3, 1, 1, 2, 2];
194        let last = 6;
195        let mut d = before.clone();
196        d.sort();
197
198        before.partial_sort(last, |a, b| a.cmp(b));
199        assert_eq!(&d[0..last], &before.as_slice()[0..last]);
200    }
201
202    #[test]
203    fn sorted_strings_test() {
204        let mut before: Vec<&str> = vec![
205            "a", "cat", "mat", "on", "sat", "the", "xxx", "xxxx", "fdadfdsf",
206        ];
207        let last = 6;
208        let mut d = before.clone();
209        d.sort();
210
211        before.partial_sort(last, |a, b| a.cmp(b));
212        assert_eq!(&d[0..last], &before.as_slice()[0..last]);
213    }
214
215    #[test]
216    fn sorted_ref_test() {
217        trait TModel: fmt::Debug + Send + Sync {
218            fn size(&self) -> usize;
219        }
220
221        struct ModelFoo {
222            size: usize,
223        }
224
225        impl TModel for ModelFoo {
226            fn size(&self) -> usize {
227                return self.size;
228            }
229        }
230        impl fmt::Debug for ModelFoo {
231            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
232                write!(f, "ModelFoo[{}]", self.size)?;
233                Ok(())
234            }
235        }
236
237        struct ModelBar {
238            size: usize,
239        }
240
241        impl TModel for ModelBar {
242            fn size(&self) -> usize {
243                return self.size;
244            }
245        }
246        impl fmt::Debug for ModelBar {
247            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
248                write!(f, "ModelBar[{}]", self.size)?;
249                Ok(())
250            }
251        }
252
253        type ModelRef = Arc<dyn TModel>;
254
255        /// Compare two `Array`s based on the ordering defined in [ord](crate::array::ord).
256        fn cmp_model(a: &dyn TModel, b: &dyn TModel) -> Ordering {
257            return a.size().cmp(&b.size());
258        }
259
260        let mut before: Vec<(i32, ModelRef)> = vec![
261            (1i32, Arc::new(ModelBar { size: 100 })),
262            (1i32, Arc::new(ModelFoo { size: 99 })),
263            (1i32, Arc::new(ModelFoo { size: 101 })),
264            (1i32, Arc::new(ModelBar { size: 104 })),
265            (1i32, Arc::new(ModelBar { size: 10 })),
266            (1i32, Arc::new(ModelBar { size: 24 })),
267            (1i32, Arc::new(ModelBar { size: 34 })),
268            (1i32, Arc::new(ModelBar { size: 114 })),
269        ];
270
271        let last = 6;
272        let mut d = before.clone();
273        d.sort_by(|a, b| cmp_model(a.1.as_ref(), b.1.as_ref()));
274
275        before.partial_sort(last, |a, b| cmp_model(a.1.as_ref(), b.1.as_ref()));
276
277        d[0..last].iter().zip(&before[0..last]).for_each(|(a, b)| {
278            assert_eq!(a.0, b.0);
279            assert_eq!(a.1.size(), b.1.size());
280        });
281    }
282
283    /// creates random initial vectors, partial sorts then and
284    /// verifies the result against std's `sort`.
285    #[test]
286    fn sorted_random_u64_test() {
287        let mut rng = rand::thread_rng();
288        let vec_size = 1025;
289        let partial_size = (rng.gen::<u64>() % vec_size) as usize;
290        let mut data = (0u64..vec_size)
291            .map(|_| rng.gen::<u64>())
292            .collect::<Vec<u64>>();
293        let mut d = data.clone();
294        d.sort();
295
296        data.partial_sort(partial_size, |a, b| a.cmp(b));
297        assert_eq!(&d[0..partial_size], &data.as_slice()[0..partial_size]);
298    }
299
300    #[test]
301    #[ignore]
302    fn sorted_expensive_random_u64_test() {
303        for _ in 0..100 {
304            let mut rng = rand::thread_rng();
305            let vec_size = 1025;
306            let partial_size = (rng.gen::<u64>() % vec_size) as usize;
307            let mut data = (0u64..vec_size)
308                .map(|_| rng.gen::<u64>())
309                .collect::<Vec<u64>>();
310            let mut d = data.clone();
311            d.sort();
312
313            data.partial_sort(partial_size, |a, b| a.cmp(b));
314            assert_eq!(&d[0..partial_size], &data.as_slice()[0..partial_size]);
315        }
316    }
317}