1use crate::{
25 codec::ProtocolCodec, error::SubstreamError, transport::tcp, types::SubstreamId, PeerId,
26};
27
28#[cfg(feature = "quic")]
29use crate::transport::quic;
30#[cfg(feature = "webrtc")]
31use crate::transport::webrtc;
32#[cfg(feature = "websocket")]
33use crate::transport::websocket;
34
35use bytes::{Buf, Bytes, BytesMut};
36use futures::{Sink, Stream};
37use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
38use unsigned_varint::{decode, encode};
39
40use std::{
41 collections::{hash_map::Entry, HashMap, VecDeque},
42 fmt,
43 hash::Hash,
44 io::ErrorKind,
45 pin::Pin,
46 task::{Context, Poll},
47};
48
49const LOG_TARGET: &str = "litep2p::substream";
51
52macro_rules! poll_flush {
53 ($substream:expr, $cx:ident) => {{
54 match $substream {
55 SubstreamType::Tcp(substream) => Pin::new(substream).poll_flush($cx),
56 #[cfg(feature = "websocket")]
57 SubstreamType::WebSocket(substream) => Pin::new(substream).poll_flush($cx),
58 #[cfg(feature = "quic")]
59 SubstreamType::Quic(substream) => Pin::new(substream).poll_flush($cx),
60 #[cfg(feature = "webrtc")]
61 SubstreamType::WebRtc(substream) => Pin::new(substream).poll_flush($cx),
62 #[cfg(test)]
63 SubstreamType::Mock(_) => unreachable!(),
64 }
65 }};
66}
67
68macro_rules! poll_write {
69 ($substream:expr, $cx:ident, $frame:expr) => {{
70 match $substream {
71 SubstreamType::Tcp(substream) => Pin::new(substream).poll_write($cx, $frame),
72 #[cfg(feature = "websocket")]
73 SubstreamType::WebSocket(substream) => Pin::new(substream).poll_write($cx, $frame),
74 #[cfg(feature = "quic")]
75 SubstreamType::Quic(substream) => Pin::new(substream).poll_write($cx, $frame),
76 #[cfg(feature = "webrtc")]
77 SubstreamType::WebRtc(substream) => Pin::new(substream).poll_write($cx, $frame),
78 #[cfg(test)]
79 SubstreamType::Mock(_) => unreachable!(),
80 }
81 }};
82}
83
84macro_rules! poll_read {
85 ($substream:expr, $cx:ident, $buffer:expr) => {{
86 match $substream {
87 SubstreamType::Tcp(substream) => Pin::new(substream).poll_read($cx, $buffer),
88 #[cfg(feature = "websocket")]
89 SubstreamType::WebSocket(substream) => Pin::new(substream).poll_read($cx, $buffer),
90 #[cfg(feature = "quic")]
91 SubstreamType::Quic(substream) => Pin::new(substream).poll_read($cx, $buffer),
92 #[cfg(feature = "webrtc")]
93 SubstreamType::WebRtc(substream) => Pin::new(substream).poll_read($cx, $buffer),
94 #[cfg(test)]
95 SubstreamType::Mock(_) => unreachable!(),
96 }
97 }};
98}
99
100macro_rules! poll_shutdown {
101 ($substream:expr, $cx:ident) => {{
102 match $substream {
103 SubstreamType::Tcp(substream) => Pin::new(substream).poll_shutdown($cx),
104 #[cfg(feature = "websocket")]
105 SubstreamType::WebSocket(substream) => Pin::new(substream).poll_shutdown($cx),
106 #[cfg(feature = "quic")]
107 SubstreamType::Quic(substream) => Pin::new(substream).poll_shutdown($cx),
108 #[cfg(feature = "webrtc")]
109 SubstreamType::WebRtc(substream) => Pin::new(substream).poll_shutdown($cx),
110 #[cfg(test)]
111 SubstreamType::Mock(substream) => {
112 let _ = Pin::new(substream).poll_close($cx);
113 todo!();
114 }
115 }
116 }};
117}
118
119macro_rules! delegate_poll_next {
120 ($substream:expr, $cx:ident) => {{
121 #[cfg(test)]
122 if let SubstreamType::Mock(inner) = $substream {
123 return Pin::new(inner).poll_next($cx);
124 }
125 }};
126}
127
128macro_rules! delegate_poll_ready {
129 ($substream:expr, $cx:ident) => {{
130 #[cfg(test)]
131 if let SubstreamType::Mock(inner) = $substream {
132 return Pin::new(inner).poll_ready($cx);
133 }
134 }};
135}
136
137macro_rules! delegate_start_send {
138 ($substream:expr, $item:ident) => {{
139 #[cfg(test)]
140 if let SubstreamType::Mock(inner) = $substream {
141 return Pin::new(inner).start_send($item);
142 }
143 }};
144}
145
146macro_rules! delegate_poll_flush {
147 ($substream:expr, $cx:ident) => {{
148 #[cfg(test)]
149 if let SubstreamType::Mock(inner) = $substream {
150 return Pin::new(inner).poll_flush($cx);
151 }
152 }};
153}
154
155macro_rules! check_size {
156 ($max_size:expr, $size:expr) => {{
157 if let Some(max_size) = $max_size {
158 if $size > max_size {
159 return Err(SubstreamError::IoError(ErrorKind::PermissionDenied).into());
160 }
161 }
162 }};
163}
164
165enum SubstreamType {
167 Tcp(tcp::Substream),
168 #[cfg(feature = "websocket")]
169 WebSocket(websocket::Substream),
170 #[cfg(feature = "quic")]
171 Quic(quic::Substream),
172 #[cfg(feature = "webrtc")]
173 WebRtc(webrtc::Substream),
174 #[cfg(test)]
175 Mock(Box<dyn crate::mock::substream::Substream>),
176}
177
178impl fmt::Debug for SubstreamType {
179 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180 match self {
181 Self::Tcp(_) => write!(f, "Tcp"),
182 #[cfg(feature = "websocket")]
183 Self::WebSocket(_) => write!(f, "WebSocket"),
184 #[cfg(feature = "quic")]
185 Self::Quic(_) => write!(f, "Quic"),
186 #[cfg(feature = "webrtc")]
187 Self::WebRtc(_) => write!(f, "WebRtc"),
188 #[cfg(test)]
189 Self::Mock(_) => write!(f, "Mock"),
190 }
191 }
192}
193
194const BACKPRESSURE_BOUNDARY: usize = 65536;
196
197pub struct Substream {
206 peer: PeerId,
208
209 substream: SubstreamType,
211
212 substream_id: SubstreamId,
214
215 codec: ProtocolCodec,
217
218 pending_out_frames: VecDeque<Bytes>,
219 pending_out_bytes: usize,
220 pending_out_frame: Option<Bytes>,
221
222 read_buffer: BytesMut,
223 offset: usize,
224 pending_frames: VecDeque<BytesMut>,
225 current_frame_size: Option<usize>,
226
227 size_vec: BytesMut,
228}
229
230impl fmt::Debug for Substream {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 f.debug_struct("Substream")
233 .field("peer", &self.peer)
234 .field("substream_id", &self.substream_id)
235 .field("codec", &self.codec)
236 .field("protocol", &self.substream)
237 .finish()
238 }
239}
240
241impl Substream {
242 fn new(
244 peer: PeerId,
245 substream_id: SubstreamId,
246 substream: SubstreamType,
247 codec: ProtocolCodec,
248 ) -> Self {
249 Self {
250 peer,
251 substream,
252 codec,
253 substream_id,
254 read_buffer: BytesMut::zeroed(1024),
255 offset: 0usize,
256 pending_frames: VecDeque::new(),
257 current_frame_size: None,
258 pending_out_bytes: 0usize,
259 pending_out_frames: VecDeque::new(),
260 pending_out_frame: None,
261 size_vec: BytesMut::zeroed(10),
262 }
263 }
264
265 pub(crate) fn new_tcp(
267 peer: PeerId,
268 substream_id: SubstreamId,
269 substream: tcp::Substream,
270 codec: ProtocolCodec,
271 ) -> Self {
272 tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for tcp");
273
274 Self::new(peer, substream_id, SubstreamType::Tcp(substream), codec)
275 }
276
277 #[cfg(feature = "websocket")]
279 pub(crate) fn new_websocket(
280 peer: PeerId,
281 substream_id: SubstreamId,
282 substream: websocket::Substream,
283 codec: ProtocolCodec,
284 ) -> Self {
285 tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for websocket");
286
287 Self::new(
288 peer,
289 substream_id,
290 SubstreamType::WebSocket(substream),
291 codec,
292 )
293 }
294
295 #[cfg(feature = "quic")]
297 pub(crate) fn new_quic(
298 peer: PeerId,
299 substream_id: SubstreamId,
300 substream: quic::Substream,
301 codec: ProtocolCodec,
302 ) -> Self {
303 tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for quic");
304
305 Self::new(peer, substream_id, SubstreamType::Quic(substream), codec)
306 }
307
308 #[cfg(feature = "webrtc")]
310 pub(crate) fn new_webrtc(
311 peer: PeerId,
312 substream_id: SubstreamId,
313 substream: webrtc::Substream,
314 codec: ProtocolCodec,
315 ) -> Self {
316 tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for webrtc");
317
318 Self::new(peer, substream_id, SubstreamType::WebRtc(substream), codec)
319 }
320
321 #[cfg(test)]
323 pub(crate) fn new_mock(
324 peer: PeerId,
325 substream_id: SubstreamId,
326 substream: Box<dyn crate::mock::substream::Substream>,
327 ) -> Self {
328 tracing::trace!(target: LOG_TARGET, ?peer, "create new substream for mocking");
329
330 Self::new(
331 peer,
332 substream_id,
333 SubstreamType::Mock(substream),
334 ProtocolCodec::Unspecified,
335 )
336 }
337
338 pub async fn close(self) {
340 let _ = match self.substream {
341 SubstreamType::Tcp(mut substream) => substream.shutdown().await,
342 #[cfg(feature = "websocket")]
343 SubstreamType::WebSocket(mut substream) => substream.shutdown().await,
344 #[cfg(feature = "quic")]
345 SubstreamType::Quic(mut substream) => substream.shutdown().await,
346 #[cfg(feature = "webrtc")]
347 SubstreamType::WebRtc(mut substream) => substream.shutdown().await,
348 #[cfg(test)]
349 SubstreamType::Mock(mut substream) => {
350 let _ = futures::SinkExt::close(&mut substream).await;
351 Ok(())
352 }
353 };
354 }
355
356 async fn send_identity_payload<T: AsyncWrite + Unpin>(
358 io: &mut T,
359 payload_size: usize,
360 payload: Bytes,
361 ) -> Result<(), SubstreamError> {
362 if payload.len() != payload_size {
363 return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
364 }
365
366 io.write_all(&payload).await.map_err(|_| SubstreamError::ConnectionClosed)?;
367
368 io.flush().await.map_err(From::from)
370 }
371
372 async fn send_unsigned_varint_payload<T: AsyncWrite + Unpin>(
374 io: &mut T,
375 bytes: Bytes,
376 max_size: Option<usize>,
377 ) -> Result<(), SubstreamError> {
378 if let Some(max_size) = max_size {
379 if bytes.len() > max_size {
380 return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
381 }
382 }
383
384 let mut buffer = unsigned_varint::encode::usize_buffer();
386 let encoded_len = unsigned_varint::encode::usize(bytes.len(), &mut buffer).len();
387 io.write_all(&buffer[..encoded_len]).await?;
388
389 io.write_all(bytes.as_ref()).await?;
391
392 io.flush().await.map_err(From::from)
394 }
395
396 pub async fn send_framed(&mut self, bytes: Bytes) -> Result<(), SubstreamError> {
411 tracing::trace!(
412 target: LOG_TARGET,
413 peer = ?self.peer,
414 codec = ?self.codec,
415 frame_len = ?bytes.len(),
416 "send framed"
417 );
418
419 match &mut self.substream {
420 #[cfg(test)]
421 SubstreamType::Mock(ref mut substream) =>
422 futures::SinkExt::send(substream, bytes).await.map_err(Into::into),
423 SubstreamType::Tcp(ref mut substream) => match self.codec {
424 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
425 ProtocolCodec::Identity(payload_size) =>
426 Self::send_identity_payload(substream, payload_size, bytes).await,
427 ProtocolCodec::UnsignedVarint(max_size) =>
428 Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
429 },
430 #[cfg(feature = "websocket")]
431 SubstreamType::WebSocket(ref mut substream) => match self.codec {
432 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
433 ProtocolCodec::Identity(payload_size) =>
434 Self::send_identity_payload(substream, payload_size, bytes).await,
435 ProtocolCodec::UnsignedVarint(max_size) =>
436 Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
437 },
438 #[cfg(feature = "quic")]
439 SubstreamType::Quic(ref mut substream) => match self.codec {
440 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
441 ProtocolCodec::Identity(payload_size) =>
442 Self::send_identity_payload(substream, payload_size, bytes).await,
443 ProtocolCodec::UnsignedVarint(max_size) => {
444 check_size!(max_size, bytes.len());
445
446 let mut buffer = unsigned_varint::encode::usize_buffer();
447 let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer);
448 let len = BytesMut::from(len);
449
450 substream.write_all_chunks(&mut [len.freeze(), bytes]).await
451 }
452 },
453 #[cfg(feature = "webrtc")]
454 SubstreamType::WebRtc(ref mut substream) => match self.codec {
455 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
456 ProtocolCodec::Identity(payload_size) =>
457 Self::send_identity_payload(substream, payload_size, bytes).await,
458 ProtocolCodec::UnsignedVarint(max_size) =>
459 Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
460 },
461 }
462 }
463}
464
465impl tokio::io::AsyncRead for Substream {
466 fn poll_read(
467 mut self: Pin<&mut Self>,
468 cx: &mut Context<'_>,
469 buf: &mut tokio::io::ReadBuf<'_>,
470 ) -> Poll<std::io::Result<()>> {
471 poll_read!(&mut self.substream, cx, buf)
472 }
473}
474
475impl tokio::io::AsyncWrite for Substream {
476 fn poll_write(
477 mut self: Pin<&mut Self>,
478 cx: &mut Context<'_>,
479 buf: &[u8],
480 ) -> Poll<Result<usize, std::io::Error>> {
481 poll_write!(&mut self.substream, cx, buf)
482 }
483
484 fn poll_flush(
485 mut self: Pin<&mut Self>,
486 cx: &mut Context<'_>,
487 ) -> Poll<Result<(), std::io::Error>> {
488 poll_flush!(&mut self.substream, cx)
489 }
490
491 fn poll_shutdown(
492 mut self: Pin<&mut Self>,
493 cx: &mut Context<'_>,
494 ) -> Poll<Result<(), std::io::Error>> {
495 poll_shutdown!(&mut self.substream, cx)
496 }
497}
498
499enum ReadError {
500 Overflow,
501 NotEnoughBytes,
502 DecodeError,
503}
504
505fn read_payload_size(buffer: &[u8]) -> Result<(usize, usize), ReadError> {
507 let max_len = encode::usize_buffer().len();
508
509 for i in 0..std::cmp::min(buffer.len(), max_len) {
510 if decode::is_last(buffer[i]) {
511 match decode::usize(&buffer[..=i]) {
512 Err(_) => return Err(ReadError::DecodeError),
513 Ok(size) => return Ok((size.0, i + 1)),
514 }
515 }
516 }
517
518 match buffer.len() < max_len {
519 true => Err(ReadError::NotEnoughBytes),
520 false => Err(ReadError::Overflow),
521 }
522}
523
524impl Stream for Substream {
525 type Item = Result<BytesMut, SubstreamError>;
526
527 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
528 let this = Pin::into_inner(self);
529
530 delegate_poll_next!(&mut this.substream, cx);
532
533 loop {
534 match this.codec {
535 ProtocolCodec::Identity(payload_size) => {
536 let mut read_buf =
537 ReadBuf::new(&mut this.read_buffer[this.offset..payload_size]);
538
539 match futures::ready!(poll_read!(&mut this.substream, cx, &mut read_buf)) {
540 Ok(_) => {
541 let nread = read_buf.filled().len();
542 if nread == 0 {
543 tracing::trace!(
544 target: LOG_TARGET,
545 peer = ?this.peer,
546 "read zero bytes, substream closed"
547 );
548 return Poll::Ready(None);
549 }
550
551 if nread == payload_size {
552 let mut payload = std::mem::replace(
553 &mut this.read_buffer,
554 BytesMut::zeroed(payload_size),
555 );
556 payload.truncate(payload_size);
557 this.offset = 0usize;
558
559 return Poll::Ready(Some(Ok(payload)));
560 } else {
561 this.offset += read_buf.filled().len();
562 }
563 }
564 Err(error) => return Poll::Ready(Some(Err(error.into()))),
565 }
566 }
567 ProtocolCodec::UnsignedVarint(max_size) => {
568 loop {
569 if let Some(frame) = this.pending_frames.pop_front() {
571 return Poll::Ready(Some(Ok(frame)));
572 }
573
574 match this.current_frame_size.take() {
575 Some(frame_size) => {
576 let mut read_buf =
577 ReadBuf::new(&mut this.read_buffer[this.offset..]);
578 this.current_frame_size = Some(frame_size);
579
580 match futures::ready!(poll_read!(
581 &mut this.substream,
582 cx,
583 &mut read_buf
584 )) {
585 Err(_error) => return Poll::Ready(None),
586 Ok(_) => {
587 let nread = match read_buf.filled().len() {
588 0 => return Poll::Ready(None),
589 nread => nread,
590 };
591
592 this.offset += nread;
593
594 if this.offset == frame_size {
595 let out_frame = std::mem::replace(
596 &mut this.read_buffer,
597 BytesMut::new(),
598 );
599 this.offset = 0;
600 this.current_frame_size = None;
601
602 return Poll::Ready(Some(Ok(out_frame)));
603 } else {
604 this.current_frame_size = Some(frame_size);
605 continue;
606 }
607 }
608 }
609 }
610 None => {
611 let mut read_buf =
612 ReadBuf::new(&mut this.size_vec[this.offset..this.offset + 1]);
613
614 match futures::ready!(poll_read!(
615 &mut this.substream,
616 cx,
617 &mut read_buf
618 )) {
619 Err(_error) => return Poll::Ready(None),
620 Ok(_) => {
621 if read_buf.filled().is_empty() {
622 return Poll::Ready(None);
623 }
624 this.offset += 1;
625
626 match read_payload_size(&this.size_vec[..this.offset]) {
627 Err(ReadError::NotEnoughBytes) => continue,
628 Err(_) =>
629 return Poll::Ready(Some(Err(
630 SubstreamError::ReadFailure(Some(
631 this.substream_id,
632 )),
633 ))),
634 Ok((size, num_bytes)) => {
635 debug_assert_eq!(num_bytes, this.offset);
636
637 if let Some(max_size) = max_size {
638 if size > max_size {
639 return Poll::Ready(Some(Err(
640 SubstreamError::ReadFailure(Some(
641 this.substream_id,
642 )),
643 )));
644 }
645 }
646
647 this.offset = 0;
648 this.current_frame_size = Some(size);
649 this.read_buffer = BytesMut::zeroed(size);
650 }
651 }
652 }
653 }
654 }
655 }
656 }
657 }
658 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
659 }
660 }
661 }
662}
663
664impl Sink<Bytes> for Substream {
666 type Error = SubstreamError;
667
668 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
669 delegate_poll_ready!(&mut self.substream, cx);
671
672 if self.pending_out_bytes >= BACKPRESSURE_BOUNDARY {
673 return poll_flush!(&mut self.substream, cx).map_err(From::from);
674 }
675
676 Poll::Ready(Ok(()))
677 }
678
679 fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
680 delegate_start_send!(&mut self.substream, item);
682
683 tracing::trace!(
684 target: LOG_TARGET,
685 peer = ?self.peer,
686 substream_id = ?self.substream_id,
687 data_len = item.len(),
688 "Substream::start_send()",
689 );
690
691 match self.codec {
692 ProtocolCodec::Identity(payload_size) => {
693 if item.len() != payload_size {
694 return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
695 }
696
697 self.pending_out_bytes += item.len();
698 self.pending_out_frames.push_back(item);
699 }
700 ProtocolCodec::UnsignedVarint(max_size) => {
701 check_size!(max_size, item.len());
702
703 let len = {
704 let mut buffer = unsigned_varint::encode::usize_buffer();
705 let len = unsigned_varint::encode::usize(item.len(), &mut buffer);
706 BytesMut::from(len)
707 };
708
709 self.pending_out_bytes += len.len() + item.len();
710 self.pending_out_frames.push_back(len.freeze());
711 self.pending_out_frames.push_back(item);
712 }
713 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
714 }
715
716 Ok(())
717 }
718
719 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
720 delegate_poll_flush!(&mut self.substream, cx);
722
723 loop {
724 let mut pending_frame = match self.pending_out_frame.take() {
725 Some(frame) => frame,
726 None => match self.pending_out_frames.pop_front() {
727 Some(frame) => frame,
728 None => break,
729 },
730 };
731
732 match poll_write!(&mut self.substream, cx, &pending_frame) {
733 Poll::Ready(Err(error)) => return Poll::Ready(Err(error.into())),
734 Poll::Pending => {
735 self.pending_out_frame = Some(pending_frame);
736 break;
737 }
738 Poll::Ready(Ok(nwritten)) => {
739 pending_frame.advance(nwritten);
740
741 if !pending_frame.is_empty() {
742 self.pending_out_frame = Some(pending_frame);
743 }
744 }
745 }
746 }
747
748 poll_flush!(&mut self.substream, cx).map_err(From::from)
749 }
750
751 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
752 poll_shutdown!(&mut self.substream, cx).map_err(From::from)
753 }
754}
755
756pub trait SubstreamSetKey: Hash + Unpin + fmt::Debug + PartialEq + Eq + Copy {}
758
759impl<K: Hash + Unpin + fmt::Debug + PartialEq + Eq + Copy> SubstreamSetKey for K {}
760
761#[derive(Debug, Default)]
763pub struct SubstreamSet<K, S>
764where
765 K: SubstreamSetKey,
766 S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
767{
768 substreams: HashMap<K, S>,
769}
770
771impl<K, S> SubstreamSet<K, S>
772where
773 K: SubstreamSetKey,
774 S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
775{
776 pub fn new() -> Self {
778 Self {
779 substreams: HashMap::new(),
780 }
781 }
782
783 pub fn insert(&mut self, key: K, substream: S) {
785 match self.substreams.entry(key) {
786 Entry::Vacant(entry) => {
787 entry.insert(substream);
788 }
789 Entry::Occupied(_) => {
790 tracing::error!(?key, "substream already exists");
791 debug_assert!(false);
792 }
793 }
794 }
795
796 pub fn remove(&mut self, key: &K) -> Option<S> {
798 self.substreams.remove(key)
799 }
800
801 #[cfg(test)]
803 pub fn get_mut(&mut self, key: &K) -> Option<&mut S> {
804 self.substreams.get_mut(key)
805 }
806
807 pub fn len(&self) -> usize {
809 self.substreams.len()
810 }
811
812 pub fn is_empty(&self) -> bool {
814 self.substreams.is_empty()
815 }
816}
817
818impl<K, S> Stream for SubstreamSet<K, S>
819where
820 K: SubstreamSetKey,
821 S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
822{
823 type Item = (K, <S as Stream>::Item);
824
825 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
826 let inner = Pin::into_inner(self);
827
828 for (key, mut substream) in inner.substreams.iter_mut() {
830 match Pin::new(&mut substream).poll_next(cx) {
831 Poll::Pending => continue,
832 Poll::Ready(Some(data)) => return Poll::Ready(Some((*key, data))),
833 Poll::Ready(None) =>
834 return Poll::Ready(Some((*key, Err(SubstreamError::ConnectionClosed)))),
835 }
836 }
837
838 Poll::Pending
839 }
840}
841
842#[cfg(test)]
843mod tests {
844 use super::*;
845 use crate::{mock::substream::MockSubstream, PeerId};
846 use futures::{SinkExt, StreamExt};
847
848 #[test]
849 fn add_substream() {
850 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
851
852 let peer = PeerId::random();
853 let substream = MockSubstream::new();
854 set.insert(peer, substream);
855
856 let peer = PeerId::random();
857 let substream = MockSubstream::new();
858 set.insert(peer, substream);
859 }
860
861 #[test]
862 #[should_panic]
863 #[cfg(debug_assertions)]
864 fn add_same_peer_twice() {
865 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
866
867 let peer = PeerId::random();
868 let substream1 = MockSubstream::new();
869 let substream2 = MockSubstream::new();
870
871 set.insert(peer, substream1);
872 set.insert(peer, substream2);
873 }
874
875 #[test]
876 fn remove_substream() {
877 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
878
879 let peer1 = PeerId::random();
880 let substream1 = MockSubstream::new();
881 set.insert(peer1, substream1);
882
883 let peer2 = PeerId::random();
884 let substream2 = MockSubstream::new();
885 set.insert(peer2, substream2);
886
887 assert!(set.remove(&peer1).is_some());
888 assert!(set.remove(&peer2).is_some());
889 assert!(set.remove(&PeerId::random()).is_none());
890 }
891
892 #[tokio::test]
893 async fn poll_data_from_substream() {
894 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
895
896 let peer = PeerId::random();
897 let mut substream = MockSubstream::new();
898 substream
899 .expect_poll_next()
900 .times(1)
901 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
902 substream
903 .expect_poll_next()
904 .times(1)
905 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
906 substream.expect_poll_next().returning(|_| Poll::Pending);
907 set.insert(peer, substream);
908
909 let value = set.next().await.unwrap();
910 assert_eq!(value.0, peer);
911 assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
912
913 let value = set.next().await.unwrap();
914 assert_eq!(value.0, peer);
915 assert_eq!(value.1.unwrap(), BytesMut::from(&b"world"[..]));
916
917 assert!(futures::poll!(set.next()).is_pending());
918 }
919
920 #[tokio::test]
921 async fn substream_closed() {
922 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
923
924 let peer = PeerId::random();
925 let mut substream = MockSubstream::new();
926 substream
927 .expect_poll_next()
928 .times(1)
929 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
930 substream.expect_poll_next().times(1).return_once(|_| Poll::Ready(None));
931 substream.expect_poll_next().returning(|_| Poll::Pending);
932 set.insert(peer, substream);
933
934 let value = set.next().await.unwrap();
935 assert_eq!(value.0, peer);
936 assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
937
938 match set.next().await {
939 Some((exited_peer, Err(SubstreamError::ConnectionClosed))) => {
940 assert_eq!(peer, exited_peer);
941 }
942 _ => panic!("inavlid event received"),
943 }
944 }
945
946 #[tokio::test]
947 async fn get_mut_substream() {
948 let _ = tracing_subscriber::fmt()
949 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
950 .try_init();
951
952 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
953
954 let peer = PeerId::random();
955 let mut substream = MockSubstream::new();
956 substream
957 .expect_poll_next()
958 .times(1)
959 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
960 substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(())));
961 substream.expect_start_send().times(1).return_once(|_| Ok(()));
962 substream.expect_poll_flush().times(1).return_once(|_| Poll::Ready(Ok(())));
963 substream
964 .expect_poll_next()
965 .times(1)
966 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
967 substream.expect_poll_next().returning(|_| Poll::Pending);
968 set.insert(peer, substream);
969
970 let value = set.next().await.unwrap();
971 assert_eq!(value.0, peer);
972 assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
973
974 let substream = set.get_mut(&peer).unwrap();
975 substream.send(vec![1, 2, 3, 4].into()).await.unwrap();
976
977 let value = set.next().await.unwrap();
978 assert_eq!(value.0, peer);
979 assert_eq!(value.1.unwrap(), BytesMut::from(&b"world"[..]));
980
981 assert!(set.get_mut(&PeerId::random()).is_none());
983 }
984
985 #[tokio::test]
986 async fn poll_data_from_two_substreams() {
987 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
988
989 let peer1 = PeerId::random();
991 let mut substream1 = MockSubstream::new();
992 substream1
993 .expect_poll_next()
994 .times(1)
995 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
996 substream1
997 .expect_poll_next()
998 .times(1)
999 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
1000 substream1.expect_poll_next().returning(|_| Poll::Pending);
1001 set.insert(peer1, substream1);
1002
1003 let peer2 = PeerId::random();
1005 let mut substream2 = MockSubstream::new();
1006 substream2
1007 .expect_poll_next()
1008 .times(1)
1009 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"siip"[..])))));
1010 substream2
1011 .expect_poll_next()
1012 .times(1)
1013 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"huup"[..])))));
1014 substream2.expect_poll_next().returning(|_| Poll::Pending);
1015 set.insert(peer2, substream2);
1016
1017 let expected: Vec<Vec<(PeerId, BytesMut)>> = vec![
1018 vec![
1019 (peer1, BytesMut::from(&b"hello"[..])),
1020 (peer1, BytesMut::from(&b"world"[..])),
1021 (peer2, BytesMut::from(&b"siip"[..])),
1022 (peer2, BytesMut::from(&b"huup"[..])),
1023 ],
1024 vec![
1025 (peer1, BytesMut::from(&b"hello"[..])),
1026 (peer2, BytesMut::from(&b"siip"[..])),
1027 (peer1, BytesMut::from(&b"world"[..])),
1028 (peer2, BytesMut::from(&b"huup"[..])),
1029 ],
1030 vec![
1031 (peer2, BytesMut::from(&b"siip"[..])),
1032 (peer2, BytesMut::from(&b"huup"[..])),
1033 (peer1, BytesMut::from(&b"hello"[..])),
1034 (peer1, BytesMut::from(&b"world"[..])),
1035 ],
1036 vec![
1037 (peer1, BytesMut::from(&b"hello"[..])),
1038 (peer2, BytesMut::from(&b"siip"[..])),
1039 (peer2, BytesMut::from(&b"huup"[..])),
1040 (peer1, BytesMut::from(&b"world"[..])),
1041 ],
1042 ];
1043
1044 let mut values = Vec::new();
1046
1047 for _ in 0..4 {
1048 let value = set.next().await.unwrap();
1049 values.push((value.0, value.1.unwrap()));
1050 }
1051
1052 let mut correct_found = false;
1053
1054 for set in expected {
1055 if values == set {
1056 correct_found = true;
1057 break;
1058 }
1059 }
1060
1061 if !correct_found {
1062 panic!("invalid set generated");
1063 }
1064
1065 for _ in 0..10 {
1067 assert!(futures::poll!(set.next()).is_pending());
1068 }
1069 }
1070}