1use crate::core::muxing::{StreamMuxer, StreamMuxerEvent};
22
23use futures::{
24 io::{IoSlice, IoSliceMut},
25 prelude::*,
26 ready,
27};
28use std::{
29 convert::TryFrom as _,
30 io,
31 pin::Pin,
32 sync::{
33 atomic::{AtomicU64, Ordering},
34 Arc,
35 },
36 task::{Context, Poll},
37};
38
39#[derive(Clone)]
42#[pin_project::pin_project]
43pub(crate) struct BandwidthLogging<SMInner> {
44 #[pin]
45 inner: SMInner,
46 sinks: Arc<BandwidthSinks>,
47}
48
49impl<SMInner> BandwidthLogging<SMInner> {
50 pub(crate) fn new(inner: SMInner, sinks: Arc<BandwidthSinks>) -> Self {
52 Self { inner, sinks }
53 }
54}
55
56impl<SMInner> StreamMuxer for BandwidthLogging<SMInner>
57where
58 SMInner: StreamMuxer,
59{
60 type Substream = InstrumentedStream<SMInner::Substream>;
61 type Error = SMInner::Error;
62
63 fn poll(
64 self: Pin<&mut Self>,
65 cx: &mut Context<'_>,
66 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
67 let this = self.project();
68 this.inner.poll(cx)
69 }
70
71 fn poll_inbound(
72 self: Pin<&mut Self>,
73 cx: &mut Context<'_>,
74 ) -> Poll<Result<Self::Substream, Self::Error>> {
75 let this = self.project();
76 let inner = ready!(this.inner.poll_inbound(cx)?);
77 let logged = InstrumentedStream {
78 inner,
79 sinks: this.sinks.clone(),
80 };
81 Poll::Ready(Ok(logged))
82 }
83
84 fn poll_outbound(
85 self: Pin<&mut Self>,
86 cx: &mut Context<'_>,
87 ) -> Poll<Result<Self::Substream, Self::Error>> {
88 let this = self.project();
89 let inner = ready!(this.inner.poll_outbound(cx)?);
90 let logged = InstrumentedStream {
91 inner,
92 sinks: this.sinks.clone(),
93 };
94 Poll::Ready(Ok(logged))
95 }
96
97 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
98 let this = self.project();
99 this.inner.poll_close(cx)
100 }
101}
102
103pub struct BandwidthSinks {
105 inbound: AtomicU64,
106 outbound: AtomicU64,
107}
108
109impl BandwidthSinks {
110 pub(crate) fn new() -> Arc<Self> {
112 Arc::new(Self {
113 inbound: AtomicU64::new(0),
114 outbound: AtomicU64::new(0),
115 })
116 }
117
118 pub fn total_inbound(&self) -> u64 {
123 self.inbound.load(Ordering::Relaxed)
124 }
125
126 pub fn total_outbound(&self) -> u64 {
131 self.outbound.load(Ordering::Relaxed)
132 }
133}
134
135#[pin_project::pin_project]
137pub(crate) struct InstrumentedStream<SMInner> {
138 #[pin]
139 inner: SMInner,
140 sinks: Arc<BandwidthSinks>,
141}
142
143impl<SMInner: AsyncRead> AsyncRead for InstrumentedStream<SMInner> {
144 fn poll_read(
145 self: Pin<&mut Self>,
146 cx: &mut Context<'_>,
147 buf: &mut [u8],
148 ) -> Poll<io::Result<usize>> {
149 let this = self.project();
150 let num_bytes = ready!(this.inner.poll_read(cx, buf))?;
151 this.sinks.inbound.fetch_add(
152 u64::try_from(num_bytes).unwrap_or(u64::max_value()),
153 Ordering::Relaxed,
154 );
155 Poll::Ready(Ok(num_bytes))
156 }
157
158 fn poll_read_vectored(
159 self: Pin<&mut Self>,
160 cx: &mut Context<'_>,
161 bufs: &mut [IoSliceMut<'_>],
162 ) -> Poll<io::Result<usize>> {
163 let this = self.project();
164 let num_bytes = ready!(this.inner.poll_read_vectored(cx, bufs))?;
165 this.sinks.inbound.fetch_add(
166 u64::try_from(num_bytes).unwrap_or(u64::max_value()),
167 Ordering::Relaxed,
168 );
169 Poll::Ready(Ok(num_bytes))
170 }
171}
172
173impl<SMInner: AsyncWrite> AsyncWrite for InstrumentedStream<SMInner> {
174 fn poll_write(
175 self: Pin<&mut Self>,
176 cx: &mut Context<'_>,
177 buf: &[u8],
178 ) -> Poll<io::Result<usize>> {
179 let this = self.project();
180 let num_bytes = ready!(this.inner.poll_write(cx, buf))?;
181 this.sinks.outbound.fetch_add(
182 u64::try_from(num_bytes).unwrap_or(u64::max_value()),
183 Ordering::Relaxed,
184 );
185 Poll::Ready(Ok(num_bytes))
186 }
187
188 fn poll_write_vectored(
189 self: Pin<&mut Self>,
190 cx: &mut Context<'_>,
191 bufs: &[IoSlice<'_>],
192 ) -> Poll<io::Result<usize>> {
193 let this = self.project();
194 let num_bytes = ready!(this.inner.poll_write_vectored(cx, bufs))?;
195 this.sinks.outbound.fetch_add(
196 u64::try_from(num_bytes).unwrap_or(u64::max_value()),
197 Ordering::Relaxed,
198 );
199 Poll::Ready(Ok(num_bytes))
200 }
201
202 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
203 let this = self.project();
204 this.inner.poll_flush(cx)
205 }
206
207 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208 let this = self.project();
209 this.inner.poll_close(cx)
210 }
211}