1use crate::{
25 codec::ProtocolCodec, error::SubstreamError, transport::tcp, types::SubstreamId, PeerId,
26};
27
28#[cfg(feature = "quic")]
29use crate::transport::quic;
30#[cfg(feature = "webrtc")]
31use crate::transport::webrtc;
32#[cfg(feature = "websocket")]
33use crate::transport::websocket;
34
35use bytes::{Buf, Bytes, BytesMut};
36use futures::{Sink, Stream};
37use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
38use unsigned_varint::{decode, encode};
39
40use std::{
41 collections::{hash_map::Entry, HashMap, VecDeque},
42 fmt,
43 hash::Hash,
44 io::ErrorKind,
45 pin::Pin,
46 task::{Context, Poll},
47};
48
49const LOG_TARGET: &str = "litep2p::substream";
51
52macro_rules! poll_flush {
53 ($substream:expr, $cx:ident) => {{
54 match $substream {
55 SubstreamType::Tcp(substream) => Pin::new(substream).poll_flush($cx),
56 #[cfg(feature = "websocket")]
57 SubstreamType::WebSocket(substream) => Pin::new(substream).poll_flush($cx),
58 #[cfg(feature = "quic")]
59 SubstreamType::Quic(substream) => Pin::new(substream).poll_flush($cx),
60 #[cfg(feature = "webrtc")]
61 SubstreamType::WebRtc(substream) => Pin::new(substream).poll_flush($cx),
62 #[cfg(test)]
63 SubstreamType::Mock(_) => unreachable!(),
64 }
65 }};
66}
67
68macro_rules! poll_write {
69 ($substream:expr, $cx:ident, $frame:expr) => {{
70 match $substream {
71 SubstreamType::Tcp(substream) => Pin::new(substream).poll_write($cx, $frame),
72 #[cfg(feature = "websocket")]
73 SubstreamType::WebSocket(substream) => Pin::new(substream).poll_write($cx, $frame),
74 #[cfg(feature = "quic")]
75 SubstreamType::Quic(substream) => Pin::new(substream).poll_write($cx, $frame),
76 #[cfg(feature = "webrtc")]
77 SubstreamType::WebRtc(substream) => Pin::new(substream).poll_write($cx, $frame),
78 #[cfg(test)]
79 SubstreamType::Mock(_) => unreachable!(),
80 }
81 }};
82}
83
84macro_rules! poll_read {
85 ($substream:expr, $cx:ident, $buffer:expr) => {{
86 match $substream {
87 SubstreamType::Tcp(substream) => Pin::new(substream).poll_read($cx, $buffer),
88 #[cfg(feature = "websocket")]
89 SubstreamType::WebSocket(substream) => Pin::new(substream).poll_read($cx, $buffer),
90 #[cfg(feature = "quic")]
91 SubstreamType::Quic(substream) => Pin::new(substream).poll_read($cx, $buffer),
92 #[cfg(feature = "webrtc")]
93 SubstreamType::WebRtc(substream) => Pin::new(substream).poll_read($cx, $buffer),
94 #[cfg(test)]
95 SubstreamType::Mock(_) => unreachable!(),
96 }
97 }};
98}
99
100macro_rules! poll_shutdown {
101 ($substream:expr, $cx:ident) => {{
102 match $substream {
103 SubstreamType::Tcp(substream) => Pin::new(substream).poll_shutdown($cx),
104 #[cfg(feature = "websocket")]
105 SubstreamType::WebSocket(substream) => Pin::new(substream).poll_shutdown($cx),
106 #[cfg(feature = "quic")]
107 SubstreamType::Quic(substream) => Pin::new(substream).poll_shutdown($cx),
108 #[cfg(feature = "webrtc")]
109 SubstreamType::WebRtc(substream) => Pin::new(substream).poll_shutdown($cx),
110 #[cfg(test)]
111 SubstreamType::Mock(substream) => {
112 let _ = Pin::new(substream).poll_close($cx);
113 todo!();
114 }
115 }
116 }};
117}
118
119macro_rules! delegate_poll_next {
120 ($substream:expr, $cx:ident) => {{
121 #[cfg(test)]
122 if let SubstreamType::Mock(inner) = $substream {
123 return Pin::new(inner).poll_next($cx);
124 }
125 }};
126}
127
128macro_rules! delegate_poll_ready {
129 ($substream:expr, $cx:ident) => {{
130 #[cfg(test)]
131 if let SubstreamType::Mock(inner) = $substream {
132 return Pin::new(inner).poll_ready($cx);
133 }
134 }};
135}
136
137macro_rules! delegate_start_send {
138 ($substream:expr, $item:ident) => {{
139 #[cfg(test)]
140 if let SubstreamType::Mock(inner) = $substream {
141 return Pin::new(inner).start_send($item);
142 }
143 }};
144}
145
146macro_rules! delegate_poll_flush {
147 ($substream:expr, $cx:ident) => {{
148 #[cfg(test)]
149 if let SubstreamType::Mock(inner) = $substream {
150 return Pin::new(inner).poll_flush($cx);
151 }
152 }};
153}
154
155macro_rules! check_size {
156 ($max_size:expr, $size:expr) => {{
157 if let Some(max_size) = $max_size {
158 if $size > max_size {
159 return Err(SubstreamError::IoError(ErrorKind::PermissionDenied).into());
160 }
161 }
162 }};
163}
164
165enum SubstreamType {
167 Tcp(tcp::Substream),
168 #[cfg(feature = "websocket")]
169 WebSocket(websocket::Substream),
170 #[cfg(feature = "quic")]
171 Quic(quic::Substream),
172 #[cfg(feature = "webrtc")]
173 WebRtc(webrtc::Substream),
174 #[cfg(test)]
175 Mock(Box<dyn crate::mock::substream::Substream>),
176}
177
178impl fmt::Debug for SubstreamType {
179 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180 match self {
181 Self::Tcp(_) => write!(f, "Tcp"),
182 #[cfg(feature = "websocket")]
183 Self::WebSocket(_) => write!(f, "WebSocket"),
184 #[cfg(feature = "quic")]
185 Self::Quic(_) => write!(f, "Quic"),
186 #[cfg(feature = "webrtc")]
187 Self::WebRtc(_) => write!(f, "WebRtc"),
188 #[cfg(test)]
189 Self::Mock(_) => write!(f, "Mock"),
190 }
191 }
192}
193
194const BACKPRESSURE_BOUNDARY: usize = 65536;
196
197pub struct Substream {
206 peer: PeerId,
208
209 substream: SubstreamType,
211
212 substream_id: SubstreamId,
214
215 codec: ProtocolCodec,
217
218 pending_out_frames: VecDeque<Bytes>,
219 pending_out_bytes: usize,
220 pending_out_frame: Option<Bytes>,
221
222 read_buffer: BytesMut,
223 offset: usize,
224 pending_frames: VecDeque<BytesMut>,
225 current_frame_size: Option<usize>,
226
227 size_vec: BytesMut,
228}
229
230impl fmt::Debug for Substream {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 f.debug_struct("Substream")
233 .field("peer", &self.peer)
234 .field("substream_id", &self.substream_id)
235 .field("codec", &self.codec)
236 .field("protocol", &self.substream)
237 .finish()
238 }
239}
240
241impl Substream {
242 fn new(
244 peer: PeerId,
245 substream_id: SubstreamId,
246 substream: SubstreamType,
247 codec: ProtocolCodec,
248 ) -> Self {
249 Self {
250 peer,
251 substream,
252 codec,
253 substream_id,
254 read_buffer: BytesMut::zeroed(1024),
255 offset: 0usize,
256 pending_frames: VecDeque::new(),
257 current_frame_size: None,
258 pending_out_bytes: 0usize,
259 pending_out_frames: VecDeque::new(),
260 pending_out_frame: None,
261 size_vec: BytesMut::zeroed(10),
262 }
263 }
264
265 pub(crate) fn new_tcp(
267 peer: PeerId,
268 substream_id: SubstreamId,
269 substream: tcp::Substream,
270 codec: ProtocolCodec,
271 ) -> Self {
272 tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for tcp");
273
274 Self::new(peer, substream_id, SubstreamType::Tcp(substream), codec)
275 }
276
277 #[cfg(feature = "websocket")]
279 pub(crate) fn new_websocket(
280 peer: PeerId,
281 substream_id: SubstreamId,
282 substream: websocket::Substream,
283 codec: ProtocolCodec,
284 ) -> Self {
285 tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for websocket");
286
287 Self::new(
288 peer,
289 substream_id,
290 SubstreamType::WebSocket(substream),
291 codec,
292 )
293 }
294
295 #[cfg(feature = "quic")]
297 pub(crate) fn new_quic(
298 peer: PeerId,
299 substream_id: SubstreamId,
300 substream: quic::Substream,
301 codec: ProtocolCodec,
302 ) -> Self {
303 tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for quic");
304
305 Self::new(peer, substream_id, SubstreamType::Quic(substream), codec)
306 }
307
308 #[cfg(feature = "webrtc")]
310 pub(crate) fn new_webrtc(
311 peer: PeerId,
312 substream_id: SubstreamId,
313 substream: webrtc::Substream,
314 codec: ProtocolCodec,
315 ) -> Self {
316 tracing::trace!(target: LOG_TARGET, ?peer, ?codec, "create new substream for webrtc");
317
318 Self::new(peer, substream_id, SubstreamType::WebRtc(substream), codec)
319 }
320
321 #[cfg(test)]
323 pub(crate) fn new_mock(
324 peer: PeerId,
325 substream_id: SubstreamId,
326 substream: Box<dyn crate::mock::substream::Substream>,
327 ) -> Self {
328 tracing::trace!(target: LOG_TARGET, ?peer, "create new substream for mocking");
329
330 Self::new(
331 peer,
332 substream_id,
333 SubstreamType::Mock(substream),
334 ProtocolCodec::Unspecified,
335 )
336 }
337
338 pub async fn close(self) {
340 let _ = match self.substream {
341 SubstreamType::Tcp(mut substream) => substream.shutdown().await,
342 #[cfg(feature = "websocket")]
343 SubstreamType::WebSocket(mut substream) => substream.shutdown().await,
344 #[cfg(feature = "quic")]
345 SubstreamType::Quic(mut substream) => substream.shutdown().await,
346 #[cfg(feature = "webrtc")]
347 SubstreamType::WebRtc(mut substream) => substream.shutdown().await,
348 #[cfg(test)]
349 SubstreamType::Mock(mut substream) => {
350 let _ = futures::SinkExt::close(&mut substream).await;
351 Ok(())
352 }
353 };
354 }
355
356 async fn send_identity_payload<T: AsyncWrite + Unpin>(
358 io: &mut T,
359 payload_size: usize,
360 payload: Bytes,
361 ) -> Result<(), SubstreamError> {
362 if payload.len() != payload_size {
363 return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
364 }
365
366 io.write_all(&payload).await.map_err(|_| SubstreamError::ConnectionClosed)?;
367
368 io.flush().await.map_err(From::from)
370 }
371
372 async fn send_unsigned_varint_payload<T: AsyncWrite + Unpin>(
374 io: &mut T,
375 bytes: Bytes,
376 max_size: Option<usize>,
377 ) -> Result<(), SubstreamError> {
378 if let Some(max_size) = max_size {
379 if bytes.len() > max_size {
380 return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
381 }
382 }
383
384 let mut buffer = unsigned_varint::encode::usize_buffer();
386 let encoded_len = unsigned_varint::encode::usize(bytes.len(), &mut buffer).len();
387 io.write_all(&buffer[..encoded_len]).await?;
388
389 io.write_all(bytes.as_ref()).await?;
391
392 io.flush().await.map_err(From::from)
394 }
395
396 pub async fn send_framed(&mut self, bytes: Bytes) -> Result<(), SubstreamError> {
411 tracing::trace!(
412 target: LOG_TARGET,
413 peer = ?self.peer,
414 codec = ?self.codec,
415 frame_len = ?bytes.len(),
416 "send framed"
417 );
418
419 match &mut self.substream {
420 #[cfg(test)]
421 SubstreamType::Mock(ref mut substream) =>
422 futures::SinkExt::send(substream, bytes).await,
423 SubstreamType::Tcp(ref mut substream) => match self.codec {
424 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
425 ProtocolCodec::Identity(payload_size) =>
426 Self::send_identity_payload(substream, payload_size, bytes).await,
427 ProtocolCodec::UnsignedVarint(max_size) =>
428 Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
429 },
430 #[cfg(feature = "websocket")]
431 SubstreamType::WebSocket(ref mut substream) => match self.codec {
432 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
433 ProtocolCodec::Identity(payload_size) =>
434 Self::send_identity_payload(substream, payload_size, bytes).await,
435 ProtocolCodec::UnsignedVarint(max_size) =>
436 Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
437 },
438 #[cfg(feature = "quic")]
439 SubstreamType::Quic(ref mut substream) => match self.codec {
440 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
441 ProtocolCodec::Identity(payload_size) =>
442 Self::send_identity_payload(substream, payload_size, bytes).await,
443 ProtocolCodec::UnsignedVarint(max_size) => {
444 check_size!(max_size, bytes.len());
445
446 let mut buffer = unsigned_varint::encode::usize_buffer();
447 let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer);
448 let len = BytesMut::from(len);
449
450 substream.write_all_chunks(&mut [len.freeze(), bytes]).await
451 }
452 },
453 #[cfg(feature = "webrtc")]
454 SubstreamType::WebRtc(ref mut substream) => match self.codec {
455 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
456 ProtocolCodec::Identity(payload_size) =>
457 Self::send_identity_payload(substream, payload_size, bytes).await,
458 ProtocolCodec::UnsignedVarint(max_size) =>
459 Self::send_unsigned_varint_payload(substream, bytes, max_size).await,
460 },
461 }
462 }
463}
464
465impl tokio::io::AsyncRead for Substream {
466 fn poll_read(
467 mut self: Pin<&mut Self>,
468 cx: &mut Context<'_>,
469 buf: &mut tokio::io::ReadBuf<'_>,
470 ) -> Poll<std::io::Result<()>> {
471 poll_read!(&mut self.substream, cx, buf)
472 }
473}
474
475impl tokio::io::AsyncWrite for Substream {
476 fn poll_write(
477 mut self: Pin<&mut Self>,
478 cx: &mut Context<'_>,
479 buf: &[u8],
480 ) -> Poll<Result<usize, std::io::Error>> {
481 poll_write!(&mut self.substream, cx, buf)
482 }
483
484 fn poll_flush(
485 mut self: Pin<&mut Self>,
486 cx: &mut Context<'_>,
487 ) -> Poll<Result<(), std::io::Error>> {
488 poll_flush!(&mut self.substream, cx)
489 }
490
491 fn poll_shutdown(
492 mut self: Pin<&mut Self>,
493 cx: &mut Context<'_>,
494 ) -> Poll<Result<(), std::io::Error>> {
495 poll_shutdown!(&mut self.substream, cx)
496 }
497}
498
499enum ReadError {
500 Overflow,
501 NotEnoughBytes,
502 DecodeError,
503}
504
505fn read_payload_size(buffer: &[u8]) -> Result<(usize, usize), ReadError> {
507 let max_len = encode::usize_buffer().len();
508
509 for i in 0..std::cmp::min(buffer.len(), max_len) {
510 if decode::is_last(buffer[i]) {
511 match decode::usize(&buffer[..=i]) {
512 Err(_) => return Err(ReadError::DecodeError),
513 Ok(size) => return Ok((size.0, i + 1)),
514 }
515 }
516 }
517
518 match buffer.len() < max_len {
519 true => Err(ReadError::NotEnoughBytes),
520 false => Err(ReadError::Overflow),
521 }
522}
523
524impl Stream for Substream {
525 type Item = Result<BytesMut, SubstreamError>;
526
527 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
528 let this = Pin::into_inner(self);
529
530 delegate_poll_next!(&mut this.substream, cx);
532
533 loop {
534 match this.codec {
535 ProtocolCodec::Identity(payload_size) => {
536 let mut read_buf =
537 ReadBuf::new(&mut this.read_buffer[this.offset..payload_size]);
538
539 match futures::ready!(poll_read!(&mut this.substream, cx, &mut read_buf)) {
540 Ok(_) => {
541 let nread = read_buf.filled().len();
542 if nread == 0 {
543 tracing::trace!(
544 target: LOG_TARGET,
545 peer = ?this.peer,
546 "read zero bytes, substream closed"
547 );
548 return Poll::Ready(None);
549 }
550
551 if nread == payload_size {
552 let mut payload = std::mem::replace(
553 &mut this.read_buffer,
554 BytesMut::zeroed(payload_size),
555 );
556 payload.truncate(payload_size);
557 this.offset = 0usize;
558
559 return Poll::Ready(Some(Ok(payload)));
560 } else {
561 this.offset += read_buf.filled().len();
562 }
563 }
564 Err(error) => return Poll::Ready(Some(Err(error.into()))),
565 }
566 }
567 ProtocolCodec::UnsignedVarint(max_size) => {
568 loop {
569 if let Some(frame) = this.pending_frames.pop_front() {
571 return Poll::Ready(Some(Ok(frame)));
572 }
573
574 match this.current_frame_size.take() {
575 Some(frame_size) => {
576 let mut read_buf =
577 ReadBuf::new(&mut this.read_buffer[this.offset..]);
578 this.current_frame_size = Some(frame_size);
579
580 match futures::ready!(poll_read!(
581 &mut this.substream,
582 cx,
583 &mut read_buf
584 )) {
585 Err(_error) => return Poll::Ready(None),
586 Ok(_) => {
587 let nread = match read_buf.filled().len() {
588 0 => return Poll::Ready(None),
589 nread => nread,
590 };
591
592 this.offset += nread;
593
594 if this.offset == frame_size {
595 let out_frame = std::mem::replace(
596 &mut this.read_buffer,
597 BytesMut::new(),
598 );
599 this.offset = 0;
600 this.current_frame_size = None;
601
602 return Poll::Ready(Some(Ok(out_frame)));
603 } else {
604 this.current_frame_size = Some(frame_size);
605 continue;
606 }
607 }
608 }
609 }
610 None => {
611 let mut read_buf =
612 ReadBuf::new(&mut this.size_vec[this.offset..this.offset + 1]);
613
614 match futures::ready!(poll_read!(
615 &mut this.substream,
616 cx,
617 &mut read_buf
618 )) {
619 Err(_error) => return Poll::Ready(None),
620 Ok(_) => {
621 if read_buf.filled().is_empty() {
622 return Poll::Ready(None);
623 }
624 this.offset += 1;
625
626 match read_payload_size(&this.size_vec[..this.offset]) {
627 Err(ReadError::NotEnoughBytes) => continue,
628 Err(_) =>
629 return Poll::Ready(Some(Err(
630 SubstreamError::ReadFailure(Some(
631 this.substream_id,
632 )),
633 ))),
634 Ok((size, num_bytes)) => {
635 debug_assert_eq!(num_bytes, this.offset);
636
637 if let Some(max_size) = max_size {
638 if size > max_size {
639 return Poll::Ready(Some(Err(
640 SubstreamError::ReadFailure(Some(
641 this.substream_id,
642 )),
643 )));
644 }
645 }
646
647 this.offset = 0;
648 if size == 0 {
652 return Poll::Ready(Some(Ok(BytesMut::new())));
653 }
654
655 this.current_frame_size = Some(size);
656 this.read_buffer = BytesMut::zeroed(size);
657 }
658 }
659 }
660 }
661 }
662 }
663 }
664 }
665 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
666 }
667 }
668 }
669}
670
671impl Sink<Bytes> for Substream {
673 type Error = SubstreamError;
674
675 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
676 delegate_poll_ready!(&mut self.substream, cx);
678
679 if self.pending_out_bytes >= BACKPRESSURE_BOUNDARY {
680 return poll_flush!(&mut self.substream, cx).map_err(From::from);
681 }
682
683 Poll::Ready(Ok(()))
684 }
685
686 fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
687 delegate_start_send!(&mut self.substream, item);
689
690 tracing::trace!(
691 target: LOG_TARGET,
692 peer = ?self.peer,
693 substream_id = ?self.substream_id,
694 data_len = item.len(),
695 "Substream::start_send()",
696 );
697
698 match self.codec {
699 ProtocolCodec::Identity(payload_size) => {
700 if item.len() != payload_size {
701 return Err(SubstreamError::IoError(ErrorKind::PermissionDenied));
702 }
703
704 self.pending_out_bytes += item.len();
705 self.pending_out_frames.push_back(item);
706 }
707 ProtocolCodec::UnsignedVarint(max_size) => {
708 check_size!(max_size, item.len());
709
710 let len = {
711 let mut buffer = unsigned_varint::encode::usize_buffer();
712 let len = unsigned_varint::encode::usize(item.len(), &mut buffer);
713 BytesMut::from(len)
714 };
715
716 self.pending_out_bytes += len.len() + item.len();
717 self.pending_out_frames.push_back(len.freeze());
718 self.pending_out_frames.push_back(item);
719 }
720 ProtocolCodec::Unspecified => panic!("codec is unspecified"),
721 }
722
723 Ok(())
724 }
725
726 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
727 delegate_poll_flush!(&mut self.substream, cx);
729
730 loop {
731 let mut pending_frame = match self.pending_out_frame.take() {
732 Some(frame) => frame,
733 None => match self.pending_out_frames.pop_front() {
734 Some(frame) => frame,
735 None => break,
736 },
737 };
738
739 match poll_write!(&mut self.substream, cx, &pending_frame) {
740 Poll::Ready(Err(error)) => return Poll::Ready(Err(error.into())),
741 Poll::Pending => {
742 self.pending_out_frame = Some(pending_frame);
743 break;
744 }
745 Poll::Ready(Ok(nwritten)) => {
746 pending_frame.advance(nwritten);
747
748 if !pending_frame.is_empty() {
749 self.pending_out_frame = Some(pending_frame);
750 }
751 }
752 }
753 }
754
755 poll_flush!(&mut self.substream, cx).map_err(From::from)
756 }
757
758 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
759 poll_shutdown!(&mut self.substream, cx).map_err(From::from)
760 }
761}
762
763pub trait SubstreamSetKey: Hash + Unpin + fmt::Debug + PartialEq + Eq + Copy {}
765
766impl<K: Hash + Unpin + fmt::Debug + PartialEq + Eq + Copy> SubstreamSetKey for K {}
767
768#[derive(Debug, Default)]
771pub struct SubstreamSet<K, S>
772where
773 K: SubstreamSetKey,
774 S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
775{
776 substreams: HashMap<K, S>,
777}
778
779impl<K, S> SubstreamSet<K, S>
780where
781 K: SubstreamSetKey,
782 S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
783{
784 pub fn new() -> Self {
786 Self {
787 substreams: HashMap::new(),
788 }
789 }
790
791 pub fn insert(&mut self, key: K, substream: S) {
793 match self.substreams.entry(key) {
794 Entry::Vacant(entry) => {
795 entry.insert(substream);
796 }
797 Entry::Occupied(_) => {
798 tracing::error!(?key, "substream already exists");
799 debug_assert!(false);
800 }
801 }
802 }
803
804 pub fn remove(&mut self, key: &K) -> Option<S> {
806 self.substreams.remove(key)
807 }
808
809 #[cfg(test)]
811 pub fn get_mut(&mut self, key: &K) -> Option<&mut S> {
812 self.substreams.get_mut(key)
813 }
814
815 pub fn len(&self) -> usize {
817 self.substreams.len()
818 }
819
820 pub fn is_empty(&self) -> bool {
822 self.substreams.is_empty()
823 }
824}
825
826impl<K, S> Stream for SubstreamSet<K, S>
827where
828 K: SubstreamSetKey,
829 S: Stream<Item = Result<BytesMut, SubstreamError>> + Unpin,
830{
831 type Item = (K, <S as Stream>::Item);
832
833 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
834 let inner = Pin::into_inner(self);
835
836 for (key, mut substream) in inner.substreams.iter_mut() {
837 match Pin::new(&mut substream).poll_next(cx) {
838 Poll::Pending => continue,
839 Poll::Ready(Some(data)) => return Poll::Ready(Some((*key, data))),
840 Poll::Ready(None) =>
841 return Poll::Ready(Some((*key, Err(SubstreamError::ConnectionClosed)))),
842 }
843 }
844
845 Poll::Pending
846 }
847}
848
849#[cfg(test)]
850mod tests {
851 use super::*;
852 use crate::{mock::substream::MockSubstream, PeerId};
853 use futures::{SinkExt, StreamExt};
854
855 #[test]
856 fn add_substream() {
857 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
858
859 let peer = PeerId::random();
860 let substream = MockSubstream::new();
861 set.insert(peer, substream);
862
863 let peer = PeerId::random();
864 let substream = MockSubstream::new();
865 set.insert(peer, substream);
866 }
867
868 #[test]
869 #[should_panic]
870 #[cfg(debug_assertions)]
871 fn add_same_peer_twice() {
872 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
873
874 let peer = PeerId::random();
875 let substream1 = MockSubstream::new();
876 let substream2 = MockSubstream::new();
877
878 set.insert(peer, substream1);
879 set.insert(peer, substream2);
880 }
881
882 #[test]
883 fn remove_substream() {
884 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
885
886 let peer1 = PeerId::random();
887 let substream1 = MockSubstream::new();
888 set.insert(peer1, substream1);
889
890 let peer2 = PeerId::random();
891 let substream2 = MockSubstream::new();
892 set.insert(peer2, substream2);
893
894 assert!(set.remove(&peer1).is_some());
895 assert!(set.remove(&peer2).is_some());
896 assert!(set.remove(&PeerId::random()).is_none());
897 }
898
899 #[tokio::test]
900 async fn poll_data_from_substream() {
901 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
902
903 let peer = PeerId::random();
904 let mut substream = MockSubstream::new();
905 substream
906 .expect_poll_next()
907 .times(1)
908 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
909 substream
910 .expect_poll_next()
911 .times(1)
912 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
913 substream.expect_poll_next().returning(|_| Poll::Pending);
914 set.insert(peer, substream);
915
916 let value = set.next().await.unwrap();
917 assert_eq!(value.0, peer);
918 assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
919
920 let value = set.next().await.unwrap();
921 assert_eq!(value.0, peer);
922 assert_eq!(value.1.unwrap(), BytesMut::from(&b"world"[..]));
923
924 assert!(futures::poll!(set.next()).is_pending());
925 }
926
927 #[tokio::test]
928 async fn substream_closed() {
929 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
930
931 let peer = PeerId::random();
932 let mut substream = MockSubstream::new();
933 substream
934 .expect_poll_next()
935 .times(1)
936 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
937 substream.expect_poll_next().times(1).return_once(|_| Poll::Ready(None));
938 substream.expect_poll_next().returning(|_| Poll::Pending);
939 set.insert(peer, substream);
940
941 let value = set.next().await.unwrap();
942 assert_eq!(value.0, peer);
943 assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
944
945 match set.next().await {
946 Some((exited_peer, Err(SubstreamError::ConnectionClosed))) => {
947 assert_eq!(peer, exited_peer);
948 }
949 _ => panic!("inavlid event received"),
950 }
951 }
952
953 #[tokio::test]
954 async fn get_mut_substream() {
955 let _ = tracing_subscriber::fmt()
956 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
957 .try_init();
958
959 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
960
961 let peer = PeerId::random();
962 let mut substream = MockSubstream::new();
963 substream
964 .expect_poll_next()
965 .times(1)
966 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
967 substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(())));
968 substream.expect_start_send().times(1).return_once(|_| Ok(()));
969 substream.expect_poll_flush().times(1).return_once(|_| Poll::Ready(Ok(())));
970 substream
971 .expect_poll_next()
972 .times(1)
973 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
974 substream.expect_poll_next().returning(|_| Poll::Pending);
975 set.insert(peer, substream);
976
977 let value = set.next().await.unwrap();
978 assert_eq!(value.0, peer);
979 assert_eq!(value.1.unwrap(), BytesMut::from(&b"hello"[..]));
980
981 let substream = set.get_mut(&peer).unwrap();
982 substream.send(vec![1, 2, 3, 4].into()).await.unwrap();
983
984 let value = set.next().await.unwrap();
985 assert_eq!(value.0, peer);
986 assert_eq!(value.1.unwrap(), BytesMut::from(&b"world"[..]));
987
988 assert!(set.get_mut(&PeerId::random()).is_none());
990 }
991
992 #[tokio::test]
993 async fn poll_data_from_two_substreams() {
994 let mut set = SubstreamSet::<PeerId, MockSubstream>::new();
995
996 let peer1 = PeerId::random();
998 let mut substream1 = MockSubstream::new();
999 substream1
1000 .expect_poll_next()
1001 .times(1)
1002 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"hello"[..])))));
1003 substream1
1004 .expect_poll_next()
1005 .times(1)
1006 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"world"[..])))));
1007 substream1.expect_poll_next().returning(|_| Poll::Pending);
1008 set.insert(peer1, substream1);
1009
1010 let peer2 = PeerId::random();
1012 let mut substream2 = MockSubstream::new();
1013 substream2
1014 .expect_poll_next()
1015 .times(1)
1016 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"siip"[..])))));
1017 substream2
1018 .expect_poll_next()
1019 .times(1)
1020 .return_once(|_| Poll::Ready(Some(Ok(BytesMut::from(&b"huup"[..])))));
1021 substream2.expect_poll_next().returning(|_| Poll::Pending);
1022 set.insert(peer2, substream2);
1023
1024 let expected: Vec<Vec<(PeerId, BytesMut)>> = vec![
1025 vec![
1026 (peer1, BytesMut::from(&b"hello"[..])),
1027 (peer1, BytesMut::from(&b"world"[..])),
1028 (peer2, BytesMut::from(&b"siip"[..])),
1029 (peer2, BytesMut::from(&b"huup"[..])),
1030 ],
1031 vec![
1032 (peer1, BytesMut::from(&b"hello"[..])),
1033 (peer2, BytesMut::from(&b"siip"[..])),
1034 (peer1, BytesMut::from(&b"world"[..])),
1035 (peer2, BytesMut::from(&b"huup"[..])),
1036 ],
1037 vec![
1038 (peer2, BytesMut::from(&b"siip"[..])),
1039 (peer2, BytesMut::from(&b"huup"[..])),
1040 (peer1, BytesMut::from(&b"hello"[..])),
1041 (peer1, BytesMut::from(&b"world"[..])),
1042 ],
1043 vec![
1044 (peer1, BytesMut::from(&b"hello"[..])),
1045 (peer2, BytesMut::from(&b"siip"[..])),
1046 (peer2, BytesMut::from(&b"huup"[..])),
1047 (peer1, BytesMut::from(&b"world"[..])),
1048 ],
1049 ];
1050
1051 let mut values = Vec::new();
1053
1054 for _ in 0..4 {
1055 let value = set.next().await.unwrap();
1056 values.push((value.0, value.1.unwrap()));
1057 }
1058
1059 let mut correct_found = false;
1060
1061 for set in expected {
1062 if values == set {
1063 correct_found = true;
1064 break;
1065 }
1066 }
1067
1068 if !correct_found {
1069 panic!("invalid set generated");
1070 }
1071
1072 for _ in 0..10 {
1074 assert!(futures::poll!(set.next()).is_pending());
1075 }
1076 }
1077}