libp2p/
bandwidth.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21#![allow(deprecated)]
22
23use crate::core::muxing::{StreamMuxer, StreamMuxerEvent};
24
25use futures::{
26    io::{IoSlice, IoSliceMut},
27    prelude::*,
28    ready,
29};
30use std::{
31    convert::TryFrom as _,
32    io,
33    pin::Pin,
34    sync::{
35        atomic::{AtomicU64, Ordering},
36        Arc,
37    },
38    task::{Context, Poll},
39};
40
41/// Wraps around a [`StreamMuxer`] and counts the number of bytes that go through all the opened
42/// streams.
43#[derive(Clone)]
44#[pin_project::pin_project]
45pub(crate) struct BandwidthLogging<SMInner> {
46    #[pin]
47    inner: SMInner,
48    sinks: Arc<BandwidthSinks>,
49}
50
51impl<SMInner> BandwidthLogging<SMInner> {
52    /// Creates a new [`BandwidthLogging`] around the stream muxer.
53    pub(crate) fn new(inner: SMInner, sinks: Arc<BandwidthSinks>) -> Self {
54        Self { inner, sinks }
55    }
56}
57
58impl<SMInner> StreamMuxer for BandwidthLogging<SMInner>
59where
60    SMInner: StreamMuxer,
61{
62    type Substream = InstrumentedStream<SMInner::Substream>;
63    type Error = SMInner::Error;
64
65    fn poll(
66        self: Pin<&mut Self>,
67        cx: &mut Context<'_>,
68    ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
69        let this = self.project();
70        this.inner.poll(cx)
71    }
72
73    fn poll_inbound(
74        self: Pin<&mut Self>,
75        cx: &mut Context<'_>,
76    ) -> Poll<Result<Self::Substream, Self::Error>> {
77        let this = self.project();
78        let inner = ready!(this.inner.poll_inbound(cx)?);
79        let logged = InstrumentedStream {
80            inner,
81            sinks: this.sinks.clone(),
82        };
83        Poll::Ready(Ok(logged))
84    }
85
86    fn poll_outbound(
87        self: Pin<&mut Self>,
88        cx: &mut Context<'_>,
89    ) -> Poll<Result<Self::Substream, Self::Error>> {
90        let this = self.project();
91        let inner = ready!(this.inner.poll_outbound(cx)?);
92        let logged = InstrumentedStream {
93            inner,
94            sinks: this.sinks.clone(),
95        };
96        Poll::Ready(Ok(logged))
97    }
98
99    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100        let this = self.project();
101        this.inner.poll_close(cx)
102    }
103}
104
105/// Allows obtaining the average bandwidth of the streams.
106#[deprecated(
107    note = "Use `libp2p::SwarmBuilder::with_bandwidth_metrics` or `libp2p_metrics::BandwidthTransport` instead."
108)]
109pub struct BandwidthSinks {
110    inbound: AtomicU64,
111    outbound: AtomicU64,
112}
113
114impl BandwidthSinks {
115    /// Returns a new [`BandwidthSinks`].
116    pub(crate) fn new() -> Arc<Self> {
117        Arc::new(Self {
118            inbound: AtomicU64::new(0),
119            outbound: AtomicU64::new(0),
120        })
121    }
122
123    /// Returns the total number of bytes that have been downloaded on all the streams.
124    ///
125    /// > **Note**: This method is by design subject to race conditions. The returned value should
126    /// >           only ever be used for statistics purposes.
127    pub fn total_inbound(&self) -> u64 {
128        self.inbound.load(Ordering::Relaxed)
129    }
130
131    /// Returns the total number of bytes that have been uploaded on all the streams.
132    ///
133    /// > **Note**: This method is by design subject to race conditions. The returned value should
134    /// >           only ever be used for statistics purposes.
135    pub fn total_outbound(&self) -> u64 {
136        self.outbound.load(Ordering::Relaxed)
137    }
138}
139
140/// Wraps around an [`AsyncRead`] + [`AsyncWrite`] and logs the bandwidth that goes through it.
141#[pin_project::pin_project]
142pub(crate) struct InstrumentedStream<SMInner> {
143    #[pin]
144    inner: SMInner,
145    sinks: Arc<BandwidthSinks>,
146}
147
148impl<SMInner: AsyncRead> AsyncRead for InstrumentedStream<SMInner> {
149    fn poll_read(
150        self: Pin<&mut Self>,
151        cx: &mut Context<'_>,
152        buf: &mut [u8],
153    ) -> Poll<io::Result<usize>> {
154        let this = self.project();
155        let num_bytes = ready!(this.inner.poll_read(cx, buf))?;
156        this.sinks.inbound.fetch_add(
157            u64::try_from(num_bytes).unwrap_or(u64::MAX),
158            Ordering::Relaxed,
159        );
160        Poll::Ready(Ok(num_bytes))
161    }
162
163    fn poll_read_vectored(
164        self: Pin<&mut Self>,
165        cx: &mut Context<'_>,
166        bufs: &mut [IoSliceMut<'_>],
167    ) -> Poll<io::Result<usize>> {
168        let this = self.project();
169        let num_bytes = ready!(this.inner.poll_read_vectored(cx, bufs))?;
170        this.sinks.inbound.fetch_add(
171            u64::try_from(num_bytes).unwrap_or(u64::MAX),
172            Ordering::Relaxed,
173        );
174        Poll::Ready(Ok(num_bytes))
175    }
176}
177
178impl<SMInner: AsyncWrite> AsyncWrite for InstrumentedStream<SMInner> {
179    fn poll_write(
180        self: Pin<&mut Self>,
181        cx: &mut Context<'_>,
182        buf: &[u8],
183    ) -> Poll<io::Result<usize>> {
184        let this = self.project();
185        let num_bytes = ready!(this.inner.poll_write(cx, buf))?;
186        this.sinks.outbound.fetch_add(
187            u64::try_from(num_bytes).unwrap_or(u64::MAX),
188            Ordering::Relaxed,
189        );
190        Poll::Ready(Ok(num_bytes))
191    }
192
193    fn poll_write_vectored(
194        self: Pin<&mut Self>,
195        cx: &mut Context<'_>,
196        bufs: &[IoSlice<'_>],
197    ) -> Poll<io::Result<usize>> {
198        let this = self.project();
199        let num_bytes = ready!(this.inner.poll_write_vectored(cx, bufs))?;
200        this.sinks.outbound.fetch_add(
201            u64::try_from(num_bytes).unwrap_or(u64::MAX),
202            Ordering::Relaxed,
203        );
204        Poll::Ready(Ok(num_bytes))
205    }
206
207    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
208        let this = self.project();
209        this.inner.poll_flush(cx)
210    }
211
212    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
213        let this = self.project();
214        this.inner.poll_close(cx)
215    }
216}