use crate::core::muxing::{StreamMuxer, StreamMuxerEvent};
use futures::{
io::{IoSlice, IoSliceMut},
prelude::*,
ready,
};
use std::{
convert::TryFrom as _,
io,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
task::{Context, Poll},
};
#[derive(Clone)]
#[pin_project::pin_project]
pub(crate) struct BandwidthLogging<SMInner> {
#[pin]
inner: SMInner,
sinks: Arc<BandwidthSinks>,
}
impl<SMInner> BandwidthLogging<SMInner> {
pub(crate) fn new(inner: SMInner, sinks: Arc<BandwidthSinks>) -> Self {
Self { inner, sinks }
}
}
impl<SMInner> StreamMuxer for BandwidthLogging<SMInner>
where
SMInner: StreamMuxer,
{
type Substream = InstrumentedStream<SMInner::Substream>;
type Error = SMInner::Error;
fn poll(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
let this = self.project();
this.inner.poll(cx)
}
fn poll_inbound(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
let this = self.project();
let inner = ready!(this.inner.poll_inbound(cx)?);
let logged = InstrumentedStream {
inner,
sinks: this.sinks.clone(),
};
Poll::Ready(Ok(logged))
}
fn poll_outbound(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
let this = self.project();
let inner = ready!(this.inner.poll_outbound(cx)?);
let logged = InstrumentedStream {
inner,
sinks: this.sinks.clone(),
};
Poll::Ready(Ok(logged))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let this = self.project();
this.inner.poll_close(cx)
}
}
pub struct BandwidthSinks {
inbound: AtomicU64,
outbound: AtomicU64,
}
impl BandwidthSinks {
pub(crate) fn new() -> Arc<Self> {
Arc::new(Self {
inbound: AtomicU64::new(0),
outbound: AtomicU64::new(0),
})
}
pub fn total_inbound(&self) -> u64 {
self.inbound.load(Ordering::Relaxed)
}
pub fn total_outbound(&self) -> u64 {
self.outbound.load(Ordering::Relaxed)
}
}
#[pin_project::pin_project]
pub(crate) struct InstrumentedStream<SMInner> {
#[pin]
inner: SMInner,
sinks: Arc<BandwidthSinks>,
}
impl<SMInner: AsyncRead> AsyncRead for InstrumentedStream<SMInner> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.project();
let num_bytes = ready!(this.inner.poll_read(cx, buf))?;
this.sinks.inbound.fetch_add(
u64::try_from(num_bytes).unwrap_or(u64::max_value()),
Ordering::Relaxed,
);
Poll::Ready(Ok(num_bytes))
}
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
let this = self.project();
let num_bytes = ready!(this.inner.poll_read_vectored(cx, bufs))?;
this.sinks.inbound.fetch_add(
u64::try_from(num_bytes).unwrap_or(u64::max_value()),
Ordering::Relaxed,
);
Poll::Ready(Ok(num_bytes))
}
}
impl<SMInner: AsyncWrite> AsyncWrite for InstrumentedStream<SMInner> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.project();
let num_bytes = ready!(this.inner.poll_write(cx, buf))?;
this.sinks.outbound.fetch_add(
u64::try_from(num_bytes).unwrap_or(u64::max_value()),
Ordering::Relaxed,
);
Poll::Ready(Ok(num_bytes))
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let this = self.project();
let num_bytes = ready!(this.inner.poll_write_vectored(cx, bufs))?;
this.sinks.outbound.fetch_add(
u64::try_from(num_bytes).unwrap_or(u64::max_value()),
Ordering::Relaxed,
);
Poll::Ready(Ok(num_bytes))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.project();
this.inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.project();
this.inner.poll_close(cx)
}
}