litep2p/substream/
mod.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2// Copyright 2023 litep2p developers
3//
4// Permission is hereby granted, free of charge, to any person obtaining a
5// copy of this software and associated documentation files (the "Software"),
6// to deal in the Software without restriction, including without limitation
7// the rights to use, copy, modify, merge, publish, distribute, sublicense,
8// and/or sell copies of the Software, and to permit persons to whom the
9// Software is furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20// DEALINGS IN THE SOFTWARE.
21
22//! Substream-related helper code.
23
24use crate::{
25    codec::ProtocolCodec, error::SubstreamError, transport::tcp, types::SubstreamId, PeerId,
26};
27
28#[cfg(feature = "quic")]
29use crate::transport::quic;
30#[cfg(feature = "webrtc")]
31use crate::transport::webrtc;
32#[cfg(feature = "websocket")]
33use crate::transport::websocket;
34
35use bytes::{Buf, Bytes, BytesMut};
36use futures::{Sink, Stream};
37use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
38use unsigned_varint::{decode, encode};
39
40use std::{
41    collections::{hash_map::Entry, HashMap, VecDeque},
42    fmt,
43    hash::Hash,
44    io::ErrorKind,
45    pin::Pin,
46    task::{Context, Poll},
47};
48
49/// Logging target for the file.
50const LOG_TARGET: &str = "litep2p::substream";
51
52macro_rules! poll_flush {
53    ($substream:expr, $cx:ident) => {{
54        match $substream {
55            SubstreamType::Tcp(substream) => Pin::new(substream).poll_flush($cx),
56            #[cfg(feature = "websocket")]
57            SubstreamType::WebSocket(substream) => Pin::new(substream).poll_flush($cx),
58            #[cfg(feature = "quic")]
59            SubstreamType::Quic(substream) => Pin::new(substream).poll_flush($cx),
60            #[cfg(feature = "webrtc")]
61            SubstreamType::WebRtc(substream) => Pin::new(substream).poll_flush($cx),
62            #[cfg(test)]
63            SubstreamType::Mock(_) => unreachable!(),
64        }
65    }};
66}
67
68macro_rules! poll_write {
69    ($substream:expr, $cx:ident, $frame:expr) => {{
70        match $substream {
71            SubstreamType::Tcp(substream) => Pin::new(substream).poll_write($cx, $frame),
72            #[cfg(feature = "websocket")]
73            SubstreamType::WebSocket(substream) => Pin::new(substream).poll_write($cx, $frame),
74            #[cfg(feature = "quic")]
75            SubstreamType::Quic(substream) => Pin::new(substream).poll_write($cx, $frame),
76            #[cfg(feature = "webrtc")]
77            SubstreamType::WebRtc(substream) => Pin::new(substream).poll_write($cx, $frame),
78            #[cfg(test)]
79            SubstreamType::Mock(_) => unreachable!(),
80        }
81    }};
82}
83
84macro_rules! poll_read {
85    ($substream:expr, $cx:ident, $buffer:expr) => {{
86        match $substream {
87            SubstreamType::Tcp(substream) => Pin::new(substream).poll_read($cx, $buffer),
88            #[cfg(feature = "websocket")]
89            SubstreamType::WebSocket(substream) => Pin::new(substream).poll_read($cx, $buffer),
90            #[cfg(feature = "quic")]
91            SubstreamType::Quic(substream) => Pin::new(substream).poll_read($cx, $buffer),
92            #[cfg(feature = "webrtc")]
93            SubstreamType::WebRtc(substream) => Pin::new(substream).poll_read($cx, $buffer),
94            #[cfg(test)]
95            SubstreamType::Mock(_) => unreachable!(),
96        }
97    }};
98}
99
100macro_rules! poll_shutdown {
101    ($substream:expr, $cx:ident) => {{
102        match $substream {
103            SubstreamType::Tcp(substream) => Pin::new(substream).poll_shutdown($cx),
104            #[cfg(feature = "websocket")]
105            SubstreamType::WebSocket(substream) => Pin::new(substream).poll_shutdown($cx),
106            #[cfg(feature = "quic")]
107            SubstreamType::Quic(substream) => Pin::new(substream).poll_shutdown($cx),
108            #[cfg(feature = "webrtc")]
109            SubstreamType::WebRtc(substream) => Pin::new(substream).poll_shutdown($cx),
110            #[cfg(test)]
111            SubstreamType::Mock(substream) => {
112                let _ = Pin::new(substream).poll_close($cx);
113                todo!();
114            }
115        }
116    }};
117}
118
119macro_rules! delegate_poll_next {
120    ($substream:expr, $cx:ident) => {{
121        #[cfg(test)]
122        if let SubstreamType::Mock(inner) = $substream {
123            return Pin::new(inner).poll_next($cx);
124        }
125    }};
126}
127
128macro_rules! delegate_poll_ready {
129    ($substream:expr, $cx:ident) => {{
130        #[cfg(test)]
131        if let SubstreamType::Mock(inner) = $substream {
132            return Pin::new(inner).poll_ready($cx);
133        }
134    }};
135}
136
137macro_rules! delegate_start_send {
138    ($substream:expr, $item:ident) => {{
139        #[cfg(test)]
140        if let SubstreamType::Mock(inner) = $substream {
141            return Pin::new(inner).start_send($item);
142        }
143    }};
144}
145
146macro_rules! delegate_poll_flush {
147    ($substream:expr, $cx:ident) => {{
148        #[cfg(test)]
149        if let SubstreamType::Mock(inner) = $substream {
150            return Pin::new(inner).poll_flush($cx);
151        }
152    }};
153}
154
155macro_rules! check_size {
156    ($max_size:expr, $size:expr) => {{
157        if let Some(max_size) = $max_size {
158            if $size > max_size {
159                return Err(SubstreamError::IoError(ErrorKind::PermissionDenied).into());
160            }
161        }
162    }};
163}
164
165/// Substream type.
166enum SubstreamType {
167    Tcp(tcp::Substream),
168    #[cfg(feature = "websocket")]
169    WebSocket(websocket::Substream),
170    #[cfg(feature = "quic")]
171    Quic(quic::Substream),
172    #[cfg(feature = "webrtc")]
173    WebRtc(webrtc::Substream),
174    #[cfg(test)]
175    Mock(Box<dyn crate::mock::substream::Substream>),
176}
177
178impl fmt::Debug for SubstreamType {
179    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180        match self {
181            Self::Tcp(_) => write!(f, "Tcp"),
182            #[cfg(feature = "websocket")]
183            Self::WebSocket(_) => write!(f, "WebSocket"),
184            #[cfg(feature = "quic")]
185            Self::Quic(_) => write!(f, "Quic"),
186            #[cfg(feature = "webrtc")]
187            Self::WebRtc(_) => write!(f, "WebRtc"),
188            #[cfg(test)]
189            Self::Mock(_) => write!(f, "Mock"),
190        }
191    }
192}
193
194/// Backpressure boundary for `Sink`.
195const BACKPRESSURE_BOUNDARY: usize = 65536;
196
197/// `Litep2p` substream type.
198///
199/// Implements [`tokio::io::AsyncRead`]/[`tokio::io::AsyncWrite`] traits which can be wrapped
200/// in a `Framed` to implement a custom codec.
201///
202/// In case a codec for the protocol was specified,
203/// [`Sink::send()`](futures::Sink)/[`Stream::next()`](futures::Stream) are also provided which
204/// implement the necessary framing to read/write codec-encoded messages from the underlying socket.
205pub struct Substream {
206    /// Remote peer ID.
207    peer: PeerId,
208
209    // Inner substream.
210    substream: SubstreamType,
211
212    /// Substream ID.
213    substream_id: SubstreamId,
214
215    /// Protocol codec.
216    codec: ProtocolCodec,
217
218    pending_out_frames: VecDeque<Bytes>,
219    pending_out_bytes: usize,
220    pending_out_frame: Option<Bytes>,
221
222    read_buffer: BytesMut,
223    offset: usize,
224    pending_frames: VecDeque<BytesMut>,
225    current_frame_size: Option<usize>,
226
227    size_vec: BytesMut,
228}
229
230impl fmt::Debug for Substream {
231    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232        f.debug_struct("Substream")
233            .field("peer", &self.peer)
234            .field("substream_id", &self.substream_id)
235            .field("codec", &self.codec)
236            .field("protocol", &self.substream)
237            .finish()
238    }
239}
240
241impl Substream {
242    /// Create new [`Substream`].
243    fn new(
244        peer: PeerId,
245        substream_id: SubstreamId,
246        substream: SubstreamType,
247        codec: ProtocolCodec,
248    ) -> Self {
249        Self {
250            peer,
251            substream,
252            codec,
253            substream_id,
254            read_buffer: BytesMut::zeroed(1024),
255            offset: 0usize,
256            pending_frames: VecDeque::new(),
257            current_frame_size: None,
258            pending_out_bytes: 0usize,
259            pending_out_frames: VecDeque::new(),
260            pending_out_frame: None,
261            size_vec: BytesMut::zeroed(10),
262        }
263    }
264
265    /// Create new [`Substream`] for TCP.
266    pub(crate) fn new_tcp(
267        peer: PeerId,
268        substream_id: SubstreamId,
269        substream: tcp::Substream,
270        codec: ProtocolCodec,
271    ) -> Self {
272        tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for tcp");
273
274        Self::new(peer, substream_id, SubstreamType::Tcp(substream), codec)
275    }
276
277    /// Create new [`Substream`] for WebSocket.
278    #[cfg(feature = "websocket")]
279    pub(crate) fn new_websocket(
280        peer: PeerId,
281        substream_id: SubstreamId,
282        substream: websocket::Substream,
283        codec: ProtocolCodec,
284    ) -> Self {
285        tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for websocket");
286
287        Self::new(
288            peer,
289            substream_id,
290            SubstreamType::WebSocket(substream),
291            codec,
292        )
293    }
294
295    /// Create new [`Substream`] for QUIC.
296    #[cfg(feature = "quic")]
297    pub(crate) fn new_quic(
298        peer: PeerId,
299        substream_id: SubstreamId,
300        substream: quic::Substream,
301        codec: ProtocolCodec,
302    ) -> Self {
303        tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for quic");
304
305        Self::new(peer, substream_id, SubstreamType::Quic(substream), codec)
306    }
307
308    /// Create new [`Substream`] for WebRTC.
309    #[cfg(feature = "webrtc")]
310    pub(crate) fn new_webrtc(
311        peer: PeerId,
312        substream_id: SubstreamId,
313        substream: webrtc::Substream,
314        codec: ProtocolCodec,
315    ) -> Self {
316        tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for webrtc");
317
318        Self::new(peer, substream_id, SubstreamType::WebRtc(substream), codec)
319    }
320
321    /// Create new [`Substream`] for mocking.
322    #[cfg(test)]
323    pub(crate) fn new_mock(
324        peer: PeerId,
325        substream_id: SubstreamId,
326        substream: Box<dyn crate::mock::substream::Substream>,
327    ) -> Self {
328        tracing::trace!(target: LOG_TARGET, ?peer, "create new substream for mocking");
329
330        Self::new(
331            peer,
332            substream_id,
333            SubstreamType::Mock(substream),
334            ProtocolCodec::Unspecified,
335        )
336    }
337
338    /// Close the substream.
339    pub async fn close(self) {
340        let _ = match self.substream {
341            SubstreamType::Tcp(mut substream) => substream.shutdown().await,
342            #[cfg(feature = "websocket")]
343            SubstreamType::WebSocket(mut substream) => substream.shutdown().await,
344            #[cfg(feature = "quic")]
345            SubstreamType::Quic(mut substream) => substream.shutdown().await,
346            #[cfg(feature = "webrtc")]
347            SubstreamType::WebRtc(mut substream) => substream.shutdown().await,
348            #[cfg(test)]
349            SubstreamType::Mock(mut substream) => {
350                let _ = futures::SinkExt::close(&mut substream).await;
351                Ok(())
352            }
353        };
354    }
355
356    /// Send identity payload to remote peer.
357    async fn send_identity_payload<T: AsyncWrite + Unpin>(
358        io: &mut T,
359        payload_size: usize,
360        payload: Bytes,
361    ) -> Result<(), SubstreamError> {
362        if payload.len() != payload_size {
363            return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
364        }
365
366        io.write_all(&payload).await.map_err(|_| SubstreamError::ConnectionClosed)?;
367
368        // Flush the stream.
369        io.flush().await.map_err(From::from)
370    }
371
372    /// Send unsigned varint payload to remote peer.
373    async fn send_unsigned_varint_payload<T: AsyncWrite + Unpin>(
374        io: &mut T,
375        bytes: Bytes,
376        max_size: Option<usize>,
377    ) -> Result<(), SubstreamError> {
378        if let Some(max_size) = max_size {
379            if bytes.len() > max_size {
380                return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
381            }
382        }
383
384        // Write the length of the frame.
385        let mut buffer = unsigned_varint::encode::usize_buffer();
386        let encoded_len = unsigned_varint::encode::usize(bytes.len(), &mut buffer).len();
387        io.write_all(&buffer[..encoded_len]).await?;
388
389        // Write the frame.
390        io.write_all(bytes.as_ref()).await?;
391
392        // Flush the stream.
393        io.flush().await.map_err(From::from)
394    }
395
396    /// Send framed data to remote peer.
397    ///
398    /// This function may be faster than the provided [`futures::Sink`] implementation for
399    /// [`Substream`] as it has direct access to the API of the underlying socket as opposed
400    /// to going through [`tokio::io::AsyncWrite`].
401    ///
402    /// # Cancel safety
403    ///
404    /// This method is not cancellation safe. If that is required, use the provided
405    /// [`futures::Sink`] implementation.
406    ///
407    /// # Panics
408    ///
409    /// Panics if no codec is provided.
410    pub async fn send_framed(&mut self, bytes: Bytes) -> Result<(), SubstreamError> {
411        tracing::trace!(
412            target: LOG_TARGET,
413            peer = ?self.peer,
414            codec = ?self.codec,
415            frame_len = ?bytes.len(),
416            "send framed"
417        );
418
419        match &mut self.substream {
420            #[cfg(test)]
421            SubstreamType::Mock(ref mut substream) =>
422                futures::SinkExt::send(substream, bytes).await.map_err(Into::into),
423            SubstreamType::Tcp(ref mut substream) => match self.codec {
424                ProtocolCodec::Unspecified => panic!("codec is unspecified"),
425                ProtocolCodec::Identity(payload_size) =>
426                    Self::send_identity_payload(substream, payload_size, bytes).await,
427                ProtocolCodec::UnsignedVarint(max_size) =>
428                    Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
429            },
430            #[cfg(feature = "websocket")]
431            SubstreamType::WebSocket(ref mut substream) => match self.codec {
432                ProtocolCodec::Unspecified => panic!("codec is unspecified"),
433                ProtocolCodec::Identity(payload_size) =>
434                    Self::send_identity_payload(substream, payload_size, bytes).await,
435                ProtocolCodec::UnsignedVarint(max_size) =>
436                    Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
437            },
438            #[cfg(feature = "quic")]
439            SubstreamType::Quic(ref mut substream) => match self.codec {
440                ProtocolCodec::Unspecified => panic!("codec is unspecified"),
441                ProtocolCodec::Identity(payload_size) =>
442                    Self::send_identity_payload(substream, payload_size, bytes).await,
443                ProtocolCodec::UnsignedVarint(max_size) => {
444                    check_size!(max_size, bytes.len());
445
446                    let mut buffer = unsigned_varint::encode::usize_buffer();
447                    let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer);
448                    let len = BytesMut::from(len);
449
450                    substream.write_all_chunks(&mut [len.freeze(), bytes]).await
451                }
452            },
453            #[cfg(feature = "webrtc")]
454            SubstreamType::WebRtc(ref mut substream) => match self.codec {
455                ProtocolCodec::Unspecified => panic!("codec is unspecified"),
456                ProtocolCodec::Identity(payload_size) =>
457                    Self::send_identity_payload(substream, payload_size, bytes).await,
458                ProtocolCodec::UnsignedVarint(max_size) =>
459                    Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
460            },
461        }
462    }
463}
464
465impl tokio::io::AsyncRead for Substream {
466    fn poll_read(
467        mut self: Pin<&mut Self>,
468        cx: &mut Context<'_>,
469        buf: &mut tokio::io::ReadBuf<'_>,
470    ) -> Poll<std::io::Result<()>> {
471        poll_read!(&mut self.substream, cx, buf)
472    }
473}
474
475impl tokio::io::AsyncWrite for Substream {
476    fn poll_write(
477        mut self: Pin<&mut Self>,
478        cx: &mut Context<'_>,
479        buf: &[u8],
480    ) -> Poll<Result<usize, std::io::Error>> {
481        poll_write!(&mut self.substream, cx, buf)
482    }
483
484    fn poll_flush(
485        mut self: Pin<&mut Self>,
486        cx: &mut Context<'_>,
487    ) -> Poll<Result<(), std::io::Error>> {
488        poll_flush!(&mut self.substream, cx)
489    }
490
491    fn poll_shutdown(
492        mut self: Pin<&mut Self>,
493        cx: &mut Context<'_>,
494    ) -> Poll<Result<(), std::io::Error>> {
495        poll_shutdown!(&mut self.substream, cx)
496    }
497}
498
499enum ReadError {
500    Overflow,
501    NotEnoughBytes,
502    DecodeError,
503}
504
505// Return the payload size and the number of bytes it took to encode it
506fn read_payload_size(buffer: &[u8]) -> Result<(usize, usize), ReadError> {
507    let max_len = encode::usize_buffer().len();
508
509    for i in 0..std::cmp::min(buffer.len(), max_len) {
510        if decode::is_last(buffer[i]) {
511            match decode::usize(&buffer[..=i]) {
512                Err(_) => return Err(ReadError::DecodeError),
513                Ok(size) => return Ok((size.0, i + 1)),
514            }
515        }
516    }
517
518    match buffer.len() < max_len {
519        true => Err(ReadError::NotEnoughBytes),
520        false => Err(ReadError::Overflow),
521    }
522}
523
524impl Stream for Substream {
525    type Item = Result<BytesMut, SubstreamError>;
526
527    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
528        let this = Pin::into_inner(self);
529
530        // `MockSubstream` implements `Stream` so calls to `poll_next()` must be delegated
531        delegate_poll_next!(&mut this.substream, cx);
532
533        loop {
534            match this.codec {
535                ProtocolCodec::Identity(payload_size) => {
536                    let mut read_buf =
537                        ReadBuf::new(&mut this.read_buffer[this.offset..payload_size]);
538
539                    match futures::ready!(poll_read!(&mut this.substream, cx, &mut read_buf)) {
540                        Ok(_) => {
541                            let nread = read_buf.filled().len();
542                            if nread == 0 {
543                                tracing::trace!(
544                                    target: LOG_TARGET,
545                                    peer = ?this.peer,
546                                    "read zero bytes, substream closed"
547                                );
548                                return Poll::Ready(None);
549                            }
550
551                            if nread == payload_size {
552                                let mut payload = std::mem::replace(
553                                    &mut this.read_buffer,
554                                    BytesMut::zeroed(payload_size),
555                                );
556                                payload.truncate(payload_size);
557                                this.offset = 0usize;
558
559                                return Poll::Ready(Some(Ok(payload)));
560                            } else {
561                                this.offset += read_buf.filled().len();
562                            }
563                        }
564                        Err(error) => return Poll::Ready(Some(Err(error.into()))),
565                    }
566                }
567                ProtocolCodec::UnsignedVarint(max_size) => {
568                    loop {
569                        // return all pending frames first
570                        if let Some(frame) = this.pending_frames.pop_front() {
571                            return Poll::Ready(Some(Ok(frame)));
572                        }
573
574                        match this.current_frame_size.take() {
575                            Some(frame_size) => {
576                                let mut read_buf =
577                                    ReadBuf::new(&mut this.read_buffer[this.offset..]);
578                                this.current_frame_size = Some(frame_size);
579
580                                match futures::ready!(poll_read!(
581                                    &mut this.substream,
582                                    cx,
583                                    &mut read_buf
584                                )) {
585                                    Err(_error) => return Poll::Ready(None),
586                                    Ok(_) => {
587                                        let nread = match read_buf.filled().len() {
588                                            0 => return Poll::Ready(None),
589                                            nread => nread,
590                                        };
591
592                                        this.offset += nread;
593
594                                        if this.offset == frame_size {
595                                            let out_frame = std::mem::replace(
596                                                &mut this.read_buffer,
597                                                BytesMut::new(),
598                                            );
599                                            this.offset = 0;
600                                            this.current_frame_size = None;
601
602                                            return Poll::Ready(Some(Ok(out_frame)));
603                                        } else {
604                                            this.current_frame_size = Some(frame_size);
605                                            continue;
606                                        }
607                                    }
608                                }
609                            }
610                            None => {
611                                let mut read_buf =
612                                    ReadBuf::new(&mut this.size_vec[this.offset..this.offset + 1]);
613
614                                match futures::ready!(poll_read!(
615                                    &mut this.substream,
616                                    cx,
617                                    &mut read_buf
618                                )) {
619                                    Err(_error) => return Poll::Ready(None),
620                                    Ok(_) => {
621                                        if read_buf.filled().is_empty() {
622                                            return Poll::Ready(None);
623                                        }
624                                        this.offset += 1;
625
626                                        match read_payload_size(&this.size_vec[..this.offset]) {
627                                            Err(ReadError::NotEnoughBytes) => continue,
628                                            Err(_) =>
629                                                return Poll::Ready(Some(Err(
630                                                    SubstreamError::ReadFailure(Some(
631                                                        this.substream_id,
632                                                    )),
633                                                ))),
634                                            Ok((size, num_bytes)) => {
635                                                debug_assert_eq!(num_bytes, this.offset);
636
637                                                if let Some(max_size) = max_size {
638                                                    if size > max_size {
639                                                        return Poll::Ready(Some(Err(
640                                                            SubstreamError::ReadFailure(Some(
641                                                                this.substream_id,
642                                                            )),
643                                                        )));
644                                                    }
645                                                }
646
647                                                this.offset = 0;
648                                                this.current_frame_size = Some(size);
649                                                this.read_buffer = BytesMut::zeroed(size);
650                                            }
651                                        }
652                                    }
653                                }
654                            }
655                        }
656                    }
657                }
658                ProtocolCodec::Unspecified => panic!("codec is unspecified"),
659            }
660        }
661    }
662}
663
664// TODO: this code can definitely be optimized
665impl Sink<Bytes> for Substream {
666    type Error = SubstreamError;
667
668    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
669        // `MockSubstream` implements `Sink` so calls to `poll_ready()` must be delegated
670        delegate_poll_ready!(&mut self.substream, cx);
671
672        if self.pending_out_bytes >= BACKPRESSURE_BOUNDARY {
673            return poll_flush!(&mut self.substream, cx).map_err(From::from);
674        }
675
676        Poll::Ready(Ok(()))
677    }
678
679    fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
680        // `MockSubstream` implements `Sink` so calls to `start_send()` must be delegated
681        delegate_start_send!(&mut self.substream, item);
682
683        tracing::trace!(
684            target: LOG_TARGET,
685            peer = ?self.peer,
686            substream_id = ?self.substream_id,
687            data_len = item.len(),
688            "Substream::start_send()",
689        );
690
691        match self.codec {
692            ProtocolCodec::Identity(payload_size) => {
693                if item.len() != payload_size {
694                    return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
695                }
696
697                self.pending_out_bytes += item.len();
698                self.pending_out_frames.push_back(item);
699            }
700            ProtocolCodec::UnsignedVarint(max_size) => {
701                check_size!(max_size, item.len());
702
703                let len = {
704                    let mut buffer = unsigned_varint::encode::usize_buffer();
705                    let len = unsigned_varint::encode::usize(item.len(), &mut buffer);
706                    BytesMut::from(len)
707                };
708
709                self.pending_out_bytes += len.len() + item.len();
710                self.pending_out_frames.push_back(len.freeze());
711                self.pending_out_frames.push_back(item);
712            }
713            ProtocolCodec::Unspecified => panic!("codec is unspecified"),
714        }
715
716        Ok(())
717    }
718
719    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
720        // `MockSubstream` implements `Sink` so calls to `poll_flush()` must be delegated
721        delegate_poll_flush!(&mut self.substream, cx);
722
723        loop {
724            let mut pending_frame = match self.pending_out_frame.take() {
725                Some(frame) => frame,
726                None => match self.pending_out_frames.pop_front() {
727                    Some(frame) => frame,
728                    None => break,
729                },
730            };
731
732            match poll_write!(&mut self.substream, cx, &pending_frame) {
733                Poll::Ready(Err(error)) => return Poll::Ready(Err(error.into())),
734                Poll::Pending => {
735                    self.pending_out_frame = Some(pending_frame);
736                    break;
737                }
738                Poll::Ready(Ok(nwritten)) => {
739                    pending_frame.advance(nwritten);
740
741                    if !pending_frame.is_empty() {
742                        self.pending_out_frame = Some(pending_frame);
743                    }
744                }
745            }
746        }
747
748        poll_flush!(&mut self.substream, cx).map_err(From::from)
749    }
750
751    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
752        poll_shutdown!(&mut self.substream, cx).map_err(From::from)
753    }
754}
755
756/// Substream set key.
757pub trait SubstreamSetKey: Hash + Unpin + fmt::Debug + PartialEq + Eq + Copy {}
758
759impl<K: Hash + Unpin + fmt::Debug + PartialEq + Eq + Copy> SubstreamSetKey for K {}
760
761/// Substream set.
762#[derive(Debug, Default)]
763pub struct SubstreamSet<K, S>
764where
765    K: SubstreamSetKey,
766    S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
767{
768    substreams: HashMap<K, S>,
769}
770
771impl<K, S> SubstreamSet<K, S>
772where
773    K: SubstreamSetKey,
774    S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
775{
776    /// Create new [`SubstreamSet`].
777    pub fn new() -> Self {
778        Self {
779            substreams: HashMap::new(),
780        }
781    }
782
783    /// Add new substream to the set.
784    pub fn insert(&mut self, key: K, substream: S) {
785        match self.substreams.entry(key) {
786            Entry::Vacant(entry) => {
787                entry.insert(substream);
788            }
789            Entry::Occupied(_) => {
790                tracing::error!(?key, "substream already exists");
791                debug_assert!(false);
792            }
793        }
794    }
795
796    /// Remove substream from the set.
797    pub fn remove(&mut self, key: &K) -> Option<S> {
798        self.substreams.remove(key)
799    }
800
801    /// Get mutable reference to stored substream.
802    #[cfg(test)]
803    pub fn get_mut(&mut self, key: &K) -> Option<&mut S> {
804        self.substreams.get_mut(key)
805    }
806
807    /// Get size of [`SubstreamSet`].
808    pub fn len(&self) -> usize {
809        self.substreams.len()
810    }
811
812    /// Check if [`SubstreamSet`] is empty.
813    pub fn is_empty(&self) -> bool {
814        self.substreams.is_empty()
815    }
816}
817
818impl<K, S> Stream for SubstreamSet<K, S>
819where
820    K: SubstreamSetKey,
821    S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
822{
823    type Item = (K, <S as Stream>::Item);
824
825    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
826        let inner = Pin::into_inner(self);
827
828        // TODO: poll the streams more randomly
829        for (key, mut substream) in inner.substreams.iter_mut() {
830            match Pin::new(&mut substream).poll_next(cx) {
831                Poll::Pending => continue,
832                Poll::Ready(Some(data)) => return Poll::Ready(Some((*key, data))),
833                Poll::Ready(None) =>
834                    return Poll::Ready(Some((*key, Err(SubstreamError::ConnectionClosed)))),
835            }
836        }
837
838        Poll::Pending
839    }
840}
841
842#[cfg(test)]
843mod tests {
844    use super::*;
845    use crate::{mock::substream::MockSubstream, PeerId};
846    use futures::{SinkExt, StreamExt};
847
848    #[test]
849    fn add_substream() {
850        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
851
852        let peer = PeerId::random();
853        let substream = MockSubstream::new();
854        set.insert(peer, substream);
855
856        let peer = PeerId::random();
857        let substream = MockSubstream::new();
858        set.insert(peer, substream);
859    }
860
861    #[test]
862    #[should_panic]
863    #[cfg(debug_assertions)]
864    fn add_same_peer_twice() {
865        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
866
867        let peer = PeerId::random();
868        let substream1 = MockSubstream::new();
869        let substream2 = MockSubstream::new();
870
871        set.insert(peer, substream1);
872        set.insert(peer, substream2);
873    }
874
875    #[test]
876    fn remove_substream() {
877        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
878
879        let peer1 = PeerId::random();
880        let substream1 = MockSubstream::new();
881        set.insert(peer1, substream1);
882
883        let peer2 = PeerId::random();
884        let substream2 = MockSubstream::new();
885        set.insert(peer2, substream2);
886
887        assert!(set.remove(&peer1).is_some());
888        assert!(set.remove(&peer2).is_some());
889        assert!(set.remove(&PeerId::random()).is_none());
890    }
891
892    #[tokio::test]
893    async fn poll_data_from_substream() {
894        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
895
896        let peer = PeerId::random();
897        let mut substream = MockSubstream::new();
898        substream
899            .expect_poll_next()
900            .times(1)
901            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
902        substream
903            .expect_poll_next()
904            .times(1)
905            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
906        substream.expect_poll_next().returning(|_| Poll::Pending);
907        set.insert(peer, substream);
908
909        let value = set.next().await.unwrap();
910        assert_eq!(value.0, peer);
911        assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
912
913        let value = set.next().await.unwrap();
914        assert_eq!(value.0, peer);
915        assert_eq!(value.1.unwrap(), BytesMut::from(&b"world"[..]));
916
917        assert!(futures::poll!(set.next()).is_pending());
918    }
919
920    #[tokio::test]
921    async fn substream_closed() {
922        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
923
924        let peer = PeerId::random();
925        let mut substream = MockSubstream::new();
926        substream
927            .expect_poll_next()
928            .times(1)
929            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
930        substream.expect_poll_next().times(1).return_once(|_| Poll::Ready(None));
931        substream.expect_poll_next().returning(|_| Poll::Pending);
932        set.insert(peer, substream);
933
934        let value = set.next().await.unwrap();
935        assert_eq!(value.0, peer);
936        assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
937
938        match set.next().await {
939            Some((exited_peer, Err(SubstreamError::ConnectionClosed))) => {
940                assert_eq!(peer, exited_peer);
941            }
942            _ => panic!("inavlid event received"),
943        }
944    }
945
946    #[tokio::test]
947    async fn get_mut_substream() {
948        let _ = tracing_subscriber::fmt()
949            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
950            .try_init();
951
952        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
953
954        let peer = PeerId::random();
955        let mut substream = MockSubstream::new();
956        substream
957            .expect_poll_next()
958            .times(1)
959            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
960        substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(())));
961        substream.expect_start_send().times(1).return_once(|_| Ok(()));
962        substream.expect_poll_flush().times(1).return_once(|_| Poll::Ready(Ok(())));
963        substream
964            .expect_poll_next()
965            .times(1)
966            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
967        substream.expect_poll_next().returning(|_| Poll::Pending);
968        set.insert(peer, substream);
969
970        let value = set.next().await.unwrap();
971        assert_eq!(value.0, peer);
972        assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
973
974        let substream = set.get_mut(&peer).unwrap();
975        substream.send(vec![1, 2, 3, 4].into()).await.unwrap();
976
977        let value = set.next().await.unwrap();
978        assert_eq!(value.0, peer);
979        assert_eq!(value.1.unwrap(), BytesMut::from(&b"world"[..]));
980
981        // try to get non-existent substream
982        assert!(set.get_mut(&PeerId::random()).is_none());
983    }
984
985    #[tokio::test]
986    async fn poll_data_from_two_substreams() {
987        let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
988
989        // prepare first substream
990        let peer1 = PeerId::random();
991        let mut substream1 = MockSubstream::new();
992        substream1
993            .expect_poll_next()
994            .times(1)
995            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
996        substream1
997            .expect_poll_next()
998            .times(1)
999            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
1000        substream1.expect_poll_next().returning(|_| Poll::Pending);
1001        set.insert(peer1, substream1);
1002
1003        // prepare second substream
1004        let peer2 = PeerId::random();
1005        let mut substream2 = MockSubstream::new();
1006        substream2
1007            .expect_poll_next()
1008            .times(1)
1009            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"siip"[..])))));
1010        substream2
1011            .expect_poll_next()
1012            .times(1)
1013            .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"huup"[..])))));
1014        substream2.expect_poll_next().returning(|_| Poll::Pending);
1015        set.insert(peer2, substream2);
1016
1017        let expected: Vec<Vec<(PeerId, BytesMut)>> = vec![
1018            vec![
1019                (peer1, BytesMut::from(&b"hello"[..])),
1020                (peer1, BytesMut::from(&b"world"[..])),
1021                (peer2, BytesMut::from(&b"siip"[..])),
1022                (peer2, BytesMut::from(&b"huup"[..])),
1023            ],
1024            vec![
1025                (peer1, BytesMut::from(&b"hello"[..])),
1026                (peer2, BytesMut::from(&b"siip"[..])),
1027                (peer1, BytesMut::from(&b"world"[..])),
1028                (peer2, BytesMut::from(&b"huup"[..])),
1029            ],
1030            vec![
1031                (peer2, BytesMut::from(&b"siip"[..])),
1032                (peer2, BytesMut::from(&b"huup"[..])),
1033                (peer1, BytesMut::from(&b"hello"[..])),
1034                (peer1, BytesMut::from(&b"world"[..])),
1035            ],
1036            vec![
1037                (peer1, BytesMut::from(&b"hello"[..])),
1038                (peer2, BytesMut::from(&b"siip"[..])),
1039                (peer2, BytesMut::from(&b"huup"[..])),
1040                (peer1, BytesMut::from(&b"world"[..])),
1041            ],
1042        ];
1043
1044        // poll values
1045        let mut values = Vec::new();
1046
1047        for _ in 0..4 {
1048            let value = set.next().await.unwrap();
1049            values.push((value.0, value.1.unwrap()));
1050        }
1051
1052        let mut correct_found = false;
1053
1054        for set in expected {
1055            if values == set {
1056                correct_found = true;
1057                break;
1058            }
1059        }
1060
1061        if !correct_found {
1062            panic!("invalid set generated");
1063        }
1064
1065        // rest of the calls return `Poll::Pending`
1066        for _ in 0..10 {
1067            assert!(futures::poll!(set.next()).is_pending());
1068        }
1069    }
1070}