futures_bounded/
map.rs

1use std::future::Future;
2use std::hash::Hash;
3use std::mem;
4use std::pin::Pin;
5use std::task::{Context, Poll, Waker};
6use std::time::Duration;
7
8use futures_timer::Delay;
9use futures_util::future::BoxFuture;
10use futures_util::stream::FuturesUnordered;
11use futures_util::{FutureExt, StreamExt};
12
13use crate::Timeout;
14
15/// Represents a map of [`Future`]s.
16///
17/// Each future must finish within the specified time and the map never outgrows its capacity.
18pub struct FuturesMap<ID, O> {
19    timeout: Duration,
20    capacity: usize,
21    inner: FuturesUnordered<TaggedFuture<ID, TimeoutFuture<BoxFuture<'static, O>>>>,
22    empty_waker: Option<Waker>,
23    full_waker: Option<Waker>,
24}
25
26/// Error of a future pushing
27#[derive(PartialEq, Debug)]
28pub enum PushError<F> {
29    /// The length of the set is equal to the capacity
30    BeyondCapacity(F),
31    /// The set already contains the given future's ID
32    ReplacedFuture(F),
33}
34
35impl<ID, O> FuturesMap<ID, O> {
36    pub fn new(timeout: Duration, capacity: usize) -> Self {
37        Self {
38            timeout,
39            capacity,
40            inner: Default::default(),
41            empty_waker: None,
42            full_waker: None,
43        }
44    }
45}
46
47impl<ID, O> FuturesMap<ID, O>
48where
49    ID: Clone + Hash + Eq + Send + Unpin + 'static,
50{
51    /// Push a future into the map.
52    ///
53    /// This method inserts the given future with defined `future_id` to the set.
54    /// If the length of the map is equal to the capacity, this method returns [PushError::BeyondCapacity],
55    /// that contains the passed future. In that case, the future is not inserted to the map.
56    /// If a future with the given `future_id` already exists, then the old future will be replaced by a new one.
57    /// In that case, the returned error [PushError::ReplacedFuture] contains the old future.
58    pub fn try_push<F>(&mut self, future_id: ID, future: F) -> Result<(), PushError<BoxFuture<O>>>
59    where
60        F: Future<Output = O> + Send + 'static,
61    {
62        if self.inner.len() >= self.capacity {
63            return Err(PushError::BeyondCapacity(future.boxed()));
64        }
65
66        if let Some(waker) = self.empty_waker.take() {
67            waker.wake();
68        }
69
70        match self.inner.iter_mut().find(|tagged| tagged.tag == future_id) {
71            None => {
72                self.inner.push(TaggedFuture {
73                    tag: future_id,
74                    inner: TimeoutFuture {
75                        inner: future.boxed(),
76                        timeout: Delay::new(self.timeout),
77                    },
78                });
79
80                Ok(())
81            }
82            Some(existing) => {
83                let old_future = mem::replace(
84                    &mut existing.inner,
85                    TimeoutFuture {
86                        inner: future.boxed(),
87                        timeout: Delay::new(self.timeout),
88                    },
89                );
90
91                Err(PushError::ReplacedFuture(old_future.inner))
92            }
93        }
94    }
95
96    pub fn is_empty(&self) -> bool {
97        self.inner.is_empty()
98    }
99
100    #[allow(unknown_lints, clippy::needless_pass_by_ref_mut)] // &mut Context is idiomatic.
101    pub fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<()> {
102        if self.inner.len() < self.capacity {
103            return Poll::Ready(());
104        }
105
106        self.full_waker = Some(cx.waker().clone());
107
108        Poll::Pending
109    }
110
111    pub fn poll_unpin(&mut self, cx: &mut Context<'_>) -> Poll<(ID, Result<O, Timeout>)> {
112        let maybe_result = futures_util::ready!(self.inner.poll_next_unpin(cx));
113
114        match maybe_result {
115            None => {
116                self.empty_waker = Some(cx.waker().clone());
117                Poll::Pending
118            }
119            Some((id, Ok(output))) => Poll::Ready((id, Ok(output))),
120            Some((id, Err(_timeout))) => Poll::Ready((id, Err(Timeout::new(self.timeout)))),
121        }
122    }
123}
124
125struct TimeoutFuture<F> {
126    inner: F,
127    timeout: Delay,
128}
129
130impl<F> Future for TimeoutFuture<F>
131where
132    F: Future + Unpin,
133{
134    type Output = Result<F::Output, ()>;
135
136    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137        if self.timeout.poll_unpin(cx).is_ready() {
138            return Poll::Ready(Err(()));
139        }
140
141        self.inner.poll_unpin(cx).map(Ok)
142    }
143}
144
145struct TaggedFuture<T, F> {
146    tag: T,
147    inner: F,
148}
149
150impl<T, F> Future for TaggedFuture<T, F>
151where
152    T: Clone + Unpin,
153    F: Future + Unpin,
154{
155    type Output = (T, F::Output);
156
157    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158        let output = futures_util::ready!(self.inner.poll_unpin(cx));
159
160        Poll::Ready((self.tag.clone(), output))
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use std::future::{pending, poll_fn, ready};
167    use std::pin::Pin;
168    use std::time::Instant;
169
170    use super::*;
171
172    #[test]
173    fn cannot_push_more_than_capacity_tasks() {
174        let mut futures = FuturesMap::new(Duration::from_secs(10), 1);
175
176        assert!(futures.try_push("ID_1", ready(())).is_ok());
177        matches!(
178            futures.try_push("ID_2", ready(())),
179            Err(PushError::BeyondCapacity(_))
180        );
181    }
182
183    #[test]
184    fn cannot_push_the_same_id_few_times() {
185        let mut futures = FuturesMap::new(Duration::from_secs(10), 5);
186
187        assert!(futures.try_push("ID", ready(())).is_ok());
188        matches!(
189            futures.try_push("ID", ready(())),
190            Err(PushError::ReplacedFuture(_))
191        );
192    }
193
194    #[tokio::test]
195    async fn futures_timeout() {
196        let mut futures = FuturesMap::new(Duration::from_millis(100), 1);
197
198        let _ = futures.try_push("ID", pending::<()>());
199        Delay::new(Duration::from_millis(150)).await;
200        let (_, result) = poll_fn(|cx| futures.poll_unpin(cx)).await;
201
202        assert!(result.is_err())
203    }
204
205    // Each future causes a delay, `Task` only has a capacity of 1, meaning they must be processed in sequence.
206    // We stop after NUM_FUTURES tasks, meaning the overall execution must at least take DELAY * NUM_FUTURES.
207    #[tokio::test]
208    async fn backpressure() {
209        const DELAY: Duration = Duration::from_millis(100);
210        const NUM_FUTURES: u32 = 10;
211
212        let start = Instant::now();
213        Task::new(DELAY, NUM_FUTURES, 1).await;
214        let duration = start.elapsed();
215
216        assert!(duration >= DELAY * NUM_FUTURES);
217    }
218
219    struct Task {
220        future: Duration,
221        num_futures: usize,
222        num_processed: usize,
223        inner: FuturesMap<u8, ()>,
224    }
225
226    impl Task {
227        fn new(future: Duration, num_futures: u32, capacity: usize) -> Self {
228            Self {
229                future,
230                num_futures: num_futures as usize,
231                num_processed: 0,
232                inner: FuturesMap::new(Duration::from_secs(60), capacity),
233            }
234        }
235    }
236
237    impl Future for Task {
238        type Output = ();
239
240        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
241            let this = self.get_mut();
242
243            while this.num_processed < this.num_futures {
244                if let Poll::Ready((_, result)) = this.inner.poll_unpin(cx) {
245                    if result.is_err() {
246                        panic!("Timeout is great than future delay")
247                    }
248
249                    this.num_processed += 1;
250                    continue;
251                }
252
253                if let Poll::Ready(()) = this.inner.poll_ready_unpin(cx) {
254                    // We push the constant future's ID to prove that user can use the same ID
255                    // if the future was finished
256                    let maybe_future = this.inner.try_push(1u8, Delay::new(this.future));
257                    assert!(maybe_future.is_ok(), "we polled for readiness");
258
259                    continue;
260                }
261
262                return Poll::Pending;
263            }
264
265            Poll::Ready(())
266        }
267    }
268}