1use crate::types::ProtocolName;
38
39use asynchronous_codec::Framed;
40use bytes::BytesMut;
41use futures::prelude::*;
42use libp2p::{
43 core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo},
44 PeerId,
45};
46use log::{debug, error, warn};
47use unsigned_varint::codec::UviBytes;
48
49use std::{
50 fmt, io, mem,
51 pin::Pin,
52 task::{Context, Poll},
53 vec,
54};
55
56const LOG_TARGET: &str = "sub-libp2p::notification::upgrade";
58
59const MAX_HANDSHAKE_SIZE: usize = 1024;
61
62#[derive(Debug, Clone)]
65pub struct NotificationsIn {
66 protocol_names: Vec<ProtocolName>,
69 max_notification_size: u64,
71}
72
73#[derive(Debug, Clone)]
76pub struct NotificationsOut {
77 protocol_names: Vec<ProtocolName>,
80 initial_message: Vec<u8>,
82 max_notification_size: u64,
84 peer_id: PeerId,
86}
87
88#[pin_project::pin_project]
93pub struct NotificationsInSubstream<TSubstream> {
94 #[pin]
95 socket: Framed<TSubstream, UviBytes<io::Cursor<Vec<u8>>>>,
96 handshake: NotificationsInSubstreamHandshake,
97}
98
99#[derive(Debug)]
101pub enum NotificationsInSubstreamHandshake {
102 NotSent,
104 PendingSend(Vec<u8>),
106 Flush,
108 Sent,
110 ClosingInResponseToRemote,
112 BothSidesClosed,
114}
115
116#[pin_project::pin_project]
118pub struct NotificationsOutSubstream<TSubstream> {
119 #[pin]
121 socket: Framed<TSubstream, UviBytes<io::Cursor<Vec<u8>>>>,
122
123 peer_id: PeerId,
125}
126
127#[cfg(test)]
128impl<TSubstream> NotificationsOutSubstream<TSubstream> {
129 pub fn new(socket: Framed<TSubstream, UviBytes<io::Cursor<Vec<u8>>>>) -> Self {
130 Self { socket, peer_id: PeerId::random() }
131 }
132}
133
134impl NotificationsIn {
135 pub fn new(
137 main_protocol_name: impl Into<ProtocolName>,
138 fallback_names: Vec<ProtocolName>,
139 max_notification_size: u64,
140 ) -> Self {
141 let mut protocol_names = fallback_names;
142 protocol_names.insert(0, main_protocol_name.into());
143
144 Self { protocol_names, max_notification_size }
145 }
146}
147
148impl UpgradeInfo for NotificationsIn {
149 type Info = ProtocolName;
150 type InfoIter = vec::IntoIter<Self::Info>;
151
152 fn protocol_info(&self) -> Self::InfoIter {
153 self.protocol_names.clone().into_iter()
154 }
155}
156
157impl<TSubstream> InboundUpgrade<TSubstream> for NotificationsIn
158where
159 TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
160{
161 type Output = NotificationsInOpen<TSubstream>;
162 type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
163 type Error = NotificationsHandshakeError;
164
165 fn upgrade_inbound(self, mut socket: TSubstream, _negotiated_name: Self::Info) -> Self::Future {
166 Box::pin(async move {
167 let handshake_len = unsigned_varint::aio::read_usize(&mut socket).await?;
168 if handshake_len > MAX_HANDSHAKE_SIZE {
169 return Err(NotificationsHandshakeError::TooLarge {
170 requested: handshake_len,
171 max: MAX_HANDSHAKE_SIZE,
172 })
173 }
174
175 let mut handshake = vec![0u8; handshake_len];
176 if !handshake.is_empty() {
177 socket.read_exact(&mut handshake).await?;
178 }
179
180 let mut codec = UviBytes::default();
181 codec.set_max_len(usize::try_from(self.max_notification_size).unwrap_or(usize::MAX));
182
183 let substream = NotificationsInSubstream {
184 socket: Framed::new(socket, codec),
185 handshake: NotificationsInSubstreamHandshake::NotSent,
186 };
187
188 Ok(NotificationsInOpen { handshake, substream })
189 })
190 }
191}
192
193pub struct NotificationsInOpen<TSubstream> {
195 pub handshake: Vec<u8>,
197 pub substream: NotificationsInSubstream<TSubstream>,
199}
200
201impl<TSubstream> fmt::Debug for NotificationsInOpen<TSubstream> {
202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203 f.debug_struct("NotificationsInOpen")
204 .field("handshake", &self.handshake)
205 .finish_non_exhaustive()
206 }
207}
208
209impl<TSubstream> NotificationsInSubstream<TSubstream>
210where
211 TSubstream: AsyncRead + AsyncWrite + Unpin,
212{
213 #[cfg(test)]
214 pub fn new(
215 socket: Framed<TSubstream, UviBytes<io::Cursor<Vec<u8>>>>,
216 handshake: NotificationsInSubstreamHandshake,
217 ) -> Self {
218 Self { socket, handshake }
219 }
220
221 pub fn send_handshake(&mut self, message: impl Into<Vec<u8>>) {
223 if !matches!(self.handshake, NotificationsInSubstreamHandshake::NotSent) {
224 error!(target: LOG_TARGET, "Tried to send handshake twice");
225 return
226 }
227
228 self.handshake = NotificationsInSubstreamHandshake::PendingSend(message.into());
229 }
230
231 pub fn poll_process(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
234 let mut this = self.project();
235
236 loop {
237 match mem::replace(this.handshake, NotificationsInSubstreamHandshake::Sent) {
238 NotificationsInSubstreamHandshake::PendingSend(msg) => {
239 match Sink::poll_ready(this.socket.as_mut(), cx) {
240 Poll::Ready(_) => {
241 *this.handshake = NotificationsInSubstreamHandshake::Flush;
242 match Sink::start_send(this.socket.as_mut(), io::Cursor::new(msg)) {
243 Ok(()) => {},
244 Err(err) => return Poll::Ready(Err(err)),
245 }
246 },
247 Poll::Pending => {
248 *this.handshake = NotificationsInSubstreamHandshake::PendingSend(msg);
249 return Poll::Pending
250 },
251 }
252 },
253 NotificationsInSubstreamHandshake::Flush => {
254 match Sink::poll_flush(this.socket.as_mut(), cx)? {
255 Poll::Ready(()) => {
256 *this.handshake = NotificationsInSubstreamHandshake::Sent;
257 return Poll::Ready(Ok(()));
258 },
259 Poll::Pending => {
260 *this.handshake = NotificationsInSubstreamHandshake::Flush;
261 return Poll::Pending
262 },
263 }
264 },
265
266 st @ NotificationsInSubstreamHandshake::NotSent |
267 st @ NotificationsInSubstreamHandshake::Sent |
268 st @ NotificationsInSubstreamHandshake::ClosingInResponseToRemote |
269 st @ NotificationsInSubstreamHandshake::BothSidesClosed => {
270 *this.handshake = st;
271 return Poll::Ready(Ok(()));
272 },
273 }
274 }
275 }
276}
277
278impl<TSubstream> Stream for NotificationsInSubstream<TSubstream>
279where
280 TSubstream: AsyncRead + AsyncWrite + Unpin,
281{
282 type Item = Result<BytesMut, io::Error>;
283
284 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
285 let mut this = self.project();
286
287 loop {
289 match mem::replace(this.handshake, NotificationsInSubstreamHandshake::Sent) {
290 NotificationsInSubstreamHandshake::NotSent => {
291 *this.handshake = NotificationsInSubstreamHandshake::NotSent;
292 return Poll::Pending
293 },
294 NotificationsInSubstreamHandshake::PendingSend(msg) => {
295 match Sink::poll_ready(this.socket.as_mut(), cx) {
296 Poll::Ready(_) => {
297 *this.handshake = NotificationsInSubstreamHandshake::Flush;
298 match Sink::start_send(this.socket.as_mut(), io::Cursor::new(msg)) {
299 Ok(()) => {},
300 Err(err) => return Poll::Ready(Some(Err(err))),
301 }
302 },
303 Poll::Pending => {
304 *this.handshake = NotificationsInSubstreamHandshake::PendingSend(msg);
305 return Poll::Pending
306 },
307 }
308 },
309 NotificationsInSubstreamHandshake::Flush => {
310 match Sink::poll_flush(this.socket.as_mut(), cx)? {
311 Poll::Ready(()) =>
312 *this.handshake = NotificationsInSubstreamHandshake::Sent,
313 Poll::Pending => {
314 *this.handshake = NotificationsInSubstreamHandshake::Flush;
315 return Poll::Pending
316 },
317 }
318 },
319
320 NotificationsInSubstreamHandshake::Sent => {
321 match Stream::poll_next(this.socket.as_mut(), cx) {
322 Poll::Ready(None) =>
323 *this.handshake =
324 NotificationsInSubstreamHandshake::ClosingInResponseToRemote,
325 Poll::Ready(Some(msg)) => {
326 *this.handshake = NotificationsInSubstreamHandshake::Sent;
327 return Poll::Ready(Some(msg))
328 },
329 Poll::Pending => {
330 *this.handshake = NotificationsInSubstreamHandshake::Sent;
331 return Poll::Pending
332 },
333 }
334 },
335
336 NotificationsInSubstreamHandshake::ClosingInResponseToRemote =>
337 match Sink::poll_close(this.socket.as_mut(), cx)? {
338 Poll::Ready(()) =>
339 *this.handshake = NotificationsInSubstreamHandshake::BothSidesClosed,
340 Poll::Pending => {
341 *this.handshake =
342 NotificationsInSubstreamHandshake::ClosingInResponseToRemote;
343 return Poll::Pending
344 },
345 },
346
347 NotificationsInSubstreamHandshake::BothSidesClosed => return Poll::Ready(None),
348 }
349 }
350 }
351}
352
353impl NotificationsOut {
354 pub fn new(
356 main_protocol_name: impl Into<ProtocolName>,
357 fallback_names: Vec<ProtocolName>,
358 initial_message: impl Into<Vec<u8>>,
359 max_notification_size: u64,
360 peer_id: PeerId,
361 ) -> Self {
362 let initial_message = initial_message.into();
363 if initial_message.len() > MAX_HANDSHAKE_SIZE {
364 error!(target: LOG_TARGET, "Outbound networking handshake is above allowed protocol limit");
365 }
366
367 let mut protocol_names = fallback_names;
368 protocol_names.insert(0, main_protocol_name.into());
369
370 Self { protocol_names, initial_message, max_notification_size, peer_id }
371 }
372}
373
374impl UpgradeInfo for NotificationsOut {
375 type Info = ProtocolName;
376 type InfoIter = vec::IntoIter<Self::Info>;
377
378 fn protocol_info(&self) -> Self::InfoIter {
379 self.protocol_names.clone().into_iter()
380 }
381}
382
383impl<TSubstream> OutboundUpgrade<TSubstream> for NotificationsOut
384where
385 TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
386{
387 type Output = NotificationsOutOpen<TSubstream>;
388 type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
389 type Error = NotificationsHandshakeError;
390
391 fn upgrade_outbound(self, mut socket: TSubstream, negotiated_name: Self::Info) -> Self::Future {
392 Box::pin(async move {
393 {
394 let mut len_data = unsigned_varint::encode::usize_buffer();
395 let encoded_len =
396 unsigned_varint::encode::usize(self.initial_message.len(), &mut len_data).len();
397 socket.write_all(&len_data[..encoded_len]).await?;
398 }
399 socket.write_all(&self.initial_message).await?;
400 socket.flush().await?;
401
402 let handshake_len = unsigned_varint::aio::read_usize(&mut socket).await?;
404 if handshake_len > MAX_HANDSHAKE_SIZE {
405 return Err(NotificationsHandshakeError::TooLarge {
406 requested: handshake_len,
407 max: MAX_HANDSHAKE_SIZE,
408 })
409 }
410
411 let mut handshake = vec![0u8; handshake_len];
412 if !handshake.is_empty() {
413 socket.read_exact(&mut handshake).await?;
414 }
415
416 let mut codec = UviBytes::default();
417 codec.set_max_len(usize::try_from(self.max_notification_size).unwrap_or(usize::MAX));
418
419 Ok(NotificationsOutOpen {
420 handshake,
421 negotiated_fallback: if negotiated_name == self.protocol_names[0] {
422 None
423 } else {
424 Some(negotiated_name)
425 },
426 substream: NotificationsOutSubstream {
427 socket: Framed::new(socket, codec),
428 peer_id: self.peer_id,
429 },
430 })
431 })
432 }
433}
434
435pub struct NotificationsOutOpen<TSubstream> {
437 pub handshake: Vec<u8>,
439 pub negotiated_fallback: Option<ProtocolName>,
442 pub substream: NotificationsOutSubstream<TSubstream>,
444}
445
446impl<TSubstream> fmt::Debug for NotificationsOutOpen<TSubstream> {
447 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
448 f.debug_struct("NotificationsOutOpen")
449 .field("handshake", &self.handshake)
450 .field("negotiated_fallback", &self.negotiated_fallback)
451 .finish_non_exhaustive()
452 }
453}
454
455impl<TSubstream> Sink<Vec<u8>> for NotificationsOutSubstream<TSubstream>
456where
457 TSubstream: AsyncRead + AsyncWrite + Unpin,
458{
459 type Error = NotificationsOutError;
460
461 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
462 let mut this = self.project();
463 Sink::poll_ready(this.socket.as_mut(), cx).map_err(NotificationsOutError::Io)
464 }
465
466 fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
467 let mut this = self.project();
468 Sink::start_send(this.socket.as_mut(), io::Cursor::new(item))
469 .map_err(NotificationsOutError::Io)
470 }
471
472 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
473 let mut this = self.project();
474
475 match Stream::poll_next(this.socket.as_mut(), cx) {
479 Poll::Pending => {},
480 Poll::Ready(Some(result)) => match result {
481 Ok(_) => {
482 debug!(
483 target: "sub-libp2p",
484 "Unexpected incoming data in `NotificationsOutSubstream` peer={:?}",
485 this.peer_id
486 );
487
488 return Poll::Ready(Err(NotificationsOutError::UnexpectedData));
489 },
490 Err(error) => {
491 debug!(
492 target: "sub-libp2p",
493 "Error while reading from `NotificationsOutSubstream` peer={:?} error={error:?}",
494 this.peer_id
495 );
496
497 return Poll::Ready(Err(NotificationsOutError::Closed));
499 },
500 },
501 Poll::Ready(None) => return Poll::Ready(Err(NotificationsOutError::Closed)),
502 }
503
504 Sink::poll_flush(this.socket.as_mut(), cx).map_err(NotificationsOutError::Io)
505 }
506
507 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
508 let mut this = self.project();
509 Sink::poll_close(this.socket.as_mut(), cx).map_err(NotificationsOutError::Io)
510 }
511}
512
513#[derive(Debug, thiserror::Error)]
515pub enum NotificationsHandshakeError {
516 #[error(transparent)]
518 Io(#[from] io::Error),
519
520 #[error("Initial message or handshake was too large: {requested}")]
522 TooLarge {
523 requested: usize,
525 max: usize,
527 },
528
529 #[error(transparent)]
531 VarintDecode(#[from] unsigned_varint::decode::Error),
532}
533
534impl From<unsigned_varint::io::ReadError> for NotificationsHandshakeError {
535 fn from(err: unsigned_varint::io::ReadError) -> Self {
536 match err {
537 unsigned_varint::io::ReadError::Io(err) => Self::Io(err),
538 unsigned_varint::io::ReadError::Decode(err) => Self::VarintDecode(err),
539 _ => {
540 warn!("Unrecognized varint decoding error");
541 Self::Io(From::from(io::ErrorKind::InvalidData))
542 },
543 }
544 }
545}
546
547#[derive(Debug, thiserror::Error)]
549pub enum NotificationsOutError {
550 #[error(transparent)]
552 Io(#[from] io::Error),
553
554 #[error("substream was closed/reset")]
556 Closed,
557
558 #[error("unexpected data received from the remote peer")]
562 UnexpectedData,
563}
564
565#[cfg(test)]
566mod tests {
567 use crate::ProtocolName;
568
569 use super::{
570 NotificationsHandshakeError, NotificationsIn, NotificationsInOpen,
571 NotificationsInSubstream, NotificationsOut, NotificationsOutError, NotificationsOutOpen,
572 NotificationsOutSubstream,
573 };
574 use futures::{channel::oneshot, future, prelude::*, SinkExt, StreamExt};
575 use libp2p::{
576 core::{upgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo},
577 PeerId,
578 };
579 use std::{pin::Pin, task::Poll};
580 use tokio::net::{TcpListener, TcpStream};
581 use tokio_util::compat::TokioAsyncReadCompatExt;
582
583 async fn dial(
586 addr: std::net::SocketAddr,
587 handshake: impl Into<Vec<u8>>,
588 ) -> Result<
589 (
590 Vec<u8>,
591 NotificationsOutSubstream<
592 multistream_select::Negotiated<tokio_util::compat::Compat<TcpStream>>,
593 >,
594 ),
595 NotificationsHandshakeError,
596 > {
597 let socket = TcpStream::connect(addr).await.unwrap();
598 let notifs_out = NotificationsOut::new(
599 "/test/proto/1",
600 Vec::new(),
601 handshake,
602 1024 * 1024,
603 PeerId::random(),
604 );
605 let (_, substream) = multistream_select::dialer_select_proto(
606 socket.compat(),
607 notifs_out.protocol_info(),
608 upgrade::Version::V1,
609 )
610 .await
611 .unwrap();
612 let NotificationsOutOpen { handshake, substream, .. } =
613 <NotificationsOut as OutboundUpgrade<_>>::upgrade_outbound(
614 notifs_out,
615 substream,
616 "/test/proto/1".into(),
617 )
618 .await?;
619 Ok((handshake, substream))
620 }
621
622 async fn listen_on_localhost(
627 listener_addr_tx: oneshot::Sender<std::net::SocketAddr>,
628 ) -> Result<
629 (
630 Vec<u8>,
631 NotificationsInSubstream<
632 multistream_select::Negotiated<tokio_util::compat::Compat<TcpStream>>,
633 >,
634 ),
635 NotificationsHandshakeError,
636 > {
637 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
638 listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
639
640 let (socket, _) = listener.accept().await.unwrap();
641 let notifs_in = NotificationsIn::new("/test/proto/1", Vec::new(), 1024 * 1024);
642 let (_, substream) =
643 multistream_select::listener_select_proto(socket.compat(), notifs_in.protocol_info())
644 .await
645 .unwrap();
646 let NotificationsInOpen { handshake, substream, .. } =
647 <NotificationsIn as InboundUpgrade<_>>::upgrade_inbound(
648 notifs_in,
649 substream,
650 "/test/proto/1".into(),
651 )
652 .await?;
653 Ok((handshake, substream))
654 }
655
656 #[tokio::test]
657 async fn basic_works() {
658 let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
659
660 let client = tokio::spawn(async move {
661 let (handshake, mut substream) =
662 dial(listener_addr_rx.await.unwrap(), &b"initial message"[..]).await.unwrap();
663
664 assert_eq!(handshake, b"hello world");
665 substream.send(b"test message".to_vec()).await.unwrap();
666 });
667
668 let (handshake, mut substream) = listen_on_localhost(listener_addr_tx).await.unwrap();
669
670 assert_eq!(handshake, b"initial message");
671 substream.send_handshake(&b"hello world"[..]);
672
673 let msg = substream.next().await.unwrap().unwrap();
674 assert_eq!(msg.as_ref(), b"test message");
675
676 client.await.unwrap();
677 }
678
679 #[tokio::test]
680 async fn empty_handshake() {
681 let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
684
685 let client = tokio::spawn(async move {
686 let (handshake, mut substream) =
687 dial(listener_addr_rx.await.unwrap(), vec![]).await.unwrap();
688
689 assert!(handshake.is_empty());
690 substream.send(Default::default()).await.unwrap();
691 });
692
693 let (handshake, mut substream) = listen_on_localhost(listener_addr_tx).await.unwrap();
694
695 assert!(handshake.is_empty());
696 substream.send_handshake(vec![]);
697
698 let msg = substream.next().await.unwrap().unwrap();
699 assert!(msg.as_ref().is_empty());
700
701 client.await.unwrap();
702 }
703
704 #[tokio::test]
705 async fn refused() {
706 let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
707
708 let client = tokio::spawn(async move {
709 let outcome = dial(listener_addr_rx.await.unwrap(), &b"hello"[..]).await;
710
711 assert!(outcome.is_err());
715 });
716
717 let (handshake, substream) = listen_on_localhost(listener_addr_tx).await.unwrap();
718 assert_eq!(handshake, b"hello");
719
720 drop(substream);
722
723 client.await.unwrap();
724 }
725
726 #[tokio::test]
727 async fn large_initial_message_refused() {
728 let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
729
730 let client = tokio::spawn(async move {
731 let ret =
732 dial(listener_addr_rx.await.unwrap(), (0..32768).map(|_| 0).collect::<Vec<_>>())
733 .await;
734 assert!(ret.is_err());
735 });
736
737 let _ret = listen_on_localhost(listener_addr_tx).await;
738 client.await.unwrap();
739 }
740
741 #[tokio::test]
742 async fn large_handshake_refused() {
743 let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
744
745 let client = tokio::spawn(async move {
746 let ret = dial(listener_addr_rx.await.unwrap(), &b"initial message"[..]).await;
747 assert!(ret.is_err());
748 });
749
750 let (handshake, mut substream) = listen_on_localhost(listener_addr_tx).await.unwrap();
751 assert_eq!(handshake, b"initial message");
752
753 substream.send_handshake((0..32768).map(|_| 0).collect::<Vec<_>>());
755 let _ = substream.next().await;
756
757 client.await.unwrap();
758 }
759
760 #[tokio::test]
761 async fn send_handshake_without_polling_for_incoming_data() {
762 const PROTO_NAME: &str = "/test/proto/1";
763 let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
764
765 let client = tokio::spawn(async move {
766 let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
767 let NotificationsOutOpen { handshake, .. } = OutboundUpgrade::upgrade_outbound(
768 NotificationsOut::new(
769 PROTO_NAME,
770 Vec::new(),
771 &b"initial message"[..],
772 1024 * 1024,
773 PeerId::random(),
774 ),
775 socket.compat(),
776 ProtocolName::Static(PROTO_NAME),
777 )
778 .await
779 .unwrap();
780
781 assert_eq!(handshake, b"hello world");
782 });
783
784 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
785 listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
786
787 let (socket, _) = listener.accept().await.unwrap();
788 let NotificationsInOpen { handshake, mut substream, .. } = InboundUpgrade::upgrade_inbound(
789 NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024),
790 socket.compat(),
791 ProtocolName::Static(PROTO_NAME),
792 )
793 .await
794 .unwrap();
795
796 assert_eq!(handshake, b"initial message");
797 substream.send_handshake(&b"hello world"[..]);
798
799 future::poll_fn(|cx| Pin::new(&mut substream).poll_process(cx)).await.unwrap();
801
802 client.await.unwrap();
803 }
804
805 #[tokio::test]
806 async fn can_detect_dropped_out_substream_without_writing_data() {
807 const PROTO_NAME: &str = "/test/proto/1";
808 let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
809
810 let client = tokio::spawn(async move {
811 let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
812 let NotificationsOutOpen { handshake, mut substream, .. } =
813 OutboundUpgrade::upgrade_outbound(
814 NotificationsOut::new(
815 PROTO_NAME,
816 Vec::new(),
817 &b"initial message"[..],
818 1024 * 1024,
819 PeerId::random(),
820 ),
821 socket.compat(),
822 ProtocolName::Static(PROTO_NAME),
823 )
824 .await
825 .unwrap();
826
827 assert_eq!(handshake, b"hello world");
828
829 future::poll_fn(|cx| match Pin::new(&mut substream).poll_flush(cx) {
830 Poll::Pending => Poll::Pending,
831 Poll::Ready(Ok(())) => {
832 cx.waker().wake_by_ref();
833 Poll::Pending
834 },
835 Poll::Ready(Err(e)) => {
836 assert!(matches!(e, NotificationsOutError::Closed));
837 Poll::Ready(())
838 },
839 })
840 .await;
841 });
842
843 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
844 listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
845
846 let (socket, _) = listener.accept().await.unwrap();
847 let NotificationsInOpen { handshake, mut substream, .. } = InboundUpgrade::upgrade_inbound(
848 NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024),
849 socket.compat(),
850 ProtocolName::Static(PROTO_NAME),
851 )
852 .await
853 .unwrap();
854
855 assert_eq!(handshake, b"initial message");
856
857 substream.send_handshake(&b"hello world"[..]);
859 future::poll_fn(|cx| Pin::new(&mut substream).poll_process(cx)).await.unwrap();
860
861 drop(substream);
862
863 client.await.unwrap();
864 }
865}