libp2p/
bandwidth.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::core::muxing::{StreamMuxer, StreamMuxerEvent};
22
23use futures::{
24    io::{IoSlice, IoSliceMut},
25    prelude::*,
26    ready,
27};
28use std::{
29    convert::TryFrom as _,
30    io,
31    pin::Pin,
32    sync::{
33        atomic::{AtomicU64, Ordering},
34        Arc,
35    },
36    task::{Context, Poll},
37};
38
39/// Wraps around a [`StreamMuxer`] and counts the number of bytes that go through all the opened
40/// streams.
41#[derive(Clone)]
42#[pin_project::pin_project]
43pub(crate) struct BandwidthLogging<SMInner> {
44    #[pin]
45    inner: SMInner,
46    sinks: Arc<BandwidthSinks>,
47}
48
49impl<SMInner> BandwidthLogging<SMInner> {
50    /// Creates a new [`BandwidthLogging`] around the stream muxer.
51    pub(crate) fn new(inner: SMInner, sinks: Arc<BandwidthSinks>) -> Self {
52        Self { inner, sinks }
53    }
54}
55
56impl<SMInner> StreamMuxer for BandwidthLogging<SMInner>
57where
58    SMInner: StreamMuxer,
59{
60    type Substream = InstrumentedStream<SMInner::Substream>;
61    type Error = SMInner::Error;
62
63    fn poll(
64        self: Pin<&mut Self>,
65        cx: &mut Context<'_>,
66    ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
67        let this = self.project();
68        this.inner.poll(cx)
69    }
70
71    fn poll_inbound(
72        self: Pin<&mut Self>,
73        cx: &mut Context<'_>,
74    ) -> Poll<Result<Self::Substream, Self::Error>> {
75        let this = self.project();
76        let inner = ready!(this.inner.poll_inbound(cx)?);
77        let logged = InstrumentedStream {
78            inner,
79            sinks: this.sinks.clone(),
80        };
81        Poll::Ready(Ok(logged))
82    }
83
84    fn poll_outbound(
85        self: Pin<&mut Self>,
86        cx: &mut Context<'_>,
87    ) -> Poll<Result<Self::Substream, Self::Error>> {
88        let this = self.project();
89        let inner = ready!(this.inner.poll_outbound(cx)?);
90        let logged = InstrumentedStream {
91            inner,
92            sinks: this.sinks.clone(),
93        };
94        Poll::Ready(Ok(logged))
95    }
96
97    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
98        let this = self.project();
99        this.inner.poll_close(cx)
100    }
101}
102
103/// Allows obtaining the average bandwidth of the streams.
104pub struct BandwidthSinks {
105    inbound: AtomicU64,
106    outbound: AtomicU64,
107}
108
109impl BandwidthSinks {
110    /// Returns a new [`BandwidthSinks`].
111    pub(crate) fn new() -> Arc<Self> {
112        Arc::new(Self {
113            inbound: AtomicU64::new(0),
114            outbound: AtomicU64::new(0),
115        })
116    }
117
118    /// Returns the total number of bytes that have been downloaded on all the streams.
119    ///
120    /// > **Note**: This method is by design subject to race conditions. The returned value should
121    /// >           only ever be used for statistics purposes.
122    pub fn total_inbound(&self) -> u64 {
123        self.inbound.load(Ordering::Relaxed)
124    }
125
126    /// Returns the total number of bytes that have been uploaded on all the streams.
127    ///
128    /// > **Note**: This method is by design subject to race conditions. The returned value should
129    /// >           only ever be used for statistics purposes.
130    pub fn total_outbound(&self) -> u64 {
131        self.outbound.load(Ordering::Relaxed)
132    }
133}
134
135/// Wraps around an [`AsyncRead`] + [`AsyncWrite`] and logs the bandwidth that goes through it.
136#[pin_project::pin_project]
137pub(crate) struct InstrumentedStream<SMInner> {
138    #[pin]
139    inner: SMInner,
140    sinks: Arc<BandwidthSinks>,
141}
142
143impl<SMInner: AsyncRead> AsyncRead for InstrumentedStream<SMInner> {
144    fn poll_read(
145        self: Pin<&mut Self>,
146        cx: &mut Context<'_>,
147        buf: &mut [u8],
148    ) -> Poll<io::Result<usize>> {
149        let this = self.project();
150        let num_bytes = ready!(this.inner.poll_read(cx, buf))?;
151        this.sinks.inbound.fetch_add(
152            u64::try_from(num_bytes).unwrap_or(u64::max_value()),
153            Ordering::Relaxed,
154        );
155        Poll::Ready(Ok(num_bytes))
156    }
157
158    fn poll_read_vectored(
159        self: Pin<&mut Self>,
160        cx: &mut Context<'_>,
161        bufs: &mut [IoSliceMut<'_>],
162    ) -> Poll<io::Result<usize>> {
163        let this = self.project();
164        let num_bytes = ready!(this.inner.poll_read_vectored(cx, bufs))?;
165        this.sinks.inbound.fetch_add(
166            u64::try_from(num_bytes).unwrap_or(u64::max_value()),
167            Ordering::Relaxed,
168        );
169        Poll::Ready(Ok(num_bytes))
170    }
171}
172
173impl<SMInner: AsyncWrite> AsyncWrite for InstrumentedStream<SMInner> {
174    fn poll_write(
175        self: Pin<&mut Self>,
176        cx: &mut Context<'_>,
177        buf: &[u8],
178    ) -> Poll<io::Result<usize>> {
179        let this = self.project();
180        let num_bytes = ready!(this.inner.poll_write(cx, buf))?;
181        this.sinks.outbound.fetch_add(
182            u64::try_from(num_bytes).unwrap_or(u64::max_value()),
183            Ordering::Relaxed,
184        );
185        Poll::Ready(Ok(num_bytes))
186    }
187
188    fn poll_write_vectored(
189        self: Pin<&mut Self>,
190        cx: &mut Context<'_>,
191        bufs: &[IoSlice<'_>],
192    ) -> Poll<io::Result<usize>> {
193        let this = self.project();
194        let num_bytes = ready!(this.inner.poll_write_vectored(cx, bufs))?;
195        this.sinks.outbound.fetch_add(
196            u64::try_from(num_bytes).unwrap_or(u64::max_value()),
197            Ordering::Relaxed,
198        );
199        Poll::Ready(Ok(num_bytes))
200    }
201
202    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
203        let this = self.project();
204        this.inner.poll_flush(cx)
205    }
206
207    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208        let this = self.project();
209        this.inner.poll_close(cx)
210    }
211}