libp2p_core/muxing/
boxed.rs

1use crate::muxing::{StreamMuxer, StreamMuxerEvent};
2use futures::{AsyncRead, AsyncWrite};
3use pin_project::pin_project;
4use std::error::Error;
5use std::fmt;
6use std::io;
7use std::io::{IoSlice, IoSliceMut};
8use std::pin::Pin;
9use std::task::{Context, Poll};
10
11/// Abstract `StreamMuxer`.
12pub struct StreamMuxerBox {
13    inner: Pin<Box<dyn StreamMuxer<Substream = SubstreamBox, Error = io::Error> + Send>>,
14}
15
16impl fmt::Debug for StreamMuxerBox {
17    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
18        f.debug_struct("StreamMuxerBox").finish_non_exhaustive()
19    }
20}
21
22/// Abstract type for asynchronous reading and writing.
23///
24/// A [`SubstreamBox`] erases the concrete type it is given and only retains its `AsyncRead`
25/// and `AsyncWrite` capabilities.
26pub struct SubstreamBox(Pin<Box<dyn AsyncReadWrite + Send>>);
27
28#[pin_project]
29struct Wrap<T>
30where
31    T: StreamMuxer,
32{
33    #[pin]
34    inner: T,
35}
36
37impl<T> StreamMuxer for Wrap<T>
38where
39    T: StreamMuxer,
40    T::Substream: Send + 'static,
41    T::Error: Send + Sync + 'static,
42{
43    type Substream = SubstreamBox;
44    type Error = io::Error;
45
46    fn poll_inbound(
47        self: Pin<&mut Self>,
48        cx: &mut Context<'_>,
49    ) -> Poll<Result<Self::Substream, Self::Error>> {
50        self.project()
51            .inner
52            .poll_inbound(cx)
53            .map_ok(SubstreamBox::new)
54            .map_err(into_io_error)
55    }
56
57    fn poll_outbound(
58        self: Pin<&mut Self>,
59        cx: &mut Context<'_>,
60    ) -> Poll<Result<Self::Substream, Self::Error>> {
61        self.project()
62            .inner
63            .poll_outbound(cx)
64            .map_ok(SubstreamBox::new)
65            .map_err(into_io_error)
66    }
67
68    #[inline]
69    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
70        self.project().inner.poll_close(cx).map_err(into_io_error)
71    }
72
73    fn poll(
74        self: Pin<&mut Self>,
75        cx: &mut Context<'_>,
76    ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
77        self.project().inner.poll(cx).map_err(into_io_error)
78    }
79}
80
81fn into_io_error<E>(err: E) -> io::Error
82where
83    E: Error + Send + Sync + 'static,
84{
85    io::Error::new(io::ErrorKind::Other, err)
86}
87
88impl StreamMuxerBox {
89    /// Turns a stream muxer into a `StreamMuxerBox`.
90    pub fn new<T>(muxer: T) -> StreamMuxerBox
91    where
92        T: StreamMuxer + Send + 'static,
93        T::Substream: Send + 'static,
94        T::Error: Send + Sync + 'static,
95    {
96        let wrap = Wrap { inner: muxer };
97
98        StreamMuxerBox {
99            inner: Box::pin(wrap),
100        }
101    }
102
103    fn project(
104        self: Pin<&mut Self>,
105    ) -> Pin<&mut (dyn StreamMuxer<Substream = SubstreamBox, Error = io::Error> + Send)> {
106        self.get_mut().inner.as_mut()
107    }
108}
109
110impl StreamMuxer for StreamMuxerBox {
111    type Substream = SubstreamBox;
112    type Error = io::Error;
113
114    fn poll_inbound(
115        self: Pin<&mut Self>,
116        cx: &mut Context<'_>,
117    ) -> Poll<Result<Self::Substream, Self::Error>> {
118        self.project().poll_inbound(cx)
119    }
120
121    fn poll_outbound(
122        self: Pin<&mut Self>,
123        cx: &mut Context<'_>,
124    ) -> Poll<Result<Self::Substream, Self::Error>> {
125        self.project().poll_outbound(cx)
126    }
127
128    #[inline]
129    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
130        self.project().poll_close(cx)
131    }
132
133    fn poll(
134        self: Pin<&mut Self>,
135        cx: &mut Context<'_>,
136    ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
137        self.project().poll(cx)
138    }
139}
140
141impl SubstreamBox {
142    /// Construct a new [`SubstreamBox`] from something that implements [`AsyncRead`] and [`AsyncWrite`].
143    pub fn new<S: AsyncRead + AsyncWrite + Send + 'static>(stream: S) -> Self {
144        Self(Box::pin(stream))
145    }
146}
147
148impl fmt::Debug for SubstreamBox {
149    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150        write!(f, "SubstreamBox({})", self.0.type_name())
151    }
152}
153
154/// Workaround because Rust does not allow `Box<dyn AsyncRead + AsyncWrite>`.
155trait AsyncReadWrite: AsyncRead + AsyncWrite {
156    /// Helper function to capture the erased inner type.
157    ///
158    /// Used to make the [`Debug`] implementation of [`SubstreamBox`] more useful.
159    fn type_name(&self) -> &'static str;
160}
161
162impl<S> AsyncReadWrite for S
163where
164    S: AsyncRead + AsyncWrite,
165{
166    fn type_name(&self) -> &'static str {
167        std::any::type_name::<S>()
168    }
169}
170
171impl AsyncRead for SubstreamBox {
172    fn poll_read(
173        mut self: Pin<&mut Self>,
174        cx: &mut Context<'_>,
175        buf: &mut [u8],
176    ) -> Poll<std::io::Result<usize>> {
177        self.0.as_mut().poll_read(cx, buf)
178    }
179
180    fn poll_read_vectored(
181        mut self: Pin<&mut Self>,
182        cx: &mut Context<'_>,
183        bufs: &mut [IoSliceMut<'_>],
184    ) -> Poll<std::io::Result<usize>> {
185        self.0.as_mut().poll_read_vectored(cx, bufs)
186    }
187}
188
189impl AsyncWrite for SubstreamBox {
190    fn poll_write(
191        mut self: Pin<&mut Self>,
192        cx: &mut Context<'_>,
193        buf: &[u8],
194    ) -> Poll<std::io::Result<usize>> {
195        self.0.as_mut().poll_write(cx, buf)
196    }
197
198    fn poll_write_vectored(
199        mut self: Pin<&mut Self>,
200        cx: &mut Context<'_>,
201        bufs: &[IoSlice<'_>],
202    ) -> Poll<std::io::Result<usize>> {
203        self.0.as_mut().poll_write_vectored(cx, bufs)
204    }
205
206    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
207        self.0.as_mut().poll_flush(cx)
208    }
209
210    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
211        self.0.as_mut().poll_close(cx)
212    }
213}