litep2p/substream/
mod.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2// Copyright 2023 litep2p developers
3//
4// Permission is hereby granted, free of charge, to any person obtaining a
5// copy of this software and associated documentation files (the "Software"),
6// to deal in the Software without restriction, including without limitation
7// the rights to use, copy, modify, merge, publish, distribute, sublicense,
8// and/or sell copies of the Software, and to permit persons to whom the
9// Software is furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20// DEALINGS IN THE SOFTWARE.
21
22//! Substream-related helper code.
23
24use crate::{
25    codec::ProtocolCodec, error::SubstreamError, transport::tcp, types::SubstreamId, PeerId,
26};
27
28#[cfg(feature = "quic")]
29use crate::transport::quic;
30#[cfg(feature = "webrtc")]
31use crate::transport::webrtc;
32#[cfg(feature = "websocket")]
33use crate::transport::websocket;
34
35use bytes::{Buf, Bytes, BytesMut};
36use futures::{Sink, Stream};
37use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
38use unsigned_varint::{decode, encode};
39
40use std::{
41    collections::{hash_map::Entry, HashMap, VecDeque},
42    fmt,
43    hash::Hash,
44    io::ErrorKind,
45    pin::Pin,
46    task::{Context, Poll},
47};
48
49/// Logging target for the file.
50const LOG_TARGET: &str = "litep2p::substream";
51
52macro_rules! poll_flush {
53    ($substream:expr, $cx:ident) => {{
54        match $substream {
55            SubstreamType::Tcp(substream) => Pin::new(substream).poll_flush($cx),
56            #[cfg(feature = "websocket")]
57            SubstreamType::WebSocket(substream) => Pin::new(substream).poll_flush($cx),
58            #[cfg(feature = "quic")]
59            SubstreamType::Quic(substream) => Pin::new(substream).poll_flush($cx),
60            #[cfg(feature = "webrtc")]
61            SubstreamType::WebRtc(substream) => Pin::new(substream).poll_flush($cx),
62            #[cfg(test)]
63            SubstreamType::Mock(_) => unreachable!(),
64        }
65    }};
66}
67
68macro_rules! poll_write {
69    ($substream:expr, $cx:ident, $frame:expr) => {{
70        match $substream {
71            SubstreamType::Tcp(substream) => Pin::new(substream).poll_write($cx, $frame),
72            #[cfg(feature = "websocket")]
73            SubstreamType::WebSocket(substream) => Pin::new(substream).poll_write($cx, $frame),
74            #[cfg(feature = "quic")]
75            SubstreamType::Quic(substream) => Pin::new(substream).poll_write($cx, $frame),
76            #[cfg(feature = "webrtc")]
77            SubstreamType::WebRtc(substream) => Pin::new(substream).poll_write($cx, $frame),
78            #[cfg(test)]
79            SubstreamType::Mock(_) => unreachable!(),
80        }
81    }};
82}
83
84macro_rules! poll_read {
85    ($substream:expr, $cx:ident, $buffer:expr) => {{
86        match $substream {
87            SubstreamType::Tcp(substream) => Pin::new(substream).poll_read($cx, $buffer),
88            #[cfg(feature = "websocket")]
89            SubstreamType::WebSocket(substream) => Pin::new(substream).poll_read($cx, $buffer),
90            #[cfg(feature = "quic")]
91            SubstreamType::Quic(substream) => Pin::new(substream).poll_read($cx, $buffer),
92            #[cfg(feature = "webrtc")]
93            SubstreamType::WebRtc(substream) => Pin::new(substream).poll_read($cx, $buffer),
94            #[cfg(test)]
95            SubstreamType::Mock(_) => unreachable!(),
96        }
97    }};
98}
99
100macro_rules! poll_shutdown {
101    ($substream:expr, $cx:ident) => {{
102        match $substream {
103            SubstreamType::Tcp(substream) => Pin::new(substream).poll_shutdown($cx),
104            #[cfg(feature = "websocket")]
105            SubstreamType::WebSocket(substream) => Pin::new(substream).poll_shutdown($cx),
106            #[cfg(feature = "quic")]
107            SubstreamType::Quic(substream) => Pin::new(substream).poll_shutdown($cx),
108            #[cfg(feature = "webrtc")]
109            SubstreamType::WebRtc(substream) => Pin::new(substream).poll_shutdown($cx),
110            #[cfg(test)]
111            SubstreamType::Mock(substream) => {
112                let _ = Pin::new(substream).poll_close($cx);
113                todo!();
114            }
115        }
116    }};
117}
118
119macro_rules! delegate_poll_next {
120    ($substream:expr, $cx:ident) => {{
121        #[cfg(test)]
122        if let SubstreamType::Mock(inner) = $substream {
123            return Pin::new(inner).poll_next($cx);
124        }
125    }};
126}
127
128macro_rules! delegate_poll_ready {
129    ($substream:expr, $cx:ident) => {{
130        #[cfg(test)]
131        if let SubstreamType::Mock(inner) = $substream {
132            return Pin::new(inner).poll_ready($cx);
133        }
134    }};
135}
136
137macro_rules! delegate_start_send {
138    ($substream:expr, $item:ident) => {{
139        #[cfg(test)]
140        if let SubstreamType::Mock(inner) = $substream {
141            return Pin::new(inner).start_send($item);
142        }
143    }};
144}
145
146macro_rules! delegate_poll_flush {
147    ($substream:expr, $cx:ident) => {{
148        #[cfg(test)]
149        if let SubstreamType::Mock(inner) = $substream {
150            return Pin::new(inner).poll_flush($cx);
151        }
152    }};
153}
154
155macro_rules! check_size {
156    ($max_size:expr, $size:expr) => {{
157        if let Some(max_size) = $max_size {
158            if $size > max_size {
159                return Err(SubstreamError::IoError(ErrorKind::PermissionDenied).into());
160            }
161        }
162    }};
163}
164
165/// Substream type.
166enum SubstreamType {
167    Tcp(tcp::Substream),
168    #[cfg(feature = "websocket")]
169    WebSocket(websocket::Substream),
170    #[cfg(feature = "quic")]
171    Quic(quic::Substream),
172    #[cfg(feature = "webrtc")]
173    WebRtc(webrtc::Substream),
174    #[cfg(test)]
175    Mock(Box<dyn crate::mock::substream::Substream>),
176}
177
178impl fmt::Debug for SubstreamType {
179    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180        match self {
181            Self::Tcp(_) => write!(f, "Tcp"),
182            #[cfg(feature = "websocket")]
183            Self::WebSocket(_) => write!(f, "WebSocket"),
184            #[cfg(feature = "quic")]
185            Self::Quic(_) => write!(f, "Quic"),
186            #[cfg(feature = "webrtc")]
187            Self::WebRtc(_) => write!(f, "WebRtc"),
188            #[cfg(test)]
189            Self::Mock(_) => write!(f, "Mock"),
190        }
191    }
192}
193
194/// Backpressure boundary for `Sink`.
195const BACKPRESSURE_BOUNDARY: usize = 65536;
196
197/// `Litep2p` substream type.
198///
199/// Implements [`tokio::io::AsyncRead`]/[`tokio::io::AsyncWrite`] traits which can be wrapped
200/// in a `Framed` to implement a custom codec.
201///
202/// In case a codec for the protocol was specified,
203/// [`Sink::send()`](futures::Sink)/[`Stream::next()`](futures::Stream) are also provided which
204/// implement the necessary framing to read/write codec-encoded messages from the underlying socket.
205pub struct Substream {
206    /// Remote peer ID.
207    peer: PeerId,
208
209    // Inner substream.
210    substream: SubstreamType,
211
212    /// Substream ID.
213    substream_id: SubstreamId,
214
215    /// Protocol codec.
216    codec: ProtocolCodec,
217
218    pending_out_frames: VecDeque<Bytes>,
219    pending_out_bytes: usize,
220    pending_out_frame: Option<Bytes>,
221
222    read_buffer: BytesMut,
223    offset: usize,
224    pending_frames: VecDeque<BytesMut>,
225    current_frame_size: Option<usize>,
226
227    size_vec: BytesMut,
228}
229
230impl fmt::Debug for Substream {
231    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232        f.debug_struct("Substream")
233            .field("peer", &self.peer)
234            .field("substream_id", &self.substream_id)
235            .field("codec", &self.codec)
236            .field("protocol", &self.substream)
237            .finish()
238    }
239}
240
241impl Substream {
242    /// Create new [`Substream`].
243    fn new(
244        peer: PeerId,
245        substream_id: SubstreamId,
246        substream: SubstreamType,
247        codec: ProtocolCodec,
248    ) -> Self {
249        Self {
250            peer,
251            substream,
252            codec,
253            substream_id,
254            read_buffer: BytesMut::zeroed(1024),
255            offset: 0usize,
256            pending_frames: VecDeque::new(),
257            current_frame_size: None,
258            pending_out_bytes: 0usize,
259            pending_out_frames: VecDeque::new(),
260            pending_out_frame: None,
261            size_vec: BytesMut::zeroed(10),
262        }
263    }
264
265    /// Create new [`Substream`] for TCP.
266    pub(crate) fn new_tcp(
267        peer: PeerId,
268        substream_id: SubstreamId,
269        substream: tcp::Substream,
270        codec: ProtocolCodec,
271    ) -> Self {
272        tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for tcp");
273
274        Self::new(peer, substream_id, SubstreamType::Tcp(substream), codec)
275    }
276
277    /// Create new [`Substream`] for WebSocket.
278    #[cfg(feature = "websocket")]
279    pub(crate) fn new_websocket(
280        peer: PeerId,
281        substream_id: SubstreamId,
282        substream: websocket::Substream,
283        codec: ProtocolCodec,
284    ) -> Self {
285        tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for websocket");
286
287        Self::new(
288            peer,
289            substream_id,
290            SubstreamType::WebSocket(substream),
291            codec,
292        )
293    }
294
295    /// Create new [`Substream`] for QUIC.
296    #[cfg(feature = "quic")]
297    pub(crate) fn new_quic(
298        peer: PeerId,
299        substream_id: SubstreamId,
300        substream: quic::Substream,
301        codec: ProtocolCodec,
302    ) -> Self {
303        tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for quic");
304
305        Self::new(peer, substream_id, SubstreamType::Quic(substream), codec)
306    }
307
308    /// Create new [`Substream`] for WebRTC.
309    #[cfg(feature = "webrtc")]
310    pub(crate) fn new_webrtc(
311        peer: PeerId,
312        substream_id: SubstreamId,
313        substream: webrtc::Substream,
314        codec: ProtocolCodec,
315    ) -> Self {
316        tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for webrtc");
317
318        Self::new(peer, substream_id, SubstreamType::WebRtc(substream), codec)
319    }
320
321    /// Create new [`Substream`] for mocking.
322    #[cfg(test)]
323    pub(crate) fn new_mock(
324        peer: PeerId,
325        substream_id: SubstreamId,
326        substream: Box<dyn crate::mock::substream::Substream>,
327    ) -> Self {
328        tracing::trace!(target: LOG_TARGET, ?peer, "create new substream for mocking");
329
330        Self::new(
331            peer,
332            substream_id,
333            SubstreamType::Mock(substream),
334            ProtocolCodec::Unspecified,
335        )
336    }
337
338    /// Close the substream.
339    pub async fn close(self) {
340        let _ = match self.substream {
341            SubstreamType::Tcp(mut substream) => substream.shutdown().await,
342            #[cfg(feature = "websocket")]
343            SubstreamType::WebSocket(mut substream) => substream.shutdown().await,
344            #[cfg(feature = "quic")]
345            SubstreamType::Quic(mut substream) => substream.shutdown().await,
346            #[cfg(feature = "webrtc")]
347            SubstreamType::WebRtc(mut substream) => substream.shutdown().await,
348            #[cfg(test)]
349            SubstreamType::Mock(mut substream) => {
350                let _ = futures::SinkExt::close(&mut substream).await;
351                Ok(())
352            }
353        };
354    }
355
356    /// Send identity payload to remote peer.
357    async fn send_identity_payload<T: AsyncWrite + Unpin>(
358        io: &mut T,
359        payload_size: usize,
360        payload: Bytes,
361    ) -> Result<(), SubstreamError> {
362        if payload.len() != payload_size {
363            return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
364        }
365
366        io.write_all(&payload).await.map_err(|_| SubstreamError::ConnectionClosed)?;
367
368        // Flush the stream.
369        io.flush().await.map_err(From::from)
370    }
371
372    /// Send unsigned varint payload to remote peer.
373    async fn send_unsigned_varint_payload<T: AsyncWrite + Unpin>(
374        io: &mut T,
375        bytes: Bytes,
376        max_size: Option<usize>,
377    ) -> Result<(), SubstreamError> {
378        if let Some(max_size) = max_size {
379            if bytes.len() > max_size {
380                return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
381            }
382        }
383
384        // Write the length of the frame.
385        let mut buffer = unsigned_varint::encode::usize_buffer();
386        let encoded_len = unsigned_varint::encode::usize(bytes.len(), &mut buffer).len();
387        io.write_all(&buffer[..encoded_len]).await?;
388
389        // Write the frame.
390        io.write_all(bytes.as_ref()).await?;
391
392        // Flush the stream.
393        io.flush().await.map_err(From::from)
394    }
395
396    /// Send framed data to remote peer.
397    ///
398    /// This function may be faster than the provided [`futures::Sink`] implementation for
399    /// [`Substream`] as it has direct access to the API of the underlying socket as opposed
400    /// to going through [`tokio::io::AsyncWrite`].
401    ///
402    /// # Cancel safety
403    ///
404    /// This method is not cancellation safe. If that is required, use the provided
405    /// [`futures::Sink`] implementation.
406    ///
407    /// # Panics
408    ///
409    /// Panics if no codec is provided.
410    pub async fn send_framed(&mut self, bytes: Bytes) -> Result<(), SubstreamError> {
411        tracing::trace!(
412            target: LOG_TARGET,
413            peer = ?self.peer,
414            codec = ?self.codec,
415            frame_len = ?bytes.len(),
416            "send framed"
417        );
418
419        match &mut self.substream {
420            #[cfg(test)]
421            SubstreamType::Mock(ref mut substream) =>
422                futures::SinkExt::send(substream, bytes).await,
423            SubstreamType::Tcp(ref mut substream) => match self.codec {
424                ProtocolCodec::Unspecified => panic!("codec is unspecified"),
425                ProtocolCodec::Identity(payload_size) =>
426                    Self::send_identity_payload(substream, payload_size, bytes).await,
427                ProtocolCodec::UnsignedVarint(max_size) =>
428                    Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
429            },
430            #[cfg(feature = "websocket")]
431            SubstreamType::WebSocket(ref mut substream) => match self.codec {
432                ProtocolCodec::Unspecified => panic!("codec is unspecified"),
433                ProtocolCodec::Identity(payload_size) =>
434                    Self::send_identity_payload(substream, payload_size, bytes).await,
435                ProtocolCodec::UnsignedVarint(max_size) =>
436                    Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
437            },
438            #[cfg(feature = "quic")]
439            SubstreamType::Quic(ref mut substream) => match self.codec {
440                ProtocolCodec::Unspecified => panic!("codec is unspecified"),
441                ProtocolCodec::Identity(payload_size) =>
442                    Self::send_identity_payload(substream, payload_size, bytes).await,
443                ProtocolCodec::UnsignedVarint(max_size) => {
444                    check_size!(max_size, bytes.len());
445
446                    let mut buffer = unsigned_varint::encode::usize_buffer();
447                    let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer);
448                    let len = BytesMut::from(len);
449
450                    substream.write_all_chunks(&mut [len.freeze(), bytes]).await
451                }
452            },
453            #[cfg(feature = "webrtc")]
454            SubstreamType::WebRtc(ref mut substream) => match self.codec {
455                ProtocolCodec::Unspecified => panic!("codec is unspecified"),
456                ProtocolCodec::Identity(payload_size) =>
457                    Self::send_identity_payload(substream, payload_size, bytes).await,
458                ProtocolCodec::UnsignedVarint(max_size) =>
459                    Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
460            },
461        }
462    }
463}
464
465impl tokio::io::AsyncRead for Substream {
466    fn poll_read(
467        mut self: Pin<&mut Self>,
468        cx: &mut Context<'_>,
469        buf: &mut tokio::io::ReadBuf<'_>,
470    ) -> Poll<std::io::Result<()>> {
471        poll_read!(&mut self.substream, cx, buf)
472    }
473}
474
475impl tokio::io::AsyncWrite for Substream {
476    fn poll_write(
477        mut self: Pin<&mut Self>,
478        cx: &mut Context<'_>,
479        buf: &[u8],
480    ) -> Poll<Result<usize, std::io::Error>> {
481        poll_write!(&mut self.substream, cx, buf)
482    }
483
484    fn poll_flush(
485        mut self: Pin<&mut Self>,
486        cx: &mut Context<'_>,
487    ) -> Poll<Result<(), std::io::Error>> {
488        poll_flush!(&mut self.substream, cx)
489    }
490
491    fn poll_shutdown(
492        mut self: Pin<&mut Self>,
493        cx: &mut Context<'_>,
494    ) -> Poll<Result<(), std::io::Error>> {
495        poll_shutdown!(&mut self.substream, cx)
496    }
497}
498
499enum ReadError {
500    Overflow,
501    NotEnoughBytes,
502    DecodeError,
503}
504
505// Return the payload size and the number of bytes it took to encode it
506fn read_payload_size(buffer: &[u8]) -> Result<(usize, usize), ReadError> {
507    let max_len = encode::usize_buffer().len();
508
509    for i in 0..std::cmp::min(buffer.len(), max_len) {
510        if decode::is_last(buffer[i]) {
511            match decode::usize(&buffer[..=i]) {
512                Err(_) => return Err(ReadError::DecodeError),
513                Ok(size) => return Ok((size.0, i + 1)),
514            }
515        }
516    }
517
518    match buffer.len() < max_len {
519        true => Err(ReadError::NotEnoughBytes),
520        false => Err(ReadError::Overflow),
521    }
522}
523
524impl Stream for Substream {
525    type Item = Result<BytesMut, SubstreamError>;
526
527    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
528        let this = Pin::into_inner(self);
529
530        // `MockSubstream` implements `Stream` so calls to `poll_next()` must be delegated
531        delegate_poll_next!(&mut this.substream, cx);
532
533        loop {
534            match this.codec {
535                ProtocolCodec::Identity(payload_size) => {
536                    let mut read_buf =
537                        ReadBuf::new(&mut this.read_buffer[this.offset..payload_size]);
538
539                    match futures::ready!(poll_read!(&mut this.substream, cx, &mut read_buf)) {
540                        Ok(_) => {
541                            let nread = read_buf.filled().len();
542                            if nread == 0 {
543                                tracing::trace!(
544                                    target: LOG_TARGET,
545                                    peer = ?this.peer,
546                                    "read zero bytes, substream closed"
547                                );
548                                return Poll::Ready(None);
549                            }
550
551                            this.offset = this.offset.saturating_add(nread);
552
553                            if this.offset == payload_size {
554                                let mut payload = std::mem::replace(
555                                    &mut this.read_buffer,
556                                    BytesMut::zeroed(payload_size),
557                                );
558                                payload.truncate(payload_size);
559                                this.offset = 0usize;
560
561                                return Poll::Ready(Some(Ok(payload)));
562                            }
563                        }
564                        Err(error) => return Poll::Ready(Some(Err(error.into()))),
565                    }
566                }
567                ProtocolCodec::UnsignedVarint(max_size) => {
568                    loop {
569                        // return all pending frames first
570                        if let Some(frame) = this.pending_frames.pop_front() {
571                            return Poll::Ready(Some(Ok(frame)));
572                        }
573
574                        match this.current_frame_size.take() {
575                            Some(frame_size) => {
576                                let mut read_buf =
577                                    ReadBuf::new(&mut this.read_buffer[this.offset..]);
578                                this.current_frame_size = Some(frame_size);
579
580                                match futures::ready!(poll_read!(
581                                    &mut this.substream,
582                                    cx,
583                                    &mut read_buf
584                                )) {
585                                    Err(_error) => return Poll::Ready(None),
586                                    Ok(_) => {
587                                        let nread = match read_buf.filled().len() {
588                                            0 => return Poll::Ready(None),
589                                            nread => nread,
590                                        };
591
592                                        this.offset += nread;
593
594                                        if this.offset == frame_size {
595                                            let out_frame = std::mem::replace(
596                                                &mut this.read_buffer,
597                                                BytesMut::new(),
598                                            );
599                                            this.offset = 0;
600                                            this.current_frame_size = None;
601
602                                            return Poll::Ready(Some(Ok(out_frame)));
603                                        } else {
604                                            this.current_frame_size = Some(frame_size);
605                                            continue;
606                                        }
607                                    }
608                                }
609                            }
610                            None => {
611                                let mut read_buf =
612                                    ReadBuf::new(&mut this.size_vec[this.offset..this.offset + 1]);
613
614                                match futures::ready!(poll_read!(
615                                    &mut this.substream,
616                                    cx,
617                                    &mut read_buf
618                                )) {
619                                    Err(_error) => return Poll::Ready(None),
620                                    Ok(_) => {
621                                        if read_buf.filled().is_empty() {
622                                            return Poll::Ready(None);
623                                        }
624                                        this.offset += 1;
625
626                                        match read_payload_size(&this.size_vec[..this.offset]) {
627                                            Err(ReadError::NotEnoughBytes) => continue,
628                                            Err(_) =>
629                                                return Poll::Ready(Some(Err(
630                                                    SubstreamError::ReadFailure(Some(
631                                                        this.substream_id,
632                                                    )),
633                                                ))),
634                                            Ok((size, num_bytes)) => {
635                                                debug_assert_eq!(num_bytes, this.offset);
636
637                                                if let Some(max_size) = max_size {
638                                                    if size > max_size {
639                                                        return Poll::Ready(Some(Err(
640                                                            SubstreamError::ReadFailure(Some(
641                                                                this.substream_id,
642                                                            )),
643                                                        )));
644                                                    }
645                                                }
646
647                                                this.offset = 0;
648                                                // Handle empty payloads detected as 0-length frame.
649                                                // The offset must be cleared to 0 to not interfere
650                                                // with next framing.
651                                                if size == 0 {
652                                                    return Poll::Ready(Some(Ok(BytesMut::new())));
653                                                }
654
655                                                this.current_frame_size = Some(size);
656                                                this.read_buffer = BytesMut::zeroed(size);
657                                            }
658                                        }
659                                    }
660                                }
661                            }
662                        }
663                    }
664                }
665                ProtocolCodec::Unspecified => panic!("codec is unspecified"),
666            }
667        }
668    }
669}
670
671// TODO: https://github.com/paritytech/litep2p/issues/341 this code can definitely be optimized
672impl Sink<Bytes> for Substream {
673    type Error = SubstreamError;
674
675    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
676        // `MockSubstream` implements `Sink` so calls to `poll_ready()` must be delegated
677        delegate_poll_ready!(&mut self.substream, cx);
678
679        if self.pending_out_bytes >= BACKPRESSURE_BOUNDARY {
680            // This attempts to empty 'pending_out_frames' into the socket.
681            match futures::Sink::poll_flush(self.as_mut(), cx) {
682                Poll::Ready(Ok(())) => {}
683                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
684                Poll::Pending => {
685                    // Still flushing. We cannot accept new data yet.
686                    return Poll::Pending;
687                }
688            }
689        }
690
691        Poll::Ready(Ok(()))
692    }
693
694    fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
695        // `MockSubstream` implements `Sink` so calls to `start_send()` must be delegated
696        delegate_start_send!(&mut self.substream, item);
697
698        tracing::trace!(
699            target: LOG_TARGET,
700            peer = ?self.peer,
701            substream_id = ?self.substream_id,
702            data_len = item.len(),
703            "Substream::start_send()",
704        );
705
706        match self.codec {
707            ProtocolCodec::Identity(payload_size) => {
708                if item.len() != payload_size {
709                    return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
710                }
711
712                self.pending_out_bytes += item.len();
713                self.pending_out_frames.push_back(item);
714            }
715            ProtocolCodec::UnsignedVarint(max_size) => {
716                check_size!(max_size, item.len());
717
718                let len = {
719                    let mut buffer = unsigned_varint::encode::usize_buffer();
720                    let len = unsigned_varint::encode::usize(item.len(), &mut buffer);
721                    BytesMut::from(len)
722                };
723
724                self.pending_out_bytes += len.len() + item.len();
725                self.pending_out_frames.push_back(len.freeze());
726                self.pending_out_frames.push_back(item);
727            }
728            ProtocolCodec::Unspecified => panic!("codec is unspecified"),
729        }
730
731        Ok(())
732    }
733
734    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
735        // `MockSubstream` implements `Sink` so calls to `poll_flush()` must be delegated
736        delegate_poll_flush!(&mut self.substream, cx);
737
738        loop {
739            let mut pending_frame = match self.pending_out_frame.take() {
740                Some(frame) => frame,
741                None => match self.pending_out_frames.pop_front() {
742                    Some(frame) => frame,
743                    None => break,
744                },
745            };
746
747            match poll_write!(&mut self.substream, cx, &pending_frame) {
748                Poll::Ready(Err(error)) => return Poll::Ready(Err(error.into())),
749                Poll::Pending => {
750                    self.pending_out_frame = Some(pending_frame);
751                    break;
752                }
753                Poll::Ready(Ok(nwritten)) => {
754                    pending_frame.advance(nwritten);
755
756                    // The number of pending bytes is reduced by the number of bytes written
757                    // to ensure that backpressure is properly handled.
758                    self.pending_out_bytes = self.pending_out_bytes.saturating_sub(nwritten);
759
760                    if !pending_frame.is_empty() {
761                        self.pending_out_frame = Some(pending_frame);
762                    }
763                }
764            }
765        }
766
767        poll_flush!(&mut self.substream, cx).map_err(From::from)
768    }
769
770    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
771        poll_shutdown!(&mut self.substream, cx).map_err(From::from)
772    }
773}
774
775/// Substream set key.
776pub trait SubstreamSetKey: Hash + Unpin + fmt::Debug + PartialEq + Eq + Copy {}
777
778impl<K: Hash + Unpin + fmt::Debug + PartialEq + Eq + Copy> SubstreamSetKey for K {}
779
780/// Substream set.
781// TODO: https://github.com/paritytech/litep2p/issues/342 remove this.
782#[derive(Debug, Default)]
783pub struct SubstreamSet<K, S>
784where
785    K: SubstreamSetKey,
786    S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
787{
788    substreams: HashMap<K, S>,
789}
790
791impl<K, S> SubstreamSet<K, S>
792where
793    K: SubstreamSetKey,
794    S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
795{
796    /// Create new [`SubstreamSet`].
797    pub fn new() -> Self {
798        Self {
799            substreams: HashMap::new(),
800        }
801    }
802
803    /// Add new substream to the set.
804    pub fn insert(&mut self, key: K, substream: S) {
805        match self.substreams.entry(key) {
806            Entry::Vacant(entry) => {
807                entry.insert(substream);
808            }
809            Entry::Occupied(_) => {
810                tracing::error!(?key, "substream already exists");
811                debug_assert!(false);
812            }
813        }
814    }
815
816    /// Remove substream from the set.
817    pub fn remove(&mut self, key: &K) -> Option<S> {
818        self.substreams.remove(key)
819    }
820
821    /// Get mutable reference to stored substream.
822    #[cfg(test)]
823    pub fn get_mut(&mut self, key: &K) -> Option<&mut S> {
824        self.substreams.get_mut(key)
825    }
826
827    /// Get size of [`SubstreamSet`].
828    pub fn len(&self) -> usize {
829        self.substreams.len()
830    }
831
832    /// Check if [`SubstreamSet`] is empty.
833    pub fn is_empty(&self) -> bool {
834        self.substreams.is_empty()
835    }
836}
837
838impl<K, S> Stream for SubstreamSet<K, S>
839where
840    K: SubstreamSetKey,
841    S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
842{
843    type Item = (K, <S as Stream>::Item);
844
845    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
846        let inner = Pin::into_inner(self);
847
848        for (key, mut substream) in inner.substreams.iter_mut() {
849            match Pin::new(&mut substream).poll_next(cx) {
850                Poll::Pending => continue,
851                Poll::Ready(Some(data)) => return Poll::Ready(Some((*key, data))),
852                Poll::Ready(None) =>
853                    return Poll::Ready(Some((*key, Err(SubstreamError::ConnectionClosed)))),
854            }
855        }
856
857        Poll::Pending
858    }
859}
860
861#[cfg(test)]
862mod tests {
863    use super::*;
864    use crate::{mock::substream::MockSubstream, PeerId};
865    use futures::{SinkExt, StreamExt};
866
867    #[test]
868    fn add_substream() {
869        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
870
871        let peer = PeerId::random();
872        let substream = MockSubstream::new();
873        set.insert(peer, substream);
874
875        let peer = PeerId::random();
876        let substream = MockSubstream::new();
877        set.insert(peer, substream);
878    }
879
880    #[test]
881    #[should_panic]
882    #[cfg(debug_assertions)]
883    fn add_same_peer_twice() {
884        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
885
886        let peer = PeerId::random();
887        let substream1 = MockSubstream::new();
888        let substream2 = MockSubstream::new();
889
890        set.insert(peer, substream1);
891        set.insert(peer, substream2);
892    }
893
894    #[test]
895    fn remove_substream() {
896        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
897
898        let peer1 = PeerId::random();
899        let substream1 = MockSubstream::new();
900        set.insert(peer1, substream1);
901
902        let peer2 = PeerId::random();
903        let substream2 = MockSubstream::new();
904        set.insert(peer2, substream2);
905
906        assert!(set.remove(&peer1).is_some());
907        assert!(set.remove(&peer2).is_some());
908        assert!(set.remove(&PeerId::random()).is_none());
909    }
910
911    #[tokio::test]
912    async fn poll_data_from_substream() {
913        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
914
915        let peer = PeerId::random();
916        let mut substream = MockSubstream::new();
917        substream
918            .expect_poll_next()
919            .times(1)
920            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
921        substream
922            .expect_poll_next()
923            .times(1)
924            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
925        substream.expect_poll_next().returning(|_| Poll::Pending);
926        set.insert(peer, substream);
927
928        let value = set.next().await.unwrap();
929        assert_eq!(value.0, peer);
930        assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
931
932        let value = set.next().await.unwrap();
933        assert_eq!(value.0, peer);
934        assert_eq!(value.1.unwrap(), BytesMut::from(&b"world"[..]));
935
936        assert!(futures::poll!(set.next()).is_pending());
937    }
938
939    #[tokio::test]
940    async fn substream_closed() {
941        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
942
943        let peer = PeerId::random();
944        let mut substream = MockSubstream::new();
945        substream
946            .expect_poll_next()
947            .times(1)
948            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
949        substream.expect_poll_next().times(1).return_once(|_| Poll::Ready(None));
950        substream.expect_poll_next().returning(|_| Poll::Pending);
951        set.insert(peer, substream);
952
953        let value = set.next().await.unwrap();
954        assert_eq!(value.0, peer);
955        assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
956
957        match set.next().await {
958            Some((exited_peer, Err(SubstreamError::ConnectionClosed))) => {
959                assert_eq!(peer, exited_peer);
960            }
961            _ => panic!("inavlid event received"),
962        }
963    }
964
965    #[tokio::test]
966    async fn get_mut_substream() {
967        let _ = tracing_subscriber::fmt()
968            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
969            .try_init();
970
971        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
972
973        let peer = PeerId::random();
974        let mut substream = MockSubstream::new();
975        substream
976            .expect_poll_next()
977            .times(1)
978            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
979        substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(())));
980        substream.expect_start_send().times(1).return_once(|_| Ok(()));
981        substream.expect_poll_flush().times(1).return_once(|_| Poll::Ready(Ok(())));
982        substream
983            .expect_poll_next()
984            .times(1)
985            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
986        substream.expect_poll_next().returning(|_| Poll::Pending);
987        set.insert(peer, substream);
988
989        let value = set.next().await.unwrap();
990        assert_eq!(value.0, peer);
991        assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
992
993        let substream = set.get_mut(&peer).unwrap();
994        substream.send(vec![1, 2, 3, 4].into()).await.unwrap();
995
996        let value = set.next().await.unwrap();
997        assert_eq!(value.0, peer);
998        assert_eq!(value.1.unwrap(), BytesMut::from(&b"world"[..]));
999
1000        // try to get non-existent substream
1001        assert!(set.get_mut(&PeerId::random()).is_none());
1002    }
1003
1004    #[tokio::test]
1005    async fn poll_data_from_two_substreams() {
1006        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
1007
1008        // prepare first substream
1009        let peer1 = PeerId::random();
1010        let mut substream1 = MockSubstream::new();
1011        substream1
1012            .expect_poll_next()
1013            .times(1)
1014            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
1015        substream1
1016            .expect_poll_next()
1017            .times(1)
1018            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
1019        substream1.expect_poll_next().returning(|_| Poll::Pending);
1020        set.insert(peer1, substream1);
1021
1022        // prepare second substream
1023        let peer2 = PeerId::random();
1024        let mut substream2 = MockSubstream::new();
1025        substream2
1026            .expect_poll_next()
1027            .times(1)
1028            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"siip"[..])))));
1029        substream2
1030            .expect_poll_next()
1031            .times(1)
1032            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"huup"[..])))));
1033        substream2.expect_poll_next().returning(|_| Poll::Pending);
1034        set.insert(peer2, substream2);
1035
1036        let expected: Vec<Vec<(PeerId, BytesMut)>> = vec![
1037            vec![
1038                (peer1, BytesMut::from(&b"hello"[..])),
1039                (peer1, BytesMut::from(&b"world"[..])),
1040                (peer2, BytesMut::from(&b"siip"[..])),
1041                (peer2, BytesMut::from(&b"huup"[..])),
1042            ],
1043            vec![
1044                (peer1, BytesMut::from(&b"hello"[..])),
1045                (peer2, BytesMut::from(&b"siip"[..])),
1046                (peer1, BytesMut::from(&b"world"[..])),
1047                (peer2, BytesMut::from(&b"huup"[..])),
1048            ],
1049            vec![
1050                (peer2, BytesMut::from(&b"siip"[..])),
1051                (peer2, BytesMut::from(&b"huup"[..])),
1052                (peer1, BytesMut::from(&b"hello"[..])),
1053                (peer1, BytesMut::from(&b"world"[..])),
1054            ],
1055            vec![
1056                (peer1, BytesMut::from(&b"hello"[..])),
1057                (peer2, BytesMut::from(&b"siip"[..])),
1058                (peer2, BytesMut::from(&b"huup"[..])),
1059                (peer1, BytesMut::from(&b"world"[..])),
1060            ],
1061        ];
1062
1063        // poll values
1064        let mut values = Vec::new();
1065
1066        for _ in 0..4 {
1067            let value = set.next().await.unwrap();
1068            values.push((value.0, value.1.unwrap()));
1069        }
1070
1071        let mut correct_found = false;
1072
1073        for set in expected {
1074            if values == set {
1075                correct_found = true;
1076                break;
1077            }
1078        }
1079
1080        if !correct_found {
1081            panic!("invalid set generated");
1082        }
1083
1084        // rest of the calls return `Poll::Pending`
1085        for _ in 0..10 {
1086            assert!(futures::poll!(set.next()).is_pending());
1087        }
1088    }
1089}