litep2p/substream/
mod.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2// Copyright 2023 litep2p developers
3//
4// Permission is hereby granted, free of charge, to any person obtaining a
5// copy of this software and associated documentation files (the "Software"),
6// to deal in the Software without restriction, including without limitation
7// the rights to use, copy, modify, merge, publish, distribute, sublicense,
8// and/or sell copies of the Software, and to permit persons to whom the
9// Software is furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20// DEALINGS IN THE SOFTWARE.
21
22//! Substream-related helper code.
23
24use crate::{
25    codec::ProtocolCodec, error::SubstreamError, transport::tcp, types::SubstreamId, PeerId,
26};
27
28#[cfg(feature = "quic")]
29use crate::transport::quic;
30#[cfg(feature = "webrtc")]
31use crate::transport::webrtc;
32#[cfg(feature = "websocket")]
33use crate::transport::websocket;
34
35use bytes::{Buf, Bytes, BytesMut};
36use futures::{Sink, Stream};
37use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
38use unsigned_varint::{decode, encode};
39
40use std::{
41    collections::{hash_map::Entry, HashMap, VecDeque},
42    fmt,
43    hash::Hash,
44    io::ErrorKind,
45    pin::Pin,
46    task::{Context, Poll},
47};
48
49/// Logging target for the file.
50const LOG_TARGET: &str = "litep2p::substream";
51
52macro_rules! poll_flush {
53    ($substream:expr, $cx:ident) => {{
54        match $substream {
55            SubstreamType::Tcp(substream) => Pin::new(substream).poll_flush($cx),
56            #[cfg(feature = "websocket")]
57            SubstreamType::WebSocket(substream) => Pin::new(substream).poll_flush($cx),
58            #[cfg(feature = "quic")]
59            SubstreamType::Quic(substream) => Pin::new(substream).poll_flush($cx),
60            #[cfg(feature = "webrtc")]
61            SubstreamType::WebRtc(substream) => Pin::new(substream).poll_flush($cx),
62            #[cfg(test)]
63            SubstreamType::Mock(_) => unreachable!(),
64        }
65    }};
66}
67
68macro_rules! poll_write {
69    ($substream:expr, $cx:ident, $frame:expr) => {{
70        match $substream {
71            SubstreamType::Tcp(substream) => Pin::new(substream).poll_write($cx, $frame),
72            #[cfg(feature = "websocket")]
73            SubstreamType::WebSocket(substream) => Pin::new(substream).poll_write($cx, $frame),
74            #[cfg(feature = "quic")]
75            SubstreamType::Quic(substream) => Pin::new(substream).poll_write($cx, $frame),
76            #[cfg(feature = "webrtc")]
77            SubstreamType::WebRtc(substream) => Pin::new(substream).poll_write($cx, $frame),
78            #[cfg(test)]
79            SubstreamType::Mock(_) => unreachable!(),
80        }
81    }};
82}
83
84macro_rules! poll_read {
85    ($substream:expr, $cx:ident, $buffer:expr) => {{
86        match $substream {
87            SubstreamType::Tcp(substream) => Pin::new(substream).poll_read($cx, $buffer),
88            #[cfg(feature = "websocket")]
89            SubstreamType::WebSocket(substream) => Pin::new(substream).poll_read($cx, $buffer),
90            #[cfg(feature = "quic")]
91            SubstreamType::Quic(substream) => Pin::new(substream).poll_read($cx, $buffer),
92            #[cfg(feature = "webrtc")]
93            SubstreamType::WebRtc(substream) => Pin::new(substream).poll_read($cx, $buffer),
94            #[cfg(test)]
95            SubstreamType::Mock(_) => unreachable!(),
96        }
97    }};
98}
99
100macro_rules! poll_shutdown {
101    ($substream:expr, $cx:ident) => {{
102        match $substream {
103            SubstreamType::Tcp(substream) => Pin::new(substream).poll_shutdown($cx),
104            #[cfg(feature = "websocket")]
105            SubstreamType::WebSocket(substream) => Pin::new(substream).poll_shutdown($cx),
106            #[cfg(feature = "quic")]
107            SubstreamType::Quic(substream) => Pin::new(substream).poll_shutdown($cx),
108            #[cfg(feature = "webrtc")]
109            SubstreamType::WebRtc(substream) => Pin::new(substream).poll_shutdown($cx),
110            #[cfg(test)]
111            SubstreamType::Mock(substream) => {
112                let _ = Pin::new(substream).poll_close($cx);
113                todo!();
114            }
115        }
116    }};
117}
118
119macro_rules! delegate_poll_next {
120    ($substream:expr, $cx:ident) => {{
121        #[cfg(test)]
122        if let SubstreamType::Mock(inner) = $substream {
123            return Pin::new(inner).poll_next($cx);
124        }
125    }};
126}
127
128macro_rules! delegate_poll_ready {
129    ($substream:expr, $cx:ident) => {{
130        #[cfg(test)]
131        if let SubstreamType::Mock(inner) = $substream {
132            return Pin::new(inner).poll_ready($cx);
133        }
134    }};
135}
136
137macro_rules! delegate_start_send {
138    ($substream:expr, $item:ident) => {{
139        #[cfg(test)]
140        if let SubstreamType::Mock(inner) = $substream {
141            return Pin::new(inner).start_send($item);
142        }
143    }};
144}
145
146macro_rules! delegate_poll_flush {
147    ($substream:expr, $cx:ident) => {{
148        #[cfg(test)]
149        if let SubstreamType::Mock(inner) = $substream {
150            return Pin::new(inner).poll_flush($cx);
151        }
152    }};
153}
154
155macro_rules! check_size {
156    ($max_size:expr, $size:expr) => {{
157        if let Some(max_size) = $max_size {
158            if $size > max_size {
159                return Err(SubstreamError::IoError(ErrorKind::PermissionDenied).into());
160            }
161        }
162    }};
163}
164
165/// Substream type.
166enum SubstreamType {
167    Tcp(tcp::Substream),
168    #[cfg(feature = "websocket")]
169    WebSocket(websocket::Substream),
170    #[cfg(feature = "quic")]
171    Quic(quic::Substream),
172    #[cfg(feature = "webrtc")]
173    WebRtc(webrtc::Substream),
174    #[cfg(test)]
175    Mock(Box<dyn crate::mock::substream::Substream>),
176}
177
178impl fmt::Debug for SubstreamType {
179    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180        match self {
181            Self::Tcp(_) => write!(f, "Tcp"),
182            #[cfg(feature = "websocket")]
183            Self::WebSocket(_) => write!(f, "WebSocket"),
184            #[cfg(feature = "quic")]
185            Self::Quic(_) => write!(f, "Quic"),
186            #[cfg(feature = "webrtc")]
187            Self::WebRtc(_) => write!(f, "WebRtc"),
188            #[cfg(test)]
189            Self::Mock(_) => write!(f, "Mock"),
190        }
191    }
192}
193
194/// Backpressure boundary for `Sink`.
195const BACKPRESSURE_BOUNDARY: usize = 65536;
196
197/// `Litep2p` substream type.
198///
199/// Implements [`tokio::io::AsyncRead`]/[`tokio::io::AsyncWrite`] traits which can be wrapped
200/// in a `Framed` to implement a custom codec.
201///
202/// In case a codec for the protocol was specified,
203/// [`Sink::send()`](futures::Sink)/[`Stream::next()`](futures::Stream) are also provided which
204/// implement the necessary framing to read/write codec-encoded messages from the underlying socket.
205pub struct Substream {
206    /// Remote peer ID.
207    peer: PeerId,
208
209    // Inner substream.
210    substream: SubstreamType,
211
212    /// Substream ID.
213    substream_id: SubstreamId,
214
215    /// Protocol codec.
216    codec: ProtocolCodec,
217
218    pending_out_frames: VecDeque<Bytes>,
219    pending_out_bytes: usize,
220    pending_out_frame: Option<Bytes>,
221
222    read_buffer: BytesMut,
223    offset: usize,
224    pending_frames: VecDeque<BytesMut>,
225    current_frame_size: Option<usize>,
226
227    size_vec: BytesMut,
228}
229
230impl fmt::Debug for Substream {
231    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232        f.debug_struct("Substream")
233            .field("peer", &self.peer)
234            .field("substream_id", &self.substream_id)
235            .field("codec", &self.codec)
236            .field("protocol", &self.substream)
237            .finish()
238    }
239}
240
241impl Substream {
242    /// Create new [`Substream`].
243    fn new(
244        peer: PeerId,
245        substream_id: SubstreamId,
246        substream: SubstreamType,
247        codec: ProtocolCodec,
248    ) -> Self {
249        Self {
250            peer,
251            substream,
252            codec,
253            substream_id,
254            read_buffer: BytesMut::zeroed(1024),
255            offset: 0usize,
256            pending_frames: VecDeque::new(),
257            current_frame_size: None,
258            pending_out_bytes: 0usize,
259            pending_out_frames: VecDeque::new(),
260            pending_out_frame: None,
261            size_vec: BytesMut::zeroed(10),
262        }
263    }
264
265    /// Create new [`Substream`] for TCP.
266    pub(crate) fn new_tcp(
267        peer: PeerId,
268        substream_id: SubstreamId,
269        substream: tcp::Substream,
270        codec: ProtocolCodec,
271    ) -> Self {
272        tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for tcp");
273
274        Self::new(peer, substream_id, SubstreamType::Tcp(substream), codec)
275    }
276
277    /// Create new [`Substream`] for WebSocket.
278    #[cfg(feature = "websocket")]
279    pub(crate) fn new_websocket(
280        peer: PeerId,
281        substream_id: SubstreamId,
282        substream: websocket::Substream,
283        codec: ProtocolCodec,
284    ) -> Self {
285        tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for websocket");
286
287        Self::new(
288            peer,
289            substream_id,
290            SubstreamType::WebSocket(substream),
291            codec,
292        )
293    }
294
295    /// Create new [`Substream`] for QUIC.
296    #[cfg(feature = "quic")]
297    pub(crate) fn new_quic(
298        peer: PeerId,
299        substream_id: SubstreamId,
300        substream: quic::Substream,
301        codec: ProtocolCodec,
302    ) -> Self {
303        tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for quic");
304
305        Self::new(peer, substream_id, SubstreamType::Quic(substream), codec)
306    }
307
308    /// Create new [`Substream`] for WebRTC.
309    #[cfg(feature = "webrtc")]
310    pub(crate) fn new_webrtc(
311        peer: PeerId,
312        substream_id: SubstreamId,
313        substream: webrtc::Substream,
314        codec: ProtocolCodec,
315    ) -> Self {
316        tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for webrtc");
317
318        Self::new(peer, substream_id, SubstreamType::WebRtc(substream), codec)
319    }
320
321    /// Create new [`Substream`] for mocking.
322    #[cfg(test)]
323    pub(crate) fn new_mock(
324        peer: PeerId,
325        substream_id: SubstreamId,
326        substream: Box<dyn crate::mock::substream::Substream>,
327    ) -> Self {
328        tracing::trace!(target: LOG_TARGET, ?peer, "create new substream for mocking");
329
330        Self::new(
331            peer,
332            substream_id,
333            SubstreamType::Mock(substream),
334            ProtocolCodec::Unspecified,
335        )
336    }
337
338    /// Close the substream.
339    pub async fn close(self) {
340        let _ = match self.substream {
341            SubstreamType::Tcp(mut substream) => substream.shutdown().await,
342            #[cfg(feature = "websocket")]
343            SubstreamType::WebSocket(mut substream) => substream.shutdown().await,
344            #[cfg(feature = "quic")]
345            SubstreamType::Quic(mut substream) => substream.shutdown().await,
346            #[cfg(feature = "webrtc")]
347            SubstreamType::WebRtc(mut substream) => substream.shutdown().await,
348            #[cfg(test)]
349            SubstreamType::Mock(mut substream) => {
350                let _ = futures::SinkExt::close(&mut substream).await;
351                Ok(())
352            }
353        };
354    }
355
356    /// Send identity payload to remote peer.
357    async fn send_identity_payload<T: AsyncWrite + Unpin>(
358        io: &mut T,
359        payload_size: usize,
360        payload: Bytes,
361    ) -> Result<(), SubstreamError> {
362        if payload.len() != payload_size {
363            return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
364        }
365
366        io.write_all(&payload).await.map_err(|_| SubstreamError::ConnectionClosed)?;
367
368        // Flush the stream.
369        io.flush().await.map_err(From::from)
370    }
371
372    /// Send unsigned varint payload to remote peer.
373    async fn send_unsigned_varint_payload<T: AsyncWrite + Unpin>(
374        io: &mut T,
375        bytes: Bytes,
376        max_size: Option<usize>,
377    ) -> Result<(), SubstreamError> {
378        if let Some(max_size) = max_size {
379            if bytes.len() > max_size {
380                return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
381            }
382        }
383
384        // Write the length of the frame.
385        let mut buffer = unsigned_varint::encode::usize_buffer();
386        let encoded_len = unsigned_varint::encode::usize(bytes.len(), &mut buffer).len();
387        io.write_all(&buffer[..encoded_len]).await?;
388
389        // Write the frame.
390        io.write_all(bytes.as_ref()).await?;
391
392        // Flush the stream.
393        io.flush().await.map_err(From::from)
394    }
395
396    /// Send framed data to remote peer.
397    ///
398    /// This function may be faster than the provided [`futures::Sink`] implementation for
399    /// [`Substream`] as it has direct access to the API of the underlying socket as opposed
400    /// to going through [`tokio::io::AsyncWrite`].
401    ///
402    /// # Cancel safety
403    ///
404    /// This method is not cancellation safe. If that is required, use the provided
405    /// [`futures::Sink`] implementation.
406    ///
407    /// # Panics
408    ///
409    /// Panics if no codec is provided.
410    pub async fn send_framed(&mut self, bytes: Bytes) -> Result<(), SubstreamError> {
411        tracing::trace!(
412            target: LOG_TARGET,
413            peer = ?self.peer,
414            codec = ?self.codec,
415            frame_len = ?bytes.len(),
416            "send framed"
417        );
418
419        match &mut self.substream {
420            #[cfg(test)]
421            SubstreamType::Mock(ref mut substream) =>
422                futures::SinkExt::send(substream, bytes).await,
423            SubstreamType::Tcp(ref mut substream) => match self.codec {
424                ProtocolCodec::Unspecified => panic!("codec is unspecified"),
425                ProtocolCodec::Identity(payload_size) =>
426                    Self::send_identity_payload(substream, payload_size, bytes).await,
427                ProtocolCodec::UnsignedVarint(max_size) =>
428                    Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
429            },
430            #[cfg(feature = "websocket")]
431            SubstreamType::WebSocket(ref mut substream) => match self.codec {
432                ProtocolCodec::Unspecified => panic!("codec is unspecified"),
433                ProtocolCodec::Identity(payload_size) =>
434                    Self::send_identity_payload(substream, payload_size, bytes).await,
435                ProtocolCodec::UnsignedVarint(max_size) =>
436                    Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
437            },
438            #[cfg(feature = "quic")]
439            SubstreamType::Quic(ref mut substream) => match self.codec {
440                ProtocolCodec::Unspecified => panic!("codec is unspecified"),
441                ProtocolCodec::Identity(payload_size) =>
442                    Self::send_identity_payload(substream, payload_size, bytes).await,
443                ProtocolCodec::UnsignedVarint(max_size) => {
444                    check_size!(max_size, bytes.len());
445
446                    let mut buffer = unsigned_varint::encode::usize_buffer();
447                    let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer);
448                    let len = BytesMut::from(len);
449
450                    substream.write_all_chunks(&mut [len.freeze(), bytes]).await
451                }
452            },
453            #[cfg(feature = "webrtc")]
454            SubstreamType::WebRtc(ref mut substream) => match self.codec {
455                ProtocolCodec::Unspecified => panic!("codec is unspecified"),
456                ProtocolCodec::Identity(payload_size) =>
457                    Self::send_identity_payload(substream, payload_size, bytes).await,
458                ProtocolCodec::UnsignedVarint(max_size) =>
459                    Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
460            },
461        }
462    }
463}
464
465impl tokio::io::AsyncRead for Substream {
466    fn poll_read(
467        mut self: Pin<&mut Self>,
468        cx: &mut Context<'_>,
469        buf: &mut tokio::io::ReadBuf<'_>,
470    ) -> Poll<std::io::Result<()>> {
471        poll_read!(&mut self.substream, cx, buf)
472    }
473}
474
475impl tokio::io::AsyncWrite for Substream {
476    fn poll_write(
477        mut self: Pin<&mut Self>,
478        cx: &mut Context<'_>,
479        buf: &[u8],
480    ) -> Poll<Result<usize, std::io::Error>> {
481        poll_write!(&mut self.substream, cx, buf)
482    }
483
484    fn poll_flush(
485        mut self: Pin<&mut Self>,
486        cx: &mut Context<'_>,
487    ) -> Poll<Result<(), std::io::Error>> {
488        poll_flush!(&mut self.substream, cx)
489    }
490
491    fn poll_shutdown(
492        mut self: Pin<&mut Self>,
493        cx: &mut Context<'_>,
494    ) -> Poll<Result<(), std::io::Error>> {
495        poll_shutdown!(&mut self.substream, cx)
496    }
497}
498
499enum ReadError {
500    Overflow,
501    NotEnoughBytes,
502    DecodeError,
503}
504
505// Return the payload size and the number of bytes it took to encode it
506fn read_payload_size(buffer: &[u8]) -> Result<(usize, usize), ReadError> {
507    let max_len = encode::usize_buffer().len();
508
509    for i in 0..std::cmp::min(buffer.len(), max_len) {
510        if decode::is_last(buffer[i]) {
511            match decode::usize(&buffer[..=i]) {
512                Err(_) => return Err(ReadError::DecodeError),
513                Ok(size) => return Ok((size.0, i + 1)),
514            }
515        }
516    }
517
518    match buffer.len() < max_len {
519        true => Err(ReadError::NotEnoughBytes),
520        false => Err(ReadError::Overflow),
521    }
522}
523
524impl Stream for Substream {
525    type Item = Result<BytesMut, SubstreamError>;
526
527    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
528        let this = Pin::into_inner(self);
529
530        // `MockSubstream` implements `Stream` so calls to `poll_next()` must be delegated
531        delegate_poll_next!(&mut this.substream, cx);
532
533        loop {
534            match this.codec {
535                ProtocolCodec::Identity(payload_size) => {
536                    let mut read_buf =
537                        ReadBuf::new(&mut this.read_buffer[this.offset..payload_size]);
538
539                    match futures::ready!(poll_read!(&mut this.substream, cx, &mut read_buf)) {
540                        Ok(_) => {
541                            let nread = read_buf.filled().len();
542                            if nread == 0 {
543                                tracing::trace!(
544                                    target: LOG_TARGET,
545                                    peer = ?this.peer,
546                                    "read zero bytes, substream closed"
547                                );
548                                return Poll::Ready(None);
549                            }
550
551                            if nread == payload_size {
552                                let mut payload = std::mem::replace(
553                                    &mut this.read_buffer,
554                                    BytesMut::zeroed(payload_size),
555                                );
556                                payload.truncate(payload_size);
557                                this.offset = 0usize;
558
559                                return Poll::Ready(Some(Ok(payload)));
560                            } else {
561                                this.offset += read_buf.filled().len();
562                            }
563                        }
564                        Err(error) => return Poll::Ready(Some(Err(error.into()))),
565                    }
566                }
567                ProtocolCodec::UnsignedVarint(max_size) => {
568                    loop {
569                        // return all pending frames first
570                        if let Some(frame) = this.pending_frames.pop_front() {
571                            return Poll::Ready(Some(Ok(frame)));
572                        }
573
574                        match this.current_frame_size.take() {
575                            Some(frame_size) => {
576                                let mut read_buf =
577                                    ReadBuf::new(&mut this.read_buffer[this.offset..]);
578                                this.current_frame_size = Some(frame_size);
579
580                                match futures::ready!(poll_read!(
581                                    &mut this.substream,
582                                    cx,
583                                    &mut read_buf
584                                )) {
585                                    Err(_error) => return Poll::Ready(None),
586                                    Ok(_) => {
587                                        let nread = match read_buf.filled().len() {
588                                            0 => return Poll::Ready(None),
589                                            nread => nread,
590                                        };
591
592                                        this.offset += nread;
593
594                                        if this.offset == frame_size {
595                                            let out_frame = std::mem::replace(
596                                                &mut this.read_buffer,
597                                                BytesMut::new(),
598                                            );
599                                            this.offset = 0;
600                                            this.current_frame_size = None;
601
602                                            return Poll::Ready(Some(Ok(out_frame)));
603                                        } else {
604                                            this.current_frame_size = Some(frame_size);
605                                            continue;
606                                        }
607                                    }
608                                }
609                            }
610                            None => {
611                                let mut read_buf =
612                                    ReadBuf::new(&mut this.size_vec[this.offset..this.offset + 1]);
613
614                                match futures::ready!(poll_read!(
615                                    &mut this.substream,
616                                    cx,
617                                    &mut read_buf
618                                )) {
619                                    Err(_error) => return Poll::Ready(None),
620                                    Ok(_) => {
621                                        if read_buf.filled().is_empty() {
622                                            return Poll::Ready(None);
623                                        }
624                                        this.offset += 1;
625
626                                        match read_payload_size(&this.size_vec[..this.offset]) {
627                                            Err(ReadError::NotEnoughBytes) => continue,
628                                            Err(_) =>
629                                                return Poll::Ready(Some(Err(
630                                                    SubstreamError::ReadFailure(Some(
631                                                        this.substream_id,
632                                                    )),
633                                                ))),
634                                            Ok((size, num_bytes)) => {
635                                                debug_assert_eq!(num_bytes, this.offset);
636
637                                                if let Some(max_size) = max_size {
638                                                    if size > max_size {
639                                                        return Poll::Ready(Some(Err(
640                                                            SubstreamError::ReadFailure(Some(
641                                                                this.substream_id,
642                                                            )),
643                                                        )));
644                                                    }
645                                                }
646
647                                                this.offset = 0;
648                                                // Handle empty payloads detected as 0-length frame.
649                                                // The offset must be cleared to 0 to not interfere
650                                                // with next framing.
651                                                if size == 0 {
652                                                    return Poll::Ready(Some(Ok(BytesMut::new())));
653                                                }
654
655                                                this.current_frame_size = Some(size);
656                                                this.read_buffer = BytesMut::zeroed(size);
657                                            }
658                                        }
659                                    }
660                                }
661                            }
662                        }
663                    }
664                }
665                ProtocolCodec::Unspecified => panic!("codec is unspecified"),
666            }
667        }
668    }
669}
670
671// TODO: https://github.com/paritytech/litep2p/issues/341 this code can definitely be optimized
672impl Sink<Bytes> for Substream {
673    type Error = SubstreamError;
674
675    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
676        // `MockSubstream` implements `Sink` so calls to `poll_ready()` must be delegated
677        delegate_poll_ready!(&mut self.substream, cx);
678
679        if self.pending_out_bytes >= BACKPRESSURE_BOUNDARY {
680            return poll_flush!(&mut self.substream, cx).map_err(From::from);
681        }
682
683        Poll::Ready(Ok(()))
684    }
685
686    fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
687        // `MockSubstream` implements `Sink` so calls to `start_send()` must be delegated
688        delegate_start_send!(&mut self.substream, item);
689
690        tracing::trace!(
691            target: LOG_TARGET,
692            peer = ?self.peer,
693            substream_id = ?self.substream_id,
694            data_len = item.len(),
695            "Substream::start_send()",
696        );
697
698        match self.codec {
699            ProtocolCodec::Identity(payload_size) => {
700                if item.len() != payload_size {
701                    return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
702                }
703
704                self.pending_out_bytes += item.len();
705                self.pending_out_frames.push_back(item);
706            }
707            ProtocolCodec::UnsignedVarint(max_size) => {
708                check_size!(max_size, item.len());
709
710                let len = {
711                    let mut buffer = unsigned_varint::encode::usize_buffer();
712                    let len = unsigned_varint::encode::usize(item.len(), &mut buffer);
713                    BytesMut::from(len)
714                };
715
716                self.pending_out_bytes += len.len() + item.len();
717                self.pending_out_frames.push_back(len.freeze());
718                self.pending_out_frames.push_back(item);
719            }
720            ProtocolCodec::Unspecified => panic!("codec is unspecified"),
721        }
722
723        Ok(())
724    }
725
726    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
727        // `MockSubstream` implements `Sink` so calls to `poll_flush()` must be delegated
728        delegate_poll_flush!(&mut self.substream, cx);
729
730        loop {
731            let mut pending_frame = match self.pending_out_frame.take() {
732                Some(frame) => frame,
733                None => match self.pending_out_frames.pop_front() {
734                    Some(frame) => frame,
735                    None => break,
736                },
737            };
738
739            match poll_write!(&mut self.substream, cx, &pending_frame) {
740                Poll::Ready(Err(error)) => return Poll::Ready(Err(error.into())),
741                Poll::Pending => {
742                    self.pending_out_frame = Some(pending_frame);
743                    break;
744                }
745                Poll::Ready(Ok(nwritten)) => {
746                    pending_frame.advance(nwritten);
747
748                    if !pending_frame.is_empty() {
749                        self.pending_out_frame = Some(pending_frame);
750                    }
751                }
752            }
753        }
754
755        poll_flush!(&mut self.substream, cx).map_err(From::from)
756    }
757
758    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
759        poll_shutdown!(&mut self.substream, cx).map_err(From::from)
760    }
761}
762
763/// Substream set key.
764pub trait SubstreamSetKey: Hash + Unpin + fmt::Debug + PartialEq + Eq + Copy {}
765
766impl<K: Hash + Unpin + fmt::Debug + PartialEq + Eq + Copy> SubstreamSetKey for K {}
767
768/// Substream set.
769// TODO: https://github.com/paritytech/litep2p/issues/342 remove this.
770#[derive(Debug, Default)]
771pub struct SubstreamSet<K, S>
772where
773    K: SubstreamSetKey,
774    S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
775{
776    substreams: HashMap<K, S>,
777}
778
779impl<K, S> SubstreamSet<K, S>
780where
781    K: SubstreamSetKey,
782    S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
783{
784    /// Create new [`SubstreamSet`].
785    pub fn new() -> Self {
786        Self {
787            substreams: HashMap::new(),
788        }
789    }
790
791    /// Add new substream to the set.
792    pub fn insert(&mut self, key: K, substream: S) {
793        match self.substreams.entry(key) {
794            Entry::Vacant(entry) => {
795                entry.insert(substream);
796            }
797            Entry::Occupied(_) => {
798                tracing::error!(?key, "substream already exists");
799                debug_assert!(false);
800            }
801        }
802    }
803
804    /// Remove substream from the set.
805    pub fn remove(&mut self, key: &K) -> Option<S> {
806        self.substreams.remove(key)
807    }
808
809    /// Get mutable reference to stored substream.
810    #[cfg(test)]
811    pub fn get_mut(&mut self, key: &K) -> Option<&mut S> {
812        self.substreams.get_mut(key)
813    }
814
815    /// Get size of [`SubstreamSet`].
816    pub fn len(&self) -> usize {
817        self.substreams.len()
818    }
819
820    /// Check if [`SubstreamSet`] is empty.
821    pub fn is_empty(&self) -> bool {
822        self.substreams.is_empty()
823    }
824}
825
826impl<K, S> Stream for SubstreamSet<K, S>
827where
828    K: SubstreamSetKey,
829    S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
830{
831    type Item = (K, <S as Stream>::Item);
832
833    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
834        let inner = Pin::into_inner(self);
835
836        for (key, mut substream) in inner.substreams.iter_mut() {
837            match Pin::new(&mut substream).poll_next(cx) {
838                Poll::Pending => continue,
839                Poll::Ready(Some(data)) => return Poll::Ready(Some((*key, data))),
840                Poll::Ready(None) =>
841                    return Poll::Ready(Some((*key, Err(SubstreamError::ConnectionClosed)))),
842            }
843        }
844
845        Poll::Pending
846    }
847}
848
849#[cfg(test)]
850mod tests {
851    use super::*;
852    use crate::{mock::substream::MockSubstream, PeerId};
853    use futures::{SinkExt, StreamExt};
854
855    #[test]
856    fn add_substream() {
857        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
858
859        let peer = PeerId::random();
860        let substream = MockSubstream::new();
861        set.insert(peer, substream);
862
863        let peer = PeerId::random();
864        let substream = MockSubstream::new();
865        set.insert(peer, substream);
866    }
867
868    #[test]
869    #[should_panic]
870    #[cfg(debug_assertions)]
871    fn add_same_peer_twice() {
872        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
873
874        let peer = PeerId::random();
875        let substream1 = MockSubstream::new();
876        let substream2 = MockSubstream::new();
877
878        set.insert(peer, substream1);
879        set.insert(peer, substream2);
880    }
881
882    #[test]
883    fn remove_substream() {
884        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
885
886        let peer1 = PeerId::random();
887        let substream1 = MockSubstream::new();
888        set.insert(peer1, substream1);
889
890        let peer2 = PeerId::random();
891        let substream2 = MockSubstream::new();
892        set.insert(peer2, substream2);
893
894        assert!(set.remove(&peer1).is_some());
895        assert!(set.remove(&peer2).is_some());
896        assert!(set.remove(&PeerId::random()).is_none());
897    }
898
899    #[tokio::test]
900    async fn poll_data_from_substream() {
901        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
902
903        let peer = PeerId::random();
904        let mut substream = MockSubstream::new();
905        substream
906            .expect_poll_next()
907            .times(1)
908            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
909        substream
910            .expect_poll_next()
911            .times(1)
912            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
913        substream.expect_poll_next().returning(|_| Poll::Pending);
914        set.insert(peer, substream);
915
916        let value = set.next().await.unwrap();
917        assert_eq!(value.0, peer);
918        assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
919
920        let value = set.next().await.unwrap();
921        assert_eq!(value.0, peer);
922        assert_eq!(value.1.unwrap(), BytesMut::from(&b"world"[..]));
923
924        assert!(futures::poll!(set.next()).is_pending());
925    }
926
927    #[tokio::test]
928    async fn substream_closed() {
929        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
930
931        let peer = PeerId::random();
932        let mut substream = MockSubstream::new();
933        substream
934            .expect_poll_next()
935            .times(1)
936            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
937        substream.expect_poll_next().times(1).return_once(|_| Poll::Ready(None));
938        substream.expect_poll_next().returning(|_| Poll::Pending);
939        set.insert(peer, substream);
940
941        let value = set.next().await.unwrap();
942        assert_eq!(value.0, peer);
943        assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
944
945        match set.next().await {
946            Some((exited_peer, Err(SubstreamError::ConnectionClosed))) => {
947                assert_eq!(peer, exited_peer);
948            }
949            _ => panic!("inavlid event received"),
950        }
951    }
952
953    #[tokio::test]
954    async fn get_mut_substream() {
955        let _ = tracing_subscriber::fmt()
956            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
957            .try_init();
958
959        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
960
961        let peer = PeerId::random();
962        let mut substream = MockSubstream::new();
963        substream
964            .expect_poll_next()
965            .times(1)
966            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
967        substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(())));
968        substream.expect_start_send().times(1).return_once(|_| Ok(()));
969        substream.expect_poll_flush().times(1).return_once(|_| Poll::Ready(Ok(())));
970        substream
971            .expect_poll_next()
972            .times(1)
973            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
974        substream.expect_poll_next().returning(|_| Poll::Pending);
975        set.insert(peer, substream);
976
977        let value = set.next().await.unwrap();
978        assert_eq!(value.0, peer);
979        assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
980
981        let substream = set.get_mut(&peer).unwrap();
982        substream.send(vec![1, 2, 3, 4].into()).await.unwrap();
983
984        let value = set.next().await.unwrap();
985        assert_eq!(value.0, peer);
986        assert_eq!(value.1.unwrap(), BytesMut::from(&b"world"[..]));
987
988        // try to get non-existent substream
989        assert!(set.get_mut(&PeerId::random()).is_none());
990    }
991
992    #[tokio::test]
993    async fn poll_data_from_two_substreams() {
994        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
995
996        // prepare first substream
997        let peer1 = PeerId::random();
998        let mut substream1 = MockSubstream::new();
999        substream1
1000            .expect_poll_next()
1001            .times(1)
1002            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
1003        substream1
1004            .expect_poll_next()
1005            .times(1)
1006            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
1007        substream1.expect_poll_next().returning(|_| Poll::Pending);
1008        set.insert(peer1, substream1);
1009
1010        // prepare second substream
1011        let peer2 = PeerId::random();
1012        let mut substream2 = MockSubstream::new();
1013        substream2
1014            .expect_poll_next()
1015            .times(1)
1016            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"siip"[..])))));
1017        substream2
1018            .expect_poll_next()
1019            .times(1)
1020            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"huup"[..])))));
1021        substream2.expect_poll_next().returning(|_| Poll::Pending);
1022        set.insert(peer2, substream2);
1023
1024        let expected: Vec<Vec<(PeerId, BytesMut)>> = vec![
1025            vec![
1026                (peer1, BytesMut::from(&b"hello"[..])),
1027                (peer1, BytesMut::from(&b"world"[..])),
1028                (peer2, BytesMut::from(&b"siip"[..])),
1029                (peer2, BytesMut::from(&b"huup"[..])),
1030            ],
1031            vec![
1032                (peer1, BytesMut::from(&b"hello"[..])),
1033                (peer2, BytesMut::from(&b"siip"[..])),
1034                (peer1, BytesMut::from(&b"world"[..])),
1035                (peer2, BytesMut::from(&b"huup"[..])),
1036            ],
1037            vec![
1038                (peer2, BytesMut::from(&b"siip"[..])),
1039                (peer2, BytesMut::from(&b"huup"[..])),
1040                (peer1, BytesMut::from(&b"hello"[..])),
1041                (peer1, BytesMut::from(&b"world"[..])),
1042            ],
1043            vec![
1044                (peer1, BytesMut::from(&b"hello"[..])),
1045                (peer2, BytesMut::from(&b"siip"[..])),
1046                (peer2, BytesMut::from(&b"huup"[..])),
1047                (peer1, BytesMut::from(&b"world"[..])),
1048            ],
1049        ];
1050
1051        // poll values
1052        let mut values = Vec::new();
1053
1054        for _ in 0..4 {
1055            let value = set.next().await.unwrap();
1056            values.push((value.0, value.1.unwrap()));
1057        }
1058
1059        let mut correct_found = false;
1060
1061        for set in expected {
1062            if values == set {
1063                correct_found = true;
1064                break;
1065            }
1066        }
1067
1068        if !correct_found {
1069            panic!("invalid set generated");
1070        }
1071
1072        // rest of the calls return `Poll::Pending`
1073        for _ in 0..10 {
1074            assert!(futures::poll!(set.next()).is_pending());
1075        }
1076    }
1077}