#[cfg(feature = "async_channel")]
use async_channel::{
bounded as bounded_channel, Receiver, Sender, TryRecvError, TrySendError as ChannelTrySendError,
};
#[cfg(feature = "futures_channel")]
use futures::{
channel::mpsc::channel as bounded_channel,
channel::mpsc::{Receiver, Sender, TryRecvError, TrySendError as FuturesTrySendError},
sink::SinkExt,
};
use futures::{
stream::Stream,
task::{Context, Poll},
};
use std::{pin::Pin, result};
use super::{prepare_with_tof, MaybeTimeOfFlight, Meter};
pub fn channel<T>(capacity: usize) -> (MeteredSender<T>, MeteredReceiver<T>) {
let (tx, rx) = bounded_channel::<MaybeTimeOfFlight<T>>(capacity);
let shared_meter = Meter::default();
let tx =
MeteredSender { meter: shared_meter.clone(), bulk_channel: tx, priority_channel: None };
let rx = MeteredReceiver { meter: shared_meter, bulk_channel: rx, priority_channel: None };
(tx, rx)
}
pub fn channel_with_priority<T>(
capacity_bulk: usize,
capacity_priority: usize,
) -> (MeteredSender<T>, MeteredReceiver<T>) {
let (tx, rx) = bounded_channel::<MaybeTimeOfFlight<T>>(capacity_bulk);
let (tx_pri, rx_pri) = bounded_channel::<MaybeTimeOfFlight<T>>(capacity_priority);
let shared_meter = Meter::default();
let tx = MeteredSender {
meter: shared_meter.clone(),
bulk_channel: tx,
priority_channel: Some(tx_pri),
};
let rx =
MeteredReceiver { meter: shared_meter, bulk_channel: rx, priority_channel: Some(rx_pri) };
(tx, rx)
}
#[derive(Debug)]
pub struct MeteredReceiver<T> {
meter: Meter,
bulk_channel: Receiver<MaybeTimeOfFlight<T>>,
priority_channel: Option<Receiver<MaybeTimeOfFlight<T>>>,
}
#[derive(thiserror::Error, Debug)]
pub enum SendError<T> {
#[error("Bounded channel has been closed")]
Closed(T),
#[error("Bounded channel has been closed and the original message is lost")]
Terminated,
}
impl<T> SendError<T> {
pub fn into_inner(self) -> Option<T> {
match self {
Self::Closed(t) => Some(t),
Self::Terminated => None,
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum TrySendError<T> {
#[error("Bounded channel has been closed")]
Closed(T),
#[error("Bounded channel is full")]
Full(T),
}
#[cfg(feature = "async_channel")]
impl<T> From<ChannelTrySendError<MaybeTimeOfFlight<T>>> for TrySendError<T> {
fn from(error: ChannelTrySendError<MaybeTimeOfFlight<T>>) -> Self {
match error {
ChannelTrySendError::Closed(val) => Self::Closed(val.into()),
ChannelTrySendError::Full(val) => Self::Full(val.into()),
}
}
}
#[cfg(feature = "async_channel")]
impl<T> From<ChannelTrySendError<T>> for TrySendError<T> {
fn from(error: ChannelTrySendError<T>) -> Self {
match error {
ChannelTrySendError::Closed(val) => Self::Closed(val),
ChannelTrySendError::Full(val) => Self::Full(val),
}
}
}
#[cfg(feature = "futures_channel")]
impl<T> From<FuturesTrySendError<MaybeTimeOfFlight<T>>> for TrySendError<T> {
fn from(error: FuturesTrySendError<MaybeTimeOfFlight<T>>) -> Self {
let disconnected = error.is_disconnected();
let val = error.into_inner();
let val = val.into();
if disconnected {
Self::Closed(val)
} else {
Self::Full(val)
}
}
}
#[cfg(feature = "futures_channel")]
impl<T> From<FuturesTrySendError<T>> for TrySendError<T> {
fn from(error: FuturesTrySendError<T>) -> Self {
let disconnected = error.is_disconnected();
let val = error.into_inner();
if disconnected {
Self::Closed(val)
} else {
Self::Full(val)
}
}
}
impl<T> TrySendError<T> {
pub fn into_inner(self) -> T {
match self {
Self::Closed(t) => t,
Self::Full(t) => t,
}
}
pub fn is_full(&self) -> bool {
match self {
Self::Closed(_) => false,
Self::Full(_) => true,
}
}
pub fn is_disconnected(&self) -> bool {
match self {
Self::Closed(_) => true,
Self::Full(_) => false,
}
}
pub fn transform_inner<U, F>(self, f: F) -> TrySendError<U>
where
F: FnOnce(T) -> U,
{
match self {
Self::Closed(t) => TrySendError::<U>::Closed(f(t)),
Self::Full(t) => TrySendError::<U>::Full(f(t)),
}
}
pub fn try_transform_inner<U, F, E>(self, f: F) -> std::result::Result<TrySendError<U>, E>
where
F: FnOnce(T) -> std::result::Result<U, E>,
E: std::fmt::Debug + std::error::Error + Send + Sync + 'static,
{
Ok(match self {
Self::Closed(t) => TrySendError::<U>::Closed(f(t)?),
Self::Full(t) => TrySendError::<U>::Full(f(t)?),
})
}
}
#[derive(thiserror::Error, PartialEq, Eq, Clone, Copy, Debug)]
pub struct RecvError {}
impl std::fmt::Display for RecvError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "receiving from an empty and closed channel")
}
}
#[cfg(feature = "async_channel")]
impl From<async_channel::RecvError> for RecvError {
fn from(_: async_channel::RecvError) -> Self {
RecvError {}
}
}
impl<T> std::ops::Deref for MeteredReceiver<T> {
type Target = Receiver<MaybeTimeOfFlight<T>>;
fn deref(&self) -> &Self::Target {
&self.bulk_channel
}
}
impl<T> std::ops::DerefMut for MeteredReceiver<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.bulk_channel
}
}
impl<T> Stream for MeteredReceiver<T> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if let Some(priority_channel) = &mut self.priority_channel {
match Receiver::poll_next(Pin::new(priority_channel), cx) {
Poll::Ready(maybe_value) => return Poll::Ready(self.maybe_meter_tof(maybe_value)),
Poll::Pending => {},
}
}
match Receiver::poll_next(Pin::new(&mut self.bulk_channel), cx) {
Poll::Ready(maybe_value) => Poll::Ready(self.maybe_meter_tof(maybe_value)),
Poll::Pending => Poll::Pending,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.bulk_channel.size_hint()
}
}
impl<T> MeteredReceiver<T> {
fn maybe_meter_tof(&mut self, maybe_value: Option<MaybeTimeOfFlight<T>>) -> Option<T> {
self.meter.note_received();
maybe_value.map(|value| {
match value {
MaybeTimeOfFlight::<T>::WithTimeOfFlight(value, tof_start) => {
let duration = tof_start.elapsed();
self.meter.note_time_of_flight(duration);
value
},
MaybeTimeOfFlight::<T>::Bare(value) => value,
}
.into()
})
}
pub fn meter(&self) -> &Meter {
#[cfg(feature = "async_channel")]
self.meter.note_channel_len(self.len());
&self.meter
}
#[cfg(feature = "futures_channel")]
pub fn try_next(&mut self) -> Result<Option<T>, TryRecvError> {
if let Some(priority_channel) = &mut self.priority_channel {
match priority_channel.try_next() {
Ok(Some(value)) => return Ok(self.maybe_meter_tof(Some(value))),
Ok(None) => return Ok(None), Err(_) => {}, }
}
match self.bulk_channel.try_next()? {
Some(value) => Ok(self.maybe_meter_tof(Some(value))),
None => Ok(None),
}
}
#[cfg(feature = "async_channel")]
pub fn try_next(&mut self) -> Result<Option<T>, TryRecvError> {
if let Some(priority_channel) = &mut self.priority_channel {
match priority_channel.try_recv() {
Ok(value) => return Ok(self.maybe_meter_tof(Some(value))),
Err(TryRecvError::Empty) => {}, Err(TryRecvError::Closed) => return Ok(None), }
}
match self.bulk_channel.try_recv() {
Ok(value) => Ok(self.maybe_meter_tof(Some(value))),
Err(TryRecvError::Empty) => Err(TryRecvError::Empty),
Err(TryRecvError::Closed) => Ok(None), }
}
#[cfg(feature = "async_channel")]
pub async fn recv(&mut self) -> Result<T, RecvError> {
if let Some(priority_channel) = &mut self.priority_channel {
match priority_channel.try_recv() {
Ok(value) =>
return Ok(self
.maybe_meter_tof(Some(value))
.expect("wrapped value is always Some, qed")),
Err(err) => match err {
TryRecvError::Closed => return Err(RecvError {}),
TryRecvError::Empty => {}, },
}
}
match self.bulk_channel.recv().await {
Ok(value) =>
Ok(self.maybe_meter_tof(Some(value)).expect("wrapped value is always Some, qed")),
Err(err) => Err(err.into()),
}
}
#[cfg(feature = "async_channel")]
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
if let Some(priority_channel) = &mut self.priority_channel {
match priority_channel.try_recv() {
Ok(value) =>
return Ok(self
.maybe_meter_tof(Some(value))
.expect("wrapped value is always Some, qed")),
Err(err) => match err {
TryRecvError::Closed => return Err(err.into()),
TryRecvError::Empty => {},
},
}
}
match self.bulk_channel.try_recv() {
Ok(value) =>
Ok(self.maybe_meter_tof(Some(value)).expect("wrapped value is always Some, qed")),
Err(err) => Err(err),
}
}
#[cfg(feature = "async_channel")]
pub fn len(&self) -> usize {
self.bulk_channel.len() + self.priority_channel.as_ref().map_or(0, |c| c.len())
}
#[cfg(feature = "futures_channel")]
pub fn len(&self) -> usize {
self.meter.calculate_channel_len()
}
}
impl<T> futures::stream::FusedStream for MeteredReceiver<T> {
fn is_terminated(&self) -> bool {
self.bulk_channel.is_terminated() &&
self.priority_channel.as_ref().map_or(true, |c| c.is_terminated())
}
}
#[derive(Debug)]
pub struct MeteredSender<T> {
meter: Meter,
bulk_channel: Sender<MaybeTimeOfFlight<T>>,
priority_channel: Option<Sender<MaybeTimeOfFlight<T>>>,
}
impl<T> Clone for MeteredSender<T> {
fn clone(&self) -> Self {
Self {
meter: self.meter.clone(),
bulk_channel: self.bulk_channel.clone(),
priority_channel: self.priority_channel.clone(),
}
}
}
impl<T> std::ops::Deref for MeteredSender<T> {
type Target = Sender<MaybeTimeOfFlight<T>>;
fn deref(&self) -> &Self::Target {
&self.bulk_channel
}
}
impl<T> std::ops::DerefMut for MeteredSender<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.bulk_channel
}
}
impl<T> MeteredSender<T> {
pub fn meter(&self) -> &Meter {
#[cfg(feature = "async_channel")]
self.meter.note_channel_len(self.len());
&self.meter
}
pub async fn send(&mut self, msg: T) -> result::Result<(), SendError<T>>
where
Self: Unpin,
{
self.send_inner(msg, false).await
}
pub async fn priority_send(&mut self, msg: T) -> result::Result<(), SendError<T>>
where
Self: Unpin,
{
self.send_inner(msg, true).await
}
async fn send_inner(
&mut self,
msg: T,
use_priority_channel: bool,
) -> result::Result<(), SendError<T>>
where
Self: Unpin,
{
let res =
if use_priority_channel { self.try_priority_send(msg) } else { self.try_send(msg) };
match res {
Err(send_err) => {
if !send_err.is_full() {
return Err(SendError::Closed(send_err.into_inner().into()))
}
self.meter.note_blocked();
self.meter.note_sent(); let msg = send_err.into_inner().into();
self.send_to_channel(msg, use_priority_channel).await
},
_ => Ok(()),
}
}
#[cfg(feature = "async_channel")]
async fn send_to_channel(
&mut self,
msg: MaybeTimeOfFlight<T>,
use_priority_channel: bool,
) -> result::Result<(), SendError<T>> {
let channel = if use_priority_channel {
self.priority_channel.as_mut().unwrap_or(&mut self.bulk_channel)
} else {
&mut self.bulk_channel
};
let fut = channel.send(msg);
futures::pin_mut!(fut);
let result = fut.await.map_err(|err| {
self.meter.retract_sent();
SendError::Closed(err.0.into())
});
result
}
#[cfg(feature = "futures_channel")]
async fn send_to_channel(
&mut self,
msg: MaybeTimeOfFlight<T>,
use_priority_channel: bool,
) -> result::Result<(), SendError<T>> {
let channel = if use_priority_channel {
self.priority_channel.as_mut().unwrap_or(&mut self.bulk_channel)
} else {
&mut self.bulk_channel
};
let fut = channel.send(msg);
futures::pin_mut!(fut);
fut.await.map_err(|_| {
self.meter.retract_sent();
SendError::Terminated
})
}
pub fn try_send(&mut self, msg: T) -> result::Result<(), TrySendError<T>> {
let msg = prepare_with_tof(&self.meter, msg); self.bulk_channel.try_send(msg).map_err(|e| {
self.meter.retract_sent(); TrySendError::from(e)
})
}
pub fn try_priority_send(&mut self, msg: T) -> result::Result<(), TrySendError<T>> {
match self.priority_channel.as_mut() {
Some(priority_channel) => {
let msg = prepare_with_tof(&self.meter, msg);
priority_channel.try_send(msg).map_err(|e| {
self.meter.retract_sent(); TrySendError::from(e)
})
},
None => self.try_send(msg), }
}
#[cfg(feature = "async_channel")]
pub fn len(&self) -> usize {
self.bulk_channel.len() + self.priority_channel.as_ref().map_or(0, |c| c.len())
}
#[cfg(feature = "futures_channel")]
pub fn len(&self) -> usize {
self.meter.calculate_channel_len()
}
}