prioritized_metered_channel/
oneshot.rs1use std::{
20 ops::Deref,
21 pin::Pin,
22 task::{Context, Poll},
23};
24
25use futures::{
26 channel::oneshot::{self, Canceled, Cancellation},
27 future::{Fuse, FusedFuture},
28 prelude::*,
29};
30use futures_timer::Delay;
31
32use crate::{CoarseDuration, CoarseInstant};
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36#[repr(u8)]
37pub enum Reason {
38 Completion = 1,
39 Cancellation = 2,
40 HardTimeout = 3,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct Measurements {
46 first_poll_till_end: CoarseDuration,
48 creation_till_end: CoarseDuration,
50 reason: Reason,
52}
53
54impl Measurements {
55 pub fn duration_since_first_poll(&self) -> &CoarseDuration {
58 &self.first_poll_till_end
59 }
60
61 pub fn duration_since_creation(&self) -> &CoarseDuration {
64 &self.creation_till_end
65 }
66
67 pub fn reason(&self) -> &Reason {
69 &self.reason
70 }
71}
72
73pub fn channel<T>(
75 name: &'static str,
76 soft_timeout: CoarseDuration,
77 hard_timeout: CoarseDuration,
78) -> (MeteredSender<T>, MeteredReceiver<T>) {
79 let (tx, rx) = oneshot::channel();
80
81 (
82 MeteredSender { inner: tx },
83 MeteredReceiver {
84 name,
85 inner: rx,
86 soft_timeout,
87 hard_timeout,
88 soft_timeout_fut: None,
89 hard_timeout_fut: None,
90 first_poll_timestamp: None,
91 creation_timestamp: CoarseInstant::now(),
92 },
93 )
94}
95
96#[allow(missing_docs)]
97#[derive(thiserror::Error, Debug)]
98pub enum Error {
99 #[error("Oneshot was canceled.")]
100 Canceled(#[source] Canceled, Measurements),
101 #[error("Oneshot did not receive a response within {}", CoarseDuration::as_f64(.0))]
102 HardTimeout(CoarseDuration, Measurements),
103}
104
105impl Measurable for Error {
106 fn measurements(&self) -> Measurements {
107 match self {
108 Self::Canceled(_, measurements) => measurements.clone(),
109 Self::HardTimeout(_, measurements) => measurements.clone(),
110 }
111 }
112}
113
114#[derive(Debug)]
116pub struct MeteredSender<T> {
117 inner: oneshot::Sender<(CoarseInstant, T)>,
118}
119
120impl<T> MeteredSender<T> {
121 pub fn send(self, t: T) -> Result<(), T> {
123 let Self { inner } = self;
124 inner.send((CoarseInstant::now(), t)).map_err(|(_, t)| t)
125 }
126
127 pub fn poll_canceled(&mut self, ctx: &mut Context<'_>) -> Poll<()> {
129 self.inner.poll_canceled(ctx)
130 }
131
132 pub fn cancellation(&mut self) -> Cancellation<'_, (CoarseInstant, T)> {
134 self.inner.cancellation()
135 }
136
137 pub fn is_canceled(&self) -> bool {
139 self.inner.is_canceled()
140 }
141
142 pub fn is_connected_to(&self, receiver: &MeteredReceiver<T>) -> bool {
144 self.inner.is_connected_to(&receiver.inner)
145 }
146}
147
148#[derive(Debug)]
150pub struct MeteredReceiver<T> {
151 name: &'static str,
152 inner: oneshot::Receiver<(CoarseInstant, T)>,
153 soft_timeout_fut: Option<Fuse<Delay>>,
155 soft_timeout: CoarseDuration,
156 hard_timeout_fut: Option<Delay>,
158 hard_timeout: CoarseDuration,
159 first_poll_timestamp: Option<CoarseInstant>,
161 creation_timestamp: CoarseInstant,
162}
163
164impl<T> MeteredReceiver<T> {
165 pub fn close(&mut self) {
166 self.inner.close()
167 }
168
169 pub fn try_recv(&mut self) -> Result<Option<OutputWithMeasurements<T>>, Error> {
176 match self.inner.try_recv() {
177 Ok(Some((when, value))) => {
178 let measurements = self.create_measurement(when, Reason::Completion);
179 Ok(Some(OutputWithMeasurements { value, measurements }))
180 },
181 Err(e) => {
182 let measurements = self.create_measurement(
183 self.first_poll_timestamp.unwrap_or_else(|| CoarseInstant::now()),
184 Reason::Cancellation,
185 );
186 Err(Error::Canceled(e, measurements))
187 },
188 Ok(None) => Ok(None),
189 }
190 }
191
192 fn create_measurement(&self, start: CoarseInstant, reason: Reason) -> Measurements {
196 let end = CoarseInstant::now();
197 Measurements {
198 first_poll_till_end: end - start,
200 creation_till_end: end - self.creation_timestamp,
201 reason,
202 }
203 }
204}
205
206impl<T> FusedFuture for MeteredReceiver<T> {
207 fn is_terminated(&self) -> bool {
208 self.inner.is_terminated()
209 }
210}
211
212impl<T> Future for MeteredReceiver<T> {
213 type Output = Result<OutputWithMeasurements<T>, Error>;
214
215 fn poll(
216 mut self: Pin<&mut Self>,
217 ctx: &mut Context<'_>,
218 ) -> Poll<Result<OutputWithMeasurements<T>, Error>> {
219 let first_poll_timestamp =
220 self.first_poll_timestamp.get_or_insert_with(|| CoarseInstant::now()).clone();
221
222 let soft_timeout = self.soft_timeout.clone().into();
223 let soft_timeout = self
224 .soft_timeout_fut
225 .get_or_insert_with(move || Delay::new(soft_timeout).fuse());
226
227 if Pin::new(soft_timeout).poll(ctx).is_ready() {
228 tracing::warn!(target: "oneshot", "Oneshot `{name}` exceeded the soft threshold", name = &self.name);
229 }
230
231 let hard_timeout = self.hard_timeout.clone().into();
232 let hard_timeout =
233 self.hard_timeout_fut.get_or_insert_with(move || Delay::new(hard_timeout));
234
235 if Pin::new(hard_timeout).poll(ctx).is_ready() {
236 let measurements = self.create_measurement(first_poll_timestamp, Reason::HardTimeout);
237 return Poll::Ready(Err(Error::HardTimeout(self.hard_timeout.clone(), measurements)))
238 }
239
240 match Pin::new(&mut self.inner).poll(ctx) {
241 Poll::Pending => Poll::Pending,
242 Poll::Ready(Err(e)) => {
243 let measurements =
244 self.create_measurement(first_poll_timestamp, Reason::Cancellation);
245 Poll::Ready(Err(Error::Canceled(e, measurements)))
246 },
247 Poll::Ready(Ok((ref sent_at_timestamp, value))) => {
248 let measurements =
249 self.create_measurement(sent_at_timestamp.clone(), Reason::Completion);
250 Poll::Ready(Ok(OutputWithMeasurements::<T> { value, measurements }))
251 },
252 }
253 }
254}
255
256pub trait Measurable {
258 fn measurements(&self) -> Measurements;
260}
261
262impl<T> Measurable for Result<OutputWithMeasurements<T>, Error> {
263 fn measurements(&self) -> Measurements {
264 match self {
265 Err(err) => err.measurements(),
266 Ok(val) => val.measurements(),
267 }
268 }
269}
270
271#[derive(Clone, Debug)]
277pub struct OutputWithMeasurements<T> {
278 value: T,
279 measurements: Measurements,
280}
281
282impl<T> Measurable for OutputWithMeasurements<T> {
283 fn measurements(&self) -> Measurements {
284 self.measurements.clone()
285 }
286}
287
288impl<T> OutputWithMeasurements<T> {
289 pub fn into(self) -> T {
293 self.value
294 }
295}
296
297impl<T> AsRef<T> for OutputWithMeasurements<T> {
298 fn as_ref(&self) -> &T {
299 &self.value
300 }
301}
302
303impl<T> Deref for OutputWithMeasurements<T> {
304 type Target = T;
305
306 fn deref(&self) -> &Self::Target {
307 &self.value
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use assert_matches::assert_matches;
314 use futures::{executor::ThreadPool, task::SpawnExt};
315 use std::time::Duration;
316
317 use super::*;
318
319 #[derive(Clone, PartialEq, Eq, Debug)]
320 struct DummyItem {
321 vals: [u8; 256],
322 }
323
324 impl Default for DummyItem {
325 fn default() -> Self {
326 Self { vals: [0u8; 256] }
327 }
328 }
329
330 fn test_launch<S, R, FS, FR>(name: &'static str, gen_sender_test: S, gen_receiver_test: R)
331 where
332 S: Fn(MeteredSender<DummyItem>) -> FS,
333 R: Fn(MeteredReceiver<DummyItem>) -> FR,
334 FS: Future<Output = ()> + Send + 'static,
335 FR: Future<Output = ()> + Send + 'static,
336 {
337 let _ = env_logger::builder().is_test(true).filter_level(LevelFilter::Trace).try_init();
338
339 let pool = ThreadPool::new().unwrap();
340 let (tx, rx) = channel(name, CoarseDuration::from_secs(1), CoarseDuration::from_secs(3));
341 futures::executor::block_on(async move {
342 let handle_receiver = pool.spawn_with_handle(gen_receiver_test(rx)).unwrap();
343 let handle_sender = pool.spawn_with_handle(gen_sender_test(tx)).unwrap();
344 futures::future::select(
345 futures::future::join(handle_sender, handle_receiver),
346 Delay::new(Duration::from_secs(5)),
347 )
348 .await;
349 });
350 }
351
352 use log::LevelFilter;
353
354 #[test]
355 fn easy() {
356 test_launch(
357 "easy",
358 |tx| async move {
359 tx.send(DummyItem::default()).unwrap();
360 },
361 |rx| async move {
362 let x = rx.await.unwrap();
363 let measurements = x.measurements();
364 assert_eq!(x.as_ref(), &DummyItem::default());
365 dbg!(measurements);
366 },
367 );
368 }
369
370 #[test]
371 fn cancel_by_drop() {
372 test_launch(
373 "cancel_by_drop",
374 |tx| async move {
375 Delay::new(Duration::from_secs(2)).await;
376 drop(tx);
377 },
378 |rx| async move {
379 let result = rx.await;
380 assert_matches!(result, Err(Error::Canceled(_, _)));
381 dbg!(result.measurements());
382 },
383 );
384 }
385
386 #[test]
387 fn starve_till_hard_timeout() {
388 test_launch(
389 "starve_till_timeout",
390 |tx| async move {
391 Delay::new(Duration::from_secs(4)).await;
392 let _ = tx.send(DummyItem::default());
393 },
394 |rx| async move {
395 let result = rx.await;
396 assert_matches!(&result, e @ &Err(Error::HardTimeout(_, _)) => {
397 println!("{:?}", e);
398 });
399 dbg!(result.measurements());
400 },
401 );
402 }
403
404 #[test]
405 fn starve_till_soft_timeout_then_food() {
406 test_launch(
407 "starve_till_soft_timeout_then_food",
408 |tx| async move {
409 Delay::new(Duration::from_secs(2)).await;
410 let _ = tx.send(DummyItem::default());
411 },
412 |rx| async move {
413 let result = rx.await;
414 assert_matches!(result, Ok(_));
415 dbg!(result.measurements());
416 },
417 );
418 }
419}