use std::{
ops::Deref,
pin::Pin,
task::{Context, Poll},
};
use futures::{
channel::oneshot::{self, Canceled, Cancellation},
future::{Fuse, FusedFuture},
prelude::*,
};
use futures_timer::Delay;
use crate::{CoarseDuration, CoarseInstant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum Reason {
Completion = 1,
Cancellation = 2,
HardTimeout = 3,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Measurements {
first_poll_till_end: CoarseDuration,
creation_till_end: CoarseDuration,
reason: Reason,
}
impl Measurements {
pub fn duration_since_first_poll(&self) -> &CoarseDuration {
&self.first_poll_till_end
}
pub fn duration_since_creation(&self) -> &CoarseDuration {
&self.creation_till_end
}
pub fn reason(&self) -> &Reason {
&self.reason
}
}
pub fn channel<T>(
name: &'static str,
soft_timeout: CoarseDuration,
hard_timeout: CoarseDuration,
) -> (MeteredSender<T>, MeteredReceiver<T>) {
let (tx, rx) = oneshot::channel();
(
MeteredSender { inner: tx },
MeteredReceiver {
name,
inner: rx,
soft_timeout,
hard_timeout,
soft_timeout_fut: None,
hard_timeout_fut: None,
first_poll_timestamp: None,
creation_timestamp: CoarseInstant::now(),
},
)
}
#[allow(missing_docs)]
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Oneshot was canceled.")]
Canceled(#[source] Canceled, Measurements),
#[error("Oneshot did not receive a response within {}", CoarseDuration::as_f64(.0))]
HardTimeout(CoarseDuration, Measurements),
}
impl Measurable for Error {
fn measurements(&self) -> Measurements {
match self {
Self::Canceled(_, measurements) => measurements.clone(),
Self::HardTimeout(_, measurements) => measurements.clone(),
}
}
}
#[derive(Debug)]
pub struct MeteredSender<T> {
inner: oneshot::Sender<(CoarseInstant, T)>,
}
impl<T> MeteredSender<T> {
pub fn send(self, t: T) -> Result<(), T> {
let Self { inner } = self;
inner.send((CoarseInstant::now(), t)).map_err(|(_, t)| t)
}
pub fn poll_canceled(&mut self, ctx: &mut Context<'_>) -> Poll<()> {
self.inner.poll_canceled(ctx)
}
pub fn cancellation(&mut self) -> Cancellation<'_, (CoarseInstant, T)> {
self.inner.cancellation()
}
pub fn is_canceled(&self) -> bool {
self.inner.is_canceled()
}
pub fn is_connected_to(&self, receiver: &MeteredReceiver<T>) -> bool {
self.inner.is_connected_to(&receiver.inner)
}
}
#[derive(Debug)]
pub struct MeteredReceiver<T> {
name: &'static str,
inner: oneshot::Receiver<(CoarseInstant, T)>,
soft_timeout_fut: Option<Fuse<Delay>>,
soft_timeout: CoarseDuration,
hard_timeout_fut: Option<Delay>,
hard_timeout: CoarseDuration,
first_poll_timestamp: Option<CoarseInstant>,
creation_timestamp: CoarseInstant,
}
impl<T> MeteredReceiver<T> {
pub fn close(&mut self) {
self.inner.close()
}
pub fn try_recv(&mut self) -> Result<Option<OutputWithMeasurements<T>>, Error> {
match self.inner.try_recv() {
Ok(Some((when, value))) => {
let measurements = self.create_measurement(when, Reason::Completion);
Ok(Some(OutputWithMeasurements { value, measurements }))
},
Err(e) => {
let measurements = self.create_measurement(
self.first_poll_timestamp.unwrap_or_else(|| CoarseInstant::now()),
Reason::Cancellation,
);
Err(Error::Canceled(e, measurements))
},
Ok(None) => Ok(None),
}
}
fn create_measurement(&self, start: CoarseInstant, reason: Reason) -> Measurements {
let end = CoarseInstant::now();
Measurements {
first_poll_till_end: end - start,
creation_till_end: end - self.creation_timestamp,
reason,
}
}
}
impl<T> FusedFuture for MeteredReceiver<T> {
fn is_terminated(&self) -> bool {
self.inner.is_terminated()
}
}
impl<T> Future for MeteredReceiver<T> {
type Output = Result<OutputWithMeasurements<T>, Error>;
fn poll(
mut self: Pin<&mut Self>,
ctx: &mut Context<'_>,
) -> Poll<Result<OutputWithMeasurements<T>, Error>> {
let first_poll_timestamp =
self.first_poll_timestamp.get_or_insert_with(|| CoarseInstant::now()).clone();
let soft_timeout = self.soft_timeout.clone().into();
let soft_timeout = self
.soft_timeout_fut
.get_or_insert_with(move || Delay::new(soft_timeout).fuse());
if Pin::new(soft_timeout).poll(ctx).is_ready() {
tracing::warn!(target: "oneshot", "Oneshot `{name}` exceeded the soft threshold", name = &self.name);
}
let hard_timeout = self.hard_timeout.clone().into();
let hard_timeout =
self.hard_timeout_fut.get_or_insert_with(move || Delay::new(hard_timeout));
if Pin::new(hard_timeout).poll(ctx).is_ready() {
let measurements = self.create_measurement(first_poll_timestamp, Reason::HardTimeout);
return Poll::Ready(Err(Error::HardTimeout(self.hard_timeout.clone(), measurements)))
}
match Pin::new(&mut self.inner).poll(ctx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(e)) => {
let measurements =
self.create_measurement(first_poll_timestamp, Reason::Cancellation);
Poll::Ready(Err(Error::Canceled(e, measurements)))
},
Poll::Ready(Ok((ref sent_at_timestamp, value))) => {
let measurements =
self.create_measurement(sent_at_timestamp.clone(), Reason::Completion);
Poll::Ready(Ok(OutputWithMeasurements::<T> { value, measurements }))
},
}
}
}
pub trait Measurable {
fn measurements(&self) -> Measurements;
}
impl<T> Measurable for Result<OutputWithMeasurements<T>, Error> {
fn measurements(&self) -> Measurements {
match self {
Err(err) => err.measurements(),
Ok(val) => val.measurements(),
}
}
}
#[derive(Clone, Debug)]
pub struct OutputWithMeasurements<T> {
value: T,
measurements: Measurements,
}
impl<T> Measurable for OutputWithMeasurements<T> {
fn measurements(&self) -> Measurements {
self.measurements.clone()
}
}
impl<T> OutputWithMeasurements<T> {
pub fn into(self) -> T {
self.value
}
}
impl<T> AsRef<T> for OutputWithMeasurements<T> {
fn as_ref(&self) -> &T {
&self.value
}
}
impl<T> Deref for OutputWithMeasurements<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.value
}
}
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use futures::{executor::ThreadPool, task::SpawnExt};
use std::time::Duration;
use super::*;
#[derive(Clone, PartialEq, Eq, Debug)]
struct DummyItem {
vals: [u8; 256],
}
impl Default for DummyItem {
fn default() -> Self {
Self { vals: [0u8; 256] }
}
}
fn test_launch<S, R, FS, FR>(name: &'static str, gen_sender_test: S, gen_receiver_test: R)
where
S: Fn(MeteredSender<DummyItem>) -> FS,
R: Fn(MeteredReceiver<DummyItem>) -> FR,
FS: Future<Output = ()> + Send + 'static,
FR: Future<Output = ()> + Send + 'static,
{
let _ = env_logger::builder().is_test(true).filter_level(LevelFilter::Trace).try_init();
let pool = ThreadPool::new().unwrap();
let (tx, rx) = channel(name, CoarseDuration::from_secs(1), CoarseDuration::from_secs(3));
futures::executor::block_on(async move {
let handle_receiver = pool.spawn_with_handle(gen_receiver_test(rx)).unwrap();
let handle_sender = pool.spawn_with_handle(gen_sender_test(tx)).unwrap();
futures::future::select(
futures::future::join(handle_sender, handle_receiver),
Delay::new(Duration::from_secs(5)),
)
.await;
});
}
use log::LevelFilter;
#[test]
fn easy() {
test_launch(
"easy",
|tx| async move {
tx.send(DummyItem::default()).unwrap();
},
|rx| async move {
let x = rx.await.unwrap();
let measurements = x.measurements();
assert_eq!(x.as_ref(), &DummyItem::default());
dbg!(measurements);
},
);
}
#[test]
fn cancel_by_drop() {
test_launch(
"cancel_by_drop",
|tx| async move {
Delay::new(Duration::from_secs(2)).await;
drop(tx);
},
|rx| async move {
let result = rx.await;
assert_matches!(result, Err(Error::Canceled(_, _)));
dbg!(result.measurements());
},
);
}
#[test]
fn starve_till_hard_timeout() {
test_launch(
"starve_till_timeout",
|tx| async move {
Delay::new(Duration::from_secs(4)).await;
let _ = tx.send(DummyItem::default());
},
|rx| async move {
let result = rx.await;
assert_matches!(&result, e @ &Err(Error::HardTimeout(_, _)) => {
println!("{:?}", e);
});
dbg!(result.measurements());
},
);
}
#[test]
fn starve_till_soft_timeout_then_food() {
test_launch(
"starve_till_soft_timeout_then_food",
|tx| async move {
Delay::new(Duration::from_secs(2)).await;
let _ = tx.send(DummyItem::default());
},
|rx| async move {
let result = rx.await;
assert_matches!(result, Ok(_));
dbg!(result.measurements());
},
);
}
}