1use crate::{
25 codec::ProtocolCodec, error::SubstreamError, transport::tcp, types::SubstreamId, PeerId,
26};
27
28#[cfg(feature = "quic")]
29use crate::transport::quic;
30#[cfg(feature = "webrtc")]
31use crate::transport::webrtc;
32#[cfg(feature = "websocket")]
33use crate::transport::websocket;
34
35use bytes::{Buf, Bytes, BytesMut};
36use futures::{Sink, Stream};
37use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
38use unsigned_varint::{decode, encode};
39
40use std::{
41 collections::{hash_map::Entry, HashMap, VecDeque},
42 fmt,
43 hash::Hash,
44 io::ErrorKind,
45 pin::Pin,
46 task::{Context, Poll},
47};
48
49const LOG_TARGET: &str = "litep2p::substream";
51
52macro_rules! poll_flush {
53 ($substream:expr, $cx:ident) => {{
54 match $substream {
55 SubstreamType::Tcp(substream) => Pin::new(substream).poll_flush($cx),
56 #[cfg(feature = "websocket")]
57 SubstreamType::WebSocket(substream) => Pin::new(substream).poll_flush($cx),
58 #[cfg(feature = "quic")]
59 SubstreamType::Quic(substream) => Pin::new(substream).poll_flush($cx),
60 #[cfg(feature = "webrtc")]
61 SubstreamType::WebRtc(substream) => Pin::new(substream).poll_flush($cx),
62 #[cfg(test)]
63 SubstreamType::Mock(_) => unreachable!(),
64 }
65 }};
66}
67
68macro_rules! poll_write {
69 ($substream:expr, $cx:ident, $frame:expr) => {{
70 match $substream {
71 SubstreamType::Tcp(substream) => Pin::new(substream).poll_write($cx, $frame),
72 #[cfg(feature = "websocket")]
73 SubstreamType::WebSocket(substream) => Pin::new(substream).poll_write($cx, $frame),
74 #[cfg(feature = "quic")]
75 SubstreamType::Quic(substream) => Pin::new(substream).poll_write($cx, $frame),
76 #[cfg(feature = "webrtc")]
77 SubstreamType::WebRtc(substream) => Pin::new(substream).poll_write($cx, $frame),
78 #[cfg(test)]
79 SubstreamType::Mock(_) => unreachable!(),
80 }
81 }};
82}
83
84macro_rules! poll_read {
85 ($substream:expr, $cx:ident, $buffer:expr) => {{
86 match $substream {
87 SubstreamType::Tcp(substream) => Pin::new(substream).poll_read($cx, $buffer),
88 #[cfg(feature = "websocket")]
89 SubstreamType::WebSocket(substream) => Pin::new(substream).poll_read($cx, $buffer),
90 #[cfg(feature = "quic")]
91 SubstreamType::Quic(substream) => Pin::new(substream).poll_read($cx, $buffer),
92 #[cfg(feature = "webrtc")]
93 SubstreamType::WebRtc(substream) => Pin::new(substream).poll_read($cx, $buffer),
94 #[cfg(test)]
95 SubstreamType::Mock(_) => unreachable!(),
96 }
97 }};
98}
99
100macro_rules! poll_shutdown {
101 ($substream:expr, $cx:ident) => {{
102 match $substream {
103 SubstreamType::Tcp(substream) => Pin::new(substream).poll_shutdown($cx),
104 #[cfg(feature = "websocket")]
105 SubstreamType::WebSocket(substream) => Pin::new(substream).poll_shutdown($cx),
106 #[cfg(feature = "quic")]
107 SubstreamType::Quic(substream) => Pin::new(substream).poll_shutdown($cx),
108 #[cfg(feature = "webrtc")]
109 SubstreamType::WebRtc(substream) => Pin::new(substream).poll_shutdown($cx),
110 #[cfg(test)]
111 SubstreamType::Mock(substream) => {
112 let _ = Pin::new(substream).poll_close($cx);
113 todo!();
114 }
115 }
116 }};
117}
118
119macro_rules! delegate_poll_next {
120 ($substream:expr, $cx:ident) => {{
121 #[cfg(test)]
122 if let SubstreamType::Mock(inner) = $substream {
123 return Pin::new(inner).poll_next($cx);
124 }
125 }};
126}
127
128macro_rules! delegate_poll_ready {
129 ($substream:expr, $cx:ident) => {{
130 #[cfg(test)]
131 if let SubstreamType::Mock(inner) = $substream {
132 return Pin::new(inner).poll_ready($cx);
133 }
134 }};
135}
136
137macro_rules! delegate_start_send {
138 ($substream:expr, $item:ident) => {{
139 #[cfg(test)]
140 if let SubstreamType::Mock(inner) = $substream {
141 return Pin::new(inner).start_send($item);
142 }
143 }};
144}
145
146macro_rules! delegate_poll_flush {
147 ($substream:expr, $cx:ident) => {{
148 #[cfg(test)]
149 if let SubstreamType::Mock(inner) = $substream {
150 return Pin::new(inner).poll_flush($cx);
151 }
152 }};
153}
154
155macro_rules! check_size {
156 ($max_size:expr, $size:expr) => {{
157 if let Some(max_size) = $max_size {
158 if $size > max_size {
159 return Err(SubstreamError::IoError(ErrorKind::PermissionDenied).into());
160 }
161 }
162 }};
163}
164
165enum SubstreamType {
167 Tcp(tcp::Substream),
168 #[cfg(feature = "websocket")]
169 WebSocket(websocket::Substream),
170 #[cfg(feature = "quic")]
171 Quic(quic::Substream),
172 #[cfg(feature = "webrtc")]
173 WebRtc(webrtc::Substream),
174 #[cfg(test)]
175 Mock(Box<dyn crate::mock::substream::Substream>),
176}
177
178impl fmt::Debug for SubstreamType {
179 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180 match self {
181 Self::Tcp(_) => write!(f, "Tcp"),
182 #[cfg(feature = "websocket")]
183 Self::WebSocket(_) => write!(f, "WebSocket"),
184 #[cfg(feature = "quic")]
185 Self::Quic(_) => write!(f, "Quic"),
186 #[cfg(feature = "webrtc")]
187 Self::WebRtc(_) => write!(f, "WebRtc"),
188 #[cfg(test)]
189 Self::Mock(_) => write!(f, "Mock"),
190 }
191 }
192}
193
194const BACKPRESSURE_BOUNDARY: usize = 65536;
196
197pub struct Substream {
206 peer: PeerId,
208
209 substream: SubstreamType,
211
212 substream_id: SubstreamId,
214
215 codec: ProtocolCodec,
217
218 pending_out_frames: VecDeque<Bytes>,
219 pending_out_bytes: usize,
220 pending_out_frame: Option<Bytes>,
221
222 read_buffer: BytesMut,
223 offset: usize,
224 pending_frames: VecDeque<BytesMut>,
225 current_frame_size: Option<usize>,
226
227 size_vec: BytesMut,
228}
229
230impl fmt::Debug for Substream {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 f.debug_struct("Substream")
233 .field("peer", &self.peer)
234 .field("substream_id", &self.substream_id)
235 .field("codec", &self.codec)
236 .field("protocol", &self.substream)
237 .finish()
238 }
239}
240
241impl Substream {
242 fn new(
244 peer: PeerId,
245 substream_id: SubstreamId,
246 substream: SubstreamType,
247 codec: ProtocolCodec,
248 ) -> Self {
249 Self {
250 peer,
251 substream,
252 codec,
253 substream_id,
254 read_buffer: BytesMut::zeroed(1024),
255 offset: 0usize,
256 pending_frames: VecDeque::new(),
257 current_frame_size: None,
258 pending_out_bytes: 0usize,
259 pending_out_frames: VecDeque::new(),
260 pending_out_frame: None,
261 size_vec: BytesMut::zeroed(10),
262 }
263 }
264
265 pub(crate) fn new_tcp(
267 peer: PeerId,
268 substream_id: SubstreamId,
269 substream: tcp::Substream,
270 codec: ProtocolCodec,
271 ) -> Self {
272 tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for tcp");
273
274 Self::new(peer, substream_id, SubstreamType::Tcp(substream), codec)
275 }
276
277 #[cfg(feature = "websocket")]
279 pub(crate) fn new_websocket(
280 peer: PeerId,
281 substream_id: SubstreamId,
282 substream: websocket::Substream,
283 codec: ProtocolCodec,
284 ) -> Self {
285 tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for websocket");
286
287 Self::new(
288 peer,
289 substream_id,
290 SubstreamType::WebSocket(substream),
291 codec,
292 )
293 }
294
295 #[cfg(feature = "quic")]
297 pub(crate) fn new_quic(
298 peer: PeerId,
299 substream_id: SubstreamId,
300 substream: quic::Substream,
301 codec: ProtocolCodec,
302 ) -> Self {
303 tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for quic");
304
305 Self::new(peer, substream_id, SubstreamType::Quic(substream), codec)
306 }
307
308 #[cfg(feature = "webrtc")]
310 pub(crate) fn new_webrtc(
311 peer: PeerId,
312 substream_id: SubstreamId,
313 substream: webrtc::Substream,
314 codec: ProtocolCodec,
315 ) -> Self {
316 tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for webrtc");
317
318 Self::new(peer, substream_id, SubstreamType::WebRtc(substream), codec)
319 }
320
321 #[cfg(test)]
323 pub(crate) fn new_mock(
324 peer: PeerId,
325 substream_id: SubstreamId,
326 substream: Box<dyn crate::mock::substream::Substream>,
327 ) -> Self {
328 tracing::trace!(target: LOG_TARGET, ?peer, "create new substream for mocking");
329
330 Self::new(
331 peer,
332 substream_id,
333 SubstreamType::Mock(substream),
334 ProtocolCodec::Unspecified,
335 )
336 }
337
338 pub async fn close(self) {
340 let _ = match self.substream {
341 SubstreamType::Tcp(mut substream) => substream.shutdown().await,
342 #[cfg(feature = "websocket")]
343 SubstreamType::WebSocket(mut substream) => substream.shutdown().await,
344 #[cfg(feature = "quic")]
345 SubstreamType::Quic(mut substream) => substream.shutdown().await,
346 #[cfg(feature = "webrtc")]
347 SubstreamType::WebRtc(mut substream) => substream.shutdown().await,
348 #[cfg(test)]
349 SubstreamType::Mock(mut substream) => {
350 let _ = futures::SinkExt::close(&mut substream).await;
351 Ok(())
352 }
353 };
354 }
355
356 async fn send_identity_payload<T: AsyncWrite + Unpin>(
358 io: &mut T,
359 payload_size: usize,
360 payload: Bytes,
361 ) -> Result<(), SubstreamError> {
362 if payload.len() != payload_size {
363 return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
364 }
365
366 io.write_all(&payload).await.map_err(|_| SubstreamError::ConnectionClosed)?;
367
368 io.flush().await.map_err(From::from)
370 }
371
372 async fn send_unsigned_varint_payload<T: AsyncWrite + Unpin>(
374 io: &mut T,
375 bytes: Bytes,
376 max_size: Option<usize>,
377 ) -> Result<(), SubstreamError> {
378 if let Some(max_size) = max_size {
379 if bytes.len() > max_size {
380 return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
381 }
382 }
383
384 let mut buffer = unsigned_varint::encode::usize_buffer();
386 let encoded_len = unsigned_varint::encode::usize(bytes.len(), &mut buffer).len();
387 io.write_all(&buffer[..encoded_len]).await?;
388
389 io.write_all(bytes.as_ref()).await?;
391
392 io.flush().await.map_err(From::from)
394 }
395
396 pub async fn send_framed(&mut self, bytes: Bytes) -> Result<(), SubstreamError> {
411 tracing::trace!(
412 target: LOG_TARGET,
413 peer = ?self.peer,
414 codec = ?self.codec,
415 frame_len = ?bytes.len(),
416 "send framed"
417 );
418
419 match &mut self.substream {
420 #[cfg(test)]
421 SubstreamType::Mock(ref mut substream) =>
422 futures::SinkExt::send(substream, bytes).await,
423 SubstreamType::Tcp(ref mut substream) => match self.codec {
424 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
425 ProtocolCodec::Identity(payload_size) =>
426 Self::send_identity_payload(substream, payload_size, bytes).await,
427 ProtocolCodec::UnsignedVarint(max_size) =>
428 Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
429 },
430 #[cfg(feature = "websocket")]
431 SubstreamType::WebSocket(ref mut substream) => match self.codec {
432 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
433 ProtocolCodec::Identity(payload_size) =>
434 Self::send_identity_payload(substream, payload_size, bytes).await,
435 ProtocolCodec::UnsignedVarint(max_size) =>
436 Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
437 },
438 #[cfg(feature = "quic")]
439 SubstreamType::Quic(ref mut substream) => match self.codec {
440 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
441 ProtocolCodec::Identity(payload_size) =>
442 Self::send_identity_payload(substream, payload_size, bytes).await,
443 ProtocolCodec::UnsignedVarint(max_size) => {
444 check_size!(max_size, bytes.len());
445
446 let mut buffer = unsigned_varint::encode::usize_buffer();
447 let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer);
448 let len = BytesMut::from(len);
449
450 substream.write_all_chunks(&mut [len.freeze(), bytes]).await
451 }
452 },
453 #[cfg(feature = "webrtc")]
454 SubstreamType::WebRtc(ref mut substream) => match self.codec {
455 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
456 ProtocolCodec::Identity(payload_size) =>
457 Self::send_identity_payload(substream, payload_size, bytes).await,
458 ProtocolCodec::UnsignedVarint(max_size) =>
459 Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
460 },
461 }
462 }
463}
464
465impl tokio::io::AsyncRead for Substream {
466 fn poll_read(
467 mut self: Pin<&mut Self>,
468 cx: &mut Context<'_>,
469 buf: &mut tokio::io::ReadBuf<'_>,
470 ) -> Poll<std::io::Result<()>> {
471 poll_read!(&mut self.substream, cx, buf)
472 }
473}
474
475impl tokio::io::AsyncWrite for Substream {
476 fn poll_write(
477 mut self: Pin<&mut Self>,
478 cx: &mut Context<'_>,
479 buf: &[u8],
480 ) -> Poll<Result<usize, std::io::Error>> {
481 poll_write!(&mut self.substream, cx, buf)
482 }
483
484 fn poll_flush(
485 mut self: Pin<&mut Self>,
486 cx: &mut Context<'_>,
487 ) -> Poll<Result<(), std::io::Error>> {
488 poll_flush!(&mut self.substream, cx)
489 }
490
491 fn poll_shutdown(
492 mut self: Pin<&mut Self>,
493 cx: &mut Context<'_>,
494 ) -> Poll<Result<(), std::io::Error>> {
495 poll_shutdown!(&mut self.substream, cx)
496 }
497}
498
499enum ReadError {
500 Overflow,
501 NotEnoughBytes,
502 DecodeError,
503}
504
505fn read_payload_size(buffer: &[u8]) -> Result<(usize, usize), ReadError> {
507 let max_len = encode::usize_buffer().len();
508
509 for i in 0..std::cmp::min(buffer.len(), max_len) {
510 if decode::is_last(buffer[i]) {
511 match decode::usize(&buffer[..=i]) {
512 Err(_) => return Err(ReadError::DecodeError),
513 Ok(size) => return Ok((size.0, i + 1)),
514 }
515 }
516 }
517
518 match buffer.len() < max_len {
519 true => Err(ReadError::NotEnoughBytes),
520 false => Err(ReadError::Overflow),
521 }
522}
523
524impl Stream for Substream {
525 type Item = Result<BytesMut, SubstreamError>;
526
527 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
528 let this = Pin::into_inner(self);
529
530 delegate_poll_next!(&mut this.substream, cx);
532
533 loop {
534 match this.codec {
535 ProtocolCodec::Identity(payload_size) => {
536 let mut read_buf =
537 ReadBuf::new(&mut this.read_buffer[this.offset..payload_size]);
538
539 match futures::ready!(poll_read!(&mut this.substream, cx, &mut read_buf)) {
540 Ok(_) => {
541 let nread = read_buf.filled().len();
542 if nread == 0 {
543 tracing::trace!(
544 target: LOG_TARGET,
545 peer = ?this.peer,
546 "read zero bytes, substream closed"
547 );
548 return Poll::Ready(None);
549 }
550
551 this.offset = this.offset.saturating_add(nread);
552
553 if this.offset == payload_size {
554 let mut payload = std::mem::replace(
555 &mut this.read_buffer,
556 BytesMut::zeroed(payload_size),
557 );
558 payload.truncate(payload_size);
559 this.offset = 0usize;
560
561 return Poll::Ready(Some(Ok(payload)));
562 }
563 }
564 Err(error) => return Poll::Ready(Some(Err(error.into()))),
565 }
566 }
567 ProtocolCodec::UnsignedVarint(max_size) => {
568 loop {
569 if let Some(frame) = this.pending_frames.pop_front() {
571 return Poll::Ready(Some(Ok(frame)));
572 }
573
574 match this.current_frame_size.take() {
575 Some(frame_size) => {
576 let mut read_buf =
577 ReadBuf::new(&mut this.read_buffer[this.offset..]);
578 this.current_frame_size = Some(frame_size);
579
580 match futures::ready!(poll_read!(
581 &mut this.substream,
582 cx,
583 &mut read_buf
584 )) {
585 Err(_error) => return Poll::Ready(None),
586 Ok(_) => {
587 let nread = match read_buf.filled().len() {
588 0 => return Poll::Ready(None),
589 nread => nread,
590 };
591
592 this.offset += nread;
593
594 if this.offset == frame_size {
595 let out_frame = std::mem::replace(
596 &mut this.read_buffer,
597 BytesMut::new(),
598 );
599 this.offset = 0;
600 this.current_frame_size = None;
601
602 return Poll::Ready(Some(Ok(out_frame)));
603 } else {
604 this.current_frame_size = Some(frame_size);
605 continue;
606 }
607 }
608 }
609 }
610 None => {
611 let mut read_buf =
612 ReadBuf::new(&mut this.size_vec[this.offset..this.offset + 1]);
613
614 match futures::ready!(poll_read!(
615 &mut this.substream,
616 cx,
617 &mut read_buf
618 )) {
619 Err(_error) => return Poll::Ready(None),
620 Ok(_) => {
621 if read_buf.filled().is_empty() {
622 return Poll::Ready(None);
623 }
624 this.offset += 1;
625
626 match read_payload_size(&this.size_vec[..this.offset]) {
627 Err(ReadError::NotEnoughBytes) => continue,
628 Err(_) =>
629 return Poll::Ready(Some(Err(
630 SubstreamError::ReadFailure(Some(
631 this.substream_id,
632 )),
633 ))),
634 Ok((size, num_bytes)) => {
635 debug_assert_eq!(num_bytes, this.offset);
636
637 if let Some(max_size) = max_size {
638 if size > max_size {
639 return Poll::Ready(Some(Err(
640 SubstreamError::ReadFailure(Some(
641 this.substream_id,
642 )),
643 )));
644 }
645 }
646
647 this.offset = 0;
648 if size == 0 {
652 return Poll::Ready(Some(Ok(BytesMut::new())));
653 }
654
655 this.current_frame_size = Some(size);
656 this.read_buffer = BytesMut::zeroed(size);
657 }
658 }
659 }
660 }
661 }
662 }
663 }
664 }
665 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
666 }
667 }
668 }
669}
670
671impl Sink<Bytes> for Substream {
673 type Error = SubstreamError;
674
675 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
676 delegate_poll_ready!(&mut self.substream, cx);
678
679 if self.pending_out_bytes >= BACKPRESSURE_BOUNDARY {
680 match futures::Sink::poll_flush(self.as_mut(), cx) {
682 Poll::Ready(Ok(())) => {}
683 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
684 Poll::Pending => {
685 return Poll::Pending;
687 }
688 }
689 }
690
691 Poll::Ready(Ok(()))
692 }
693
694 fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
695 delegate_start_send!(&mut self.substream, item);
697
698 tracing::trace!(
699 target: LOG_TARGET,
700 peer = ?self.peer,
701 substream_id = ?self.substream_id,
702 data_len = item.len(),
703 "Substream::start_send()",
704 );
705
706 match self.codec {
707 ProtocolCodec::Identity(payload_size) => {
708 if item.len() != payload_size {
709 return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
710 }
711
712 self.pending_out_bytes += item.len();
713 self.pending_out_frames.push_back(item);
714 }
715 ProtocolCodec::UnsignedVarint(max_size) => {
716 check_size!(max_size, item.len());
717
718 let len = {
719 let mut buffer = unsigned_varint::encode::usize_buffer();
720 let len = unsigned_varint::encode::usize(item.len(), &mut buffer);
721 BytesMut::from(len)
722 };
723
724 self.pending_out_bytes += len.len() + item.len();
725 self.pending_out_frames.push_back(len.freeze());
726 self.pending_out_frames.push_back(item);
727 }
728 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
729 }
730
731 Ok(())
732 }
733
734 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
735 delegate_poll_flush!(&mut self.substream, cx);
737
738 loop {
739 let mut pending_frame = match self.pending_out_frame.take() {
740 Some(frame) => frame,
741 None => match self.pending_out_frames.pop_front() {
742 Some(frame) => frame,
743 None => break,
744 },
745 };
746
747 match poll_write!(&mut self.substream, cx, &pending_frame) {
748 Poll::Ready(Err(error)) => return Poll::Ready(Err(error.into())),
749 Poll::Pending => {
750 self.pending_out_frame = Some(pending_frame);
751 break;
752 }
753 Poll::Ready(Ok(nwritten)) => {
754 pending_frame.advance(nwritten);
755
756 self.pending_out_bytes = self.pending_out_bytes.saturating_sub(nwritten);
759
760 if !pending_frame.is_empty() {
761 self.pending_out_frame = Some(pending_frame);
762 }
763 }
764 }
765 }
766
767 poll_flush!(&mut self.substream, cx).map_err(From::from)
768 }
769
770 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
771 poll_shutdown!(&mut self.substream, cx).map_err(From::from)
772 }
773}
774
775pub trait SubstreamSetKey: Hash + Unpin + fmt::Debug + PartialEq + Eq + Copy {}
777
778impl<K: Hash + Unpin + fmt::Debug + PartialEq + Eq + Copy> SubstreamSetKey for K {}
779
780#[derive(Debug, Default)]
783pub struct SubstreamSet<K, S>
784where
785 K: SubstreamSetKey,
786 S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
787{
788 substreams: HashMap<K, S>,
789}
790
791impl<K, S> SubstreamSet<K, S>
792where
793 K: SubstreamSetKey,
794 S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
795{
796 pub fn new() -> Self {
798 Self {
799 substreams: HashMap::new(),
800 }
801 }
802
803 pub fn insert(&mut self, key: K, substream: S) {
805 match self.substreams.entry(key) {
806 Entry::Vacant(entry) => {
807 entry.insert(substream);
808 }
809 Entry::Occupied(_) => {
810 tracing::error!(?key, "substream already exists");
811 debug_assert!(false);
812 }
813 }
814 }
815
816 pub fn remove(&mut self, key: &K) -> Option<S> {
818 self.substreams.remove(key)
819 }
820
821 #[cfg(test)]
823 pub fn get_mut(&mut self, key: &K) -> Option<&mut S> {
824 self.substreams.get_mut(key)
825 }
826
827 pub fn len(&self) -> usize {
829 self.substreams.len()
830 }
831
832 pub fn is_empty(&self) -> bool {
834 self.substreams.is_empty()
835 }
836}
837
838impl<K, S> Stream for SubstreamSet<K, S>
839where
840 K: SubstreamSetKey,
841 S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
842{
843 type Item = (K, <S as Stream>::Item);
844
845 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
846 let inner = Pin::into_inner(self);
847
848 for (key, mut substream) in inner.substreams.iter_mut() {
849 match Pin::new(&mut substream).poll_next(cx) {
850 Poll::Pending => continue,
851 Poll::Ready(Some(data)) => return Poll::Ready(Some((*key, data))),
852 Poll::Ready(None) =>
853 return Poll::Ready(Some((*key, Err(SubstreamError::ConnectionClosed)))),
854 }
855 }
856
857 Poll::Pending
858 }
859}
860
861#[cfg(test)]
862mod tests {
863 use super::*;
864 use crate::{mock::substream::MockSubstream, PeerId};
865 use futures::{SinkExt, StreamExt};
866
867 #[test]
868 fn add_substream() {
869 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
870
871 let peer = PeerId::random();
872 let substream = MockSubstream::new();
873 set.insert(peer, substream);
874
875 let peer = PeerId::random();
876 let substream = MockSubstream::new();
877 set.insert(peer, substream);
878 }
879
880 #[test]
881 #[should_panic]
882 #[cfg(debug_assertions)]
883 fn add_same_peer_twice() {
884 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
885
886 let peer = PeerId::random();
887 let substream1 = MockSubstream::new();
888 let substream2 = MockSubstream::new();
889
890 set.insert(peer, substream1);
891 set.insert(peer, substream2);
892 }
893
894 #[test]
895 fn remove_substream() {
896 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
897
898 let peer1 = PeerId::random();
899 let substream1 = MockSubstream::new();
900 set.insert(peer1, substream1);
901
902 let peer2 = PeerId::random();
903 let substream2 = MockSubstream::new();
904 set.insert(peer2, substream2);
905
906 assert!(set.remove(&peer1).is_some());
907 assert!(set.remove(&peer2).is_some());
908 assert!(set.remove(&PeerId::random()).is_none());
909 }
910
911 #[tokio::test]
912 async fn poll_data_from_substream() {
913 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
914
915 let peer = PeerId::random();
916 let mut substream = MockSubstream::new();
917 substream
918 .expect_poll_next()
919 .times(1)
920 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
921 substream
922 .expect_poll_next()
923 .times(1)
924 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
925 substream.expect_poll_next().returning(|_| Poll::Pending);
926 set.insert(peer, substream);
927
928 let value = set.next().await.unwrap();
929 assert_eq!(value.0, peer);
930 assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
931
932 let value = set.next().await.unwrap();
933 assert_eq!(value.0, peer);
934 assert_eq!(value.1.unwrap(), BytesMut::from(&b"world"[..]));
935
936 assert!(futures::poll!(set.next()).is_pending());
937 }
938
939 #[tokio::test]
940 async fn substream_closed() {
941 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
942
943 let peer = PeerId::random();
944 let mut substream = MockSubstream::new();
945 substream
946 .expect_poll_next()
947 .times(1)
948 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
949 substream.expect_poll_next().times(1).return_once(|_| Poll::Ready(None));
950 substream.expect_poll_next().returning(|_| Poll::Pending);
951 set.insert(peer, substream);
952
953 let value = set.next().await.unwrap();
954 assert_eq!(value.0, peer);
955 assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
956
957 match set.next().await {
958 Some((exited_peer, Err(SubstreamError::ConnectionClosed))) => {
959 assert_eq!(peer, exited_peer);
960 }
961 _ => panic!("inavlid event received"),
962 }
963 }
964
965 #[tokio::test]
966 async fn get_mut_substream() {
967 let _ = tracing_subscriber::fmt()
968 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
969 .try_init();
970
971 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
972
973 let peer = PeerId::random();
974 let mut substream = MockSubstream::new();
975 substream
976 .expect_poll_next()
977 .times(1)
978 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
979 substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(())));
980 substream.expect_start_send().times(1).return_once(|_| Ok(()));
981 substream.expect_poll_flush().times(1).return_once(|_| Poll::Ready(Ok(())));
982 substream
983 .expect_poll_next()
984 .times(1)
985 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
986 substream.expect_poll_next().returning(|_| Poll::Pending);
987 set.insert(peer, substream);
988
989 let value = set.next().await.unwrap();
990 assert_eq!(value.0, peer);
991 assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
992
993 let substream = set.get_mut(&peer).unwrap();
994 substream.send(vec![1, 2, 3, 4].into()).await.unwrap();
995
996 let value = set.next().await.unwrap();
997 assert_eq!(value.0, peer);
998 assert_eq!(value.1.unwrap(), BytesMut::from(&b"world"[..]));
999
1000 assert!(set.get_mut(&PeerId::random()).is_none());
1002 }
1003
1004 #[tokio::test]
1005 async fn poll_data_from_two_substreams() {
1006 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
1007
1008 let peer1 = PeerId::random();
1010 let mut substream1 = MockSubstream::new();
1011 substream1
1012 .expect_poll_next()
1013 .times(1)
1014 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
1015 substream1
1016 .expect_poll_next()
1017 .times(1)
1018 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
1019 substream1.expect_poll_next().returning(|_| Poll::Pending);
1020 set.insert(peer1, substream1);
1021
1022 let peer2 = PeerId::random();
1024 let mut substream2 = MockSubstream::new();
1025 substream2
1026 .expect_poll_next()
1027 .times(1)
1028 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"siip"[..])))));
1029 substream2
1030 .expect_poll_next()
1031 .times(1)
1032 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"huup"[..])))));
1033 substream2.expect_poll_next().returning(|_| Poll::Pending);
1034 set.insert(peer2, substream2);
1035
1036 let expected: Vec<Vec<(PeerId, BytesMut)>> = vec![
1037 vec![
1038 (peer1, BytesMut::from(&b"hello"[..])),
1039 (peer1, BytesMut::from(&b"world"[..])),
1040 (peer2, BytesMut::from(&b"siip"[..])),
1041 (peer2, BytesMut::from(&b"huup"[..])),
1042 ],
1043 vec![
1044 (peer1, BytesMut::from(&b"hello"[..])),
1045 (peer2, BytesMut::from(&b"siip"[..])),
1046 (peer1, BytesMut::from(&b"world"[..])),
1047 (peer2, BytesMut::from(&b"huup"[..])),
1048 ],
1049 vec![
1050 (peer2, BytesMut::from(&b"siip"[..])),
1051 (peer2, BytesMut::from(&b"huup"[..])),
1052 (peer1, BytesMut::from(&b"hello"[..])),
1053 (peer1, BytesMut::from(&b"world"[..])),
1054 ],
1055 vec![
1056 (peer1, BytesMut::from(&b"hello"[..])),
1057 (peer2, BytesMut::from(&b"siip"[..])),
1058 (peer2, BytesMut::from(&b"huup"[..])),
1059 (peer1, BytesMut::from(&b"world"[..])),
1060 ],
1061 ];
1062
1063 let mut values = Vec::new();
1065
1066 for _ in 0..4 {
1067 let value = set.next().await.unwrap();
1068 values.push((value.0, value.1.unwrap()));
1069 }
1070
1071 let mut correct_found = false;
1072
1073 for set in expected {
1074 if values == set {
1075 correct_found = true;
1076 break;
1077 }
1078 }
1079
1080 if !correct_found {
1081 panic!("invalid set generated");
1082 }
1083
1084 for _ in 0..10 {
1086 assert!(futures::poll!(set.next()).is_pending());
1087 }
1088 }
1089}