litep2p/transport/tcp/
substream.rs

1// Copyright 2023 litep2p developers
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::{protocol::Permit, BandwidthSink};
22
23use tokio::io::{AsyncRead, AsyncWrite};
24use tokio_util::compat::Compat;
25
26use std::{
27    io,
28    pin::Pin,
29    task::{Context, Poll},
30};
31
32/// Substream that holds the inner substream provided by the transport
33/// and a permit which keeps the connection open.
34///
35/// `BandwidthSink` is used to meter inbound/outbound bytes.
36#[derive(Debug)]
37pub struct Substream {
38    /// Underlying socket.
39    io: Compat<crate::yamux::Stream>,
40
41    /// Bandwidth sink.
42    bandwidth_sink: BandwidthSink,
43
44    /// Connection permit.
45    _permit: Permit,
46}
47
48impl Substream {
49    /// Create new [`Substream`].
50    pub fn new(
51        io: Compat<crate::yamux::Stream>,
52        bandwidth_sink: BandwidthSink,
53        _permit: Permit,
54    ) -> Self {
55        Self {
56            io,
57            bandwidth_sink,
58            _permit,
59        }
60    }
61}
62
63impl AsyncRead for Substream {
64    fn poll_read(
65        mut self: Pin<&mut Self>,
66        cx: &mut Context<'_>,
67        buf: &mut tokio::io::ReadBuf<'_>,
68    ) -> Poll<io::Result<()>> {
69        match futures::ready!(Pin::new(&mut self.io).poll_read(cx, buf)) {
70            Err(error) => Poll::Ready(Err(error)),
71            Ok(res) => {
72                self.bandwidth_sink.increase_inbound(buf.filled().len());
73                Poll::Ready(Ok(res))
74            }
75        }
76    }
77}
78
79impl AsyncWrite for Substream {
80    fn poll_write(
81        mut self: Pin<&mut Self>,
82        cx: &mut Context<'_>,
83        buf: &[u8],
84    ) -> Poll<Result<usize, io::Error>> {
85        match futures::ready!(Pin::new(&mut self.io).poll_write(cx, buf)) {
86            Err(error) => Poll::Ready(Err(error)),
87            Ok(nwritten) => {
88                self.bandwidth_sink.increase_outbound(nwritten);
89                Poll::Ready(Ok(nwritten))
90            }
91        }
92    }
93
94    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
95        Pin::new(&mut self.io).poll_flush(cx)
96    }
97
98    fn poll_shutdown(
99        mut self: Pin<&mut Self>,
100        cx: &mut Context<'_>,
101    ) -> Poll<Result<(), io::Error>> {
102        Pin::new(&mut self.io).poll_shutdown(cx)
103    }
104
105    fn poll_write_vectored(
106        mut self: Pin<&mut Self>,
107        cx: &mut Context<'_>,
108        bufs: &[io::IoSlice<'_>],
109    ) -> Poll<Result<usize, io::Error>> {
110        match futures::ready!(Pin::new(&mut self.io).poll_write_vectored(cx, bufs)) {
111            Err(error) => Poll::Ready(Err(error)),
112            Ok(nwritten) => {
113                self.bandwidth_sink.increase_outbound(nwritten);
114                Poll::Ready(Ok(nwritten))
115            }
116        }
117    }
118
119    fn is_write_vectored(&self) -> bool {
120        self.io.is_write_vectored()
121    }
122}