1use std::future::Future;
2use std::hash::Hash;
3use std::mem;
4use std::pin::Pin;
5use std::task::{Context, Poll, Waker};
6use std::time::Duration;
7
8use futures_timer::Delay;
9use futures_util::future::BoxFuture;
10use futures_util::stream::FuturesUnordered;
11use futures_util::{FutureExt, StreamExt};
12
13use crate::Timeout;
14
15pub struct FuturesMap<ID, O> {
19 timeout: Duration,
20 capacity: usize,
21 inner: FuturesUnordered<TaggedFuture<ID, TimeoutFuture<BoxFuture<'static, O>>>>,
22 empty_waker: Option<Waker>,
23 full_waker: Option<Waker>,
24}
25
26#[derive(PartialEq, Debug)]
28pub enum PushError<F> {
29 BeyondCapacity(F),
31 ReplacedFuture(F),
33}
34
35impl<ID, O> FuturesMap<ID, O> {
36 pub fn new(timeout: Duration, capacity: usize) -> Self {
37 Self {
38 timeout,
39 capacity,
40 inner: Default::default(),
41 empty_waker: None,
42 full_waker: None,
43 }
44 }
45}
46
47impl<ID, O> FuturesMap<ID, O>
48where
49 ID: Clone + Hash + Eq + Send + Unpin + 'static,
50{
51 pub fn try_push<F>(&mut self, future_id: ID, future: F) -> Result<(), PushError<BoxFuture<O>>>
59 where
60 F: Future<Output = O> + Send + 'static,
61 {
62 if self.inner.len() >= self.capacity {
63 return Err(PushError::BeyondCapacity(future.boxed()));
64 }
65
66 if let Some(waker) = self.empty_waker.take() {
67 waker.wake();
68 }
69
70 match self.inner.iter_mut().find(|tagged| tagged.tag == future_id) {
71 None => {
72 self.inner.push(TaggedFuture {
73 tag: future_id,
74 inner: TimeoutFuture {
75 inner: future.boxed(),
76 timeout: Delay::new(self.timeout),
77 },
78 });
79
80 Ok(())
81 }
82 Some(existing) => {
83 let old_future = mem::replace(
84 &mut existing.inner,
85 TimeoutFuture {
86 inner: future.boxed(),
87 timeout: Delay::new(self.timeout),
88 },
89 );
90
91 Err(PushError::ReplacedFuture(old_future.inner))
92 }
93 }
94 }
95
96 pub fn is_empty(&self) -> bool {
97 self.inner.is_empty()
98 }
99
100 #[allow(unknown_lints, clippy::needless_pass_by_ref_mut)] pub fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<()> {
102 if self.inner.len() < self.capacity {
103 return Poll::Ready(());
104 }
105
106 self.full_waker = Some(cx.waker().clone());
107
108 Poll::Pending
109 }
110
111 pub fn poll_unpin(&mut self, cx: &mut Context<'_>) -> Poll<(ID, Result<O, Timeout>)> {
112 let maybe_result = futures_util::ready!(self.inner.poll_next_unpin(cx));
113
114 match maybe_result {
115 None => {
116 self.empty_waker = Some(cx.waker().clone());
117 Poll::Pending
118 }
119 Some((id, Ok(output))) => Poll::Ready((id, Ok(output))),
120 Some((id, Err(_timeout))) => Poll::Ready((id, Err(Timeout::new(self.timeout)))),
121 }
122 }
123}
124
125struct TimeoutFuture<F> {
126 inner: F,
127 timeout: Delay,
128}
129
130impl<F> Future for TimeoutFuture<F>
131where
132 F: Future + Unpin,
133{
134 type Output = Result<F::Output, ()>;
135
136 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
137 if self.timeout.poll_unpin(cx).is_ready() {
138 return Poll::Ready(Err(()));
139 }
140
141 self.inner.poll_unpin(cx).map(Ok)
142 }
143}
144
145struct TaggedFuture<T, F> {
146 tag: T,
147 inner: F,
148}
149
150impl<T, F> Future for TaggedFuture<T, F>
151where
152 T: Clone + Unpin,
153 F: Future + Unpin,
154{
155 type Output = (T, F::Output);
156
157 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
158 let output = futures_util::ready!(self.inner.poll_unpin(cx));
159
160 Poll::Ready((self.tag.clone(), output))
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use std::future::{pending, poll_fn, ready};
167 use std::pin::Pin;
168 use std::time::Instant;
169
170 use super::*;
171
172 #[test]
173 fn cannot_push_more_than_capacity_tasks() {
174 let mut futures = FuturesMap::new(Duration::from_secs(10), 1);
175
176 assert!(futures.try_push("ID_1", ready(())).is_ok());
177 matches!(
178 futures.try_push("ID_2", ready(())),
179 Err(PushError::BeyondCapacity(_))
180 );
181 }
182
183 #[test]
184 fn cannot_push_the_same_id_few_times() {
185 let mut futures = FuturesMap::new(Duration::from_secs(10), 5);
186
187 assert!(futures.try_push("ID", ready(())).is_ok());
188 matches!(
189 futures.try_push("ID", ready(())),
190 Err(PushError::ReplacedFuture(_))
191 );
192 }
193
194 #[tokio::test]
195 async fn futures_timeout() {
196 let mut futures = FuturesMap::new(Duration::from_millis(100), 1);
197
198 let _ = futures.try_push("ID", pending::<()>());
199 Delay::new(Duration::from_millis(150)).await;
200 let (_, result) = poll_fn(|cx| futures.poll_unpin(cx)).await;
201
202 assert!(result.is_err())
203 }
204
205 #[tokio::test]
208 async fn backpressure() {
209 const DELAY: Duration = Duration::from_millis(100);
210 const NUM_FUTURES: u32 = 10;
211
212 let start = Instant::now();
213 Task::new(DELAY, NUM_FUTURES, 1).await;
214 let duration = start.elapsed();
215
216 assert!(duration >= DELAY * NUM_FUTURES);
217 }
218
219 struct Task {
220 future: Duration,
221 num_futures: usize,
222 num_processed: usize,
223 inner: FuturesMap<u8, ()>,
224 }
225
226 impl Task {
227 fn new(future: Duration, num_futures: u32, capacity: usize) -> Self {
228 Self {
229 future,
230 num_futures: num_futures as usize,
231 num_processed: 0,
232 inner: FuturesMap::new(Duration::from_secs(60), capacity),
233 }
234 }
235 }
236
237 impl Future for Task {
238 type Output = ();
239
240 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
241 let this = self.get_mut();
242
243 while this.num_processed < this.num_futures {
244 if let Poll::Ready((_, result)) = this.inner.poll_unpin(cx) {
245 if result.is_err() {
246 panic!("Timeout is great than future delay")
247 }
248
249 this.num_processed += 1;
250 continue;
251 }
252
253 if let Poll::Ready(()) = this.inner.poll_ready_unpin(cx) {
254 let maybe_future = this.inner.try_push(1u8, Delay::new(this.future));
257 assert!(maybe_future.is_ok(), "we polled for readiness");
258
259 continue;
260 }
261
262 return Poll::Pending;
263 }
264
265 Poll::Ready(())
266 }
267 }
268}