1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
24
25use futures::{future, prelude::*, ready};
26use libp2p_core::muxing::{StreamMuxer, StreamMuxerEvent};
27use libp2p_core::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeInfo};
28use std::collections::VecDeque;
29use std::io::{IoSlice, IoSliceMut};
30use std::task::Waker;
31use std::{
32 io, iter,
33 pin::Pin,
34 task::{Context, Poll},
35};
36use thiserror::Error;
37use yamux::ConnectionError;
38
39#[derive(Debug)]
41pub struct Muxer<C> {
42 connection: yamux::Connection<C>,
43 inbound_stream_buffer: VecDeque<Stream>,
53 inbound_stream_waker: Option<Waker>,
55}
56
57const MAX_BUFFERED_INBOUND_STREAMS: usize = 256;
62
63impl<C> Muxer<C>
64where
65 C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
66{
67 fn new(io: C, cfg: yamux::Config, mode: yamux::Mode) -> Self {
69 Muxer {
70 connection: yamux::Connection::new(io, cfg, mode),
71 inbound_stream_buffer: VecDeque::default(),
72 inbound_stream_waker: None,
73 }
74 }
75}
76
77impl<C> StreamMuxer for Muxer<C>
78where
79 C: AsyncRead + AsyncWrite + Unpin + 'static,
80{
81 type Substream = Stream;
82 type Error = Error;
83
84 fn poll_inbound(
85 mut self: Pin<&mut Self>,
86 cx: &mut Context<'_>,
87 ) -> Poll<Result<Self::Substream, Self::Error>> {
88 if let Some(stream) = self.inbound_stream_buffer.pop_front() {
89 return Poll::Ready(Ok(stream));
90 }
91
92 if let Poll::Ready(res) = self.poll_inner(cx) {
93 return Poll::Ready(res);
94 }
95
96 self.inbound_stream_waker = Some(cx.waker().clone());
97 Poll::Pending
98 }
99
100 fn poll_outbound(
101 mut self: Pin<&mut Self>,
102 cx: &mut Context<'_>,
103 ) -> Poll<Result<Self::Substream, Self::Error>> {
104 let stream = ready!(self.connection.poll_new_outbound(cx).map_err(Error)?);
105
106 Poll::Ready(Ok(Stream(stream)))
107 }
108
109 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110 ready!(self.connection.poll_close(cx).map_err(Error)?);
111
112 Poll::Ready(Ok(()))
113 }
114
115 fn poll(
116 self: Pin<&mut Self>,
117 cx: &mut Context<'_>,
118 ) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
119 let this = self.get_mut();
120
121 let inbound_stream = ready!(this.poll_inner(cx))?;
122
123 if this.inbound_stream_buffer.len() >= MAX_BUFFERED_INBOUND_STREAMS {
124 log::warn!("dropping {} because buffer is full", inbound_stream.0);
125 drop(inbound_stream);
126 } else {
127 this.inbound_stream_buffer.push_back(inbound_stream);
128
129 if let Some(waker) = this.inbound_stream_waker.take() {
130 waker.wake()
131 }
132 }
133
134 cx.waker().wake_by_ref();
136 Poll::Pending
137 }
138}
139
140#[derive(Debug)]
142pub struct Stream(yamux::Stream);
143
144impl AsyncRead for Stream {
145 fn poll_read(
146 mut self: Pin<&mut Self>,
147 cx: &mut Context<'_>,
148 buf: &mut [u8],
149 ) -> Poll<io::Result<usize>> {
150 Pin::new(&mut self.0).poll_read(cx, buf)
151 }
152
153 fn poll_read_vectored(
154 mut self: Pin<&mut Self>,
155 cx: &mut Context<'_>,
156 bufs: &mut [IoSliceMut<'_>],
157 ) -> Poll<io::Result<usize>> {
158 Pin::new(&mut self.0).poll_read_vectored(cx, bufs)
159 }
160}
161
162impl AsyncWrite for Stream {
163 fn poll_write(
164 mut self: Pin<&mut Self>,
165 cx: &mut Context<'_>,
166 buf: &[u8],
167 ) -> Poll<io::Result<usize>> {
168 Pin::new(&mut self.0).poll_write(cx, buf)
169 }
170
171 fn poll_write_vectored(
172 mut self: Pin<&mut Self>,
173 cx: &mut Context<'_>,
174 bufs: &[IoSlice<'_>],
175 ) -> Poll<io::Result<usize>> {
176 Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
177 }
178
179 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
180 Pin::new(&mut self.0).poll_flush(cx)
181 }
182
183 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
184 Pin::new(&mut self.0).poll_close(cx)
185 }
186}
187
188impl<C> Muxer<C>
189where
190 C: AsyncRead + AsyncWrite + Unpin + 'static,
191{
192 fn poll_inner(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream, Error>> {
193 let stream = ready!(self.connection.poll_next_inbound(cx))
194 .transpose()
195 .map_err(Error)?
196 .map(Stream)
197 .ok_or(Error(ConnectionError::Closed))?;
198
199 Poll::Ready(Ok(stream))
200 }
201}
202
203#[derive(Debug, Clone)]
205pub struct Config {
206 inner: yamux::Config,
207 mode: Option<yamux::Mode>,
208}
209
210pub struct WindowUpdateMode(yamux::WindowUpdateMode);
213
214impl WindowUpdateMode {
215 pub fn on_receive() -> Self {
228 WindowUpdateMode(yamux::WindowUpdateMode::OnReceive)
229 }
230
231 pub fn on_read() -> Self {
246 WindowUpdateMode(yamux::WindowUpdateMode::OnRead)
247 }
248}
249
250impl Config {
251 pub fn client() -> Self {
254 Self {
255 mode: Some(yamux::Mode::Client),
256 ..Default::default()
257 }
258 }
259
260 pub fn server() -> Self {
263 Self {
264 mode: Some(yamux::Mode::Server),
265 ..Default::default()
266 }
267 }
268
269 pub fn set_receive_window_size(&mut self, num_bytes: u32) -> &mut Self {
271 self.inner.set_receive_window(num_bytes);
272 self
273 }
274
275 pub fn set_max_buffer_size(&mut self, num_bytes: usize) -> &mut Self {
277 self.inner.set_max_buffer_size(num_bytes);
278 self
279 }
280
281 pub fn set_max_num_streams(&mut self, num_streams: usize) -> &mut Self {
283 self.inner.set_max_num_streams(num_streams);
284 self
285 }
286
287 pub fn set_window_update_mode(&mut self, mode: WindowUpdateMode) -> &mut Self {
290 self.inner.set_window_update_mode(mode.0);
291 self
292 }
293}
294
295impl Default for Config {
296 fn default() -> Self {
297 let mut inner = yamux::Config::default();
298 inner.set_read_after_close(false);
301 Config { inner, mode: None }
302 }
303}
304
305impl UpgradeInfo for Config {
306 type Info = &'static str;
307 type InfoIter = iter::Once<Self::Info>;
308
309 fn protocol_info(&self) -> Self::InfoIter {
310 iter::once("/yamux/1.0.0")
311 }
312}
313
314impl<C> InboundUpgrade<C> for Config
315where
316 C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
317{
318 type Output = Muxer<C>;
319 type Error = io::Error;
320 type Future = future::Ready<Result<Self::Output, Self::Error>>;
321
322 fn upgrade_inbound(self, io: C, _: Self::Info) -> Self::Future {
323 let mode = self.mode.unwrap_or(yamux::Mode::Server);
324 future::ready(Ok(Muxer::new(io, self.inner, mode)))
325 }
326}
327
328impl<C> OutboundUpgrade<C> for Config
329where
330 C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
331{
332 type Output = Muxer<C>;
333 type Error = io::Error;
334 type Future = future::Ready<Result<Self::Output, Self::Error>>;
335
336 fn upgrade_outbound(self, io: C, _: Self::Info) -> Self::Future {
337 let mode = self.mode.unwrap_or(yamux::Mode::Client);
338 future::ready(Ok(Muxer::new(io, self.inner, mode)))
339 }
340}
341
342#[derive(Debug, Error)]
344#[error(transparent)]
345pub struct Error(yamux::ConnectionError);
346
347impl From<Error> for io::Error {
348 fn from(err: Error) -> Self {
349 match err.0 {
350 yamux::ConnectionError::Io(e) => e,
351 e => io::Error::new(io::ErrorKind::Other, e),
352 }
353 }
354}