referrerpolicy=no-referrer-when-downgrade

sc_network/protocol/notifications/upgrade/
notifications.rs

1// This file is part of Substrate.
2
3// Copyright (C) Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0
5
6// This program is free software: you can redistribute it and/or modify
7// it under the terms of the GNU General Public License as published by
8// the Free Software Foundation, either version 3 of the License, or
9// (at your option) any later version.
10
11// This program is distributed in the hope that it will be useful,
12// but WITHOUT ANY WARRANTY; without even the implied warranty of
13// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14// GNU General Public License for more details.
15
16// You should have received a copy of the GNU General Public License
17// along with this program. If not, see <https://www.gnu.org/licenses/>.
18
19/// Notifications protocol.
20///
21/// The Substrate notifications protocol consists in the following:
22///
23/// - Node A opens a substream to node B and sends a message which contains some
24///   protocol-specific higher-level logic. This message is prefixed with a variable-length
25///   integer message length. This message can be empty, in which case `0` is sent.
26/// - If node B accepts the substream, it sends back a message with the same properties.
27/// - If instead B refuses the connection (which typically happens because no empty slot is
28///   available), then it immediately closes the substream without sending back anything.
29/// - Node A can then send notifications to B, prefixed with a variable-length integer
30///   indicating the length of the message.
31/// - Either node A or node B can signal that it doesn't want this notifications substream
32///   anymore by closing its writing side. The other party should respond by also closing their
33///   own writing side soon after.
34///
35/// Notification substreams are unidirectional. If A opens a substream with B, then B is
36/// encouraged but not required to open a substream to A as well.
37use crate::types::ProtocolName;
38
39use asynchronous_codec::Framed;
40use bytes::BytesMut;
41use futures::prelude::*;
42use libp2p::{
43	core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo},
44	PeerId,
45};
46use log::{debug, error, warn};
47use unsigned_varint::codec::UviBytes;
48
49use std::{
50	fmt, io, mem,
51	pin::Pin,
52	task::{Context, Poll},
53	vec,
54};
55
56/// Logging target for the file.
57const LOG_TARGET: &str = "sub-libp2p::notification::upgrade";
58
59/// Maximum allowed size of the two handshake messages, in bytes.
60const MAX_HANDSHAKE_SIZE: usize = 1024;
61
62/// Upgrade that accepts a substream, sends back a status message, then becomes a unidirectional
63/// stream of messages.
64#[derive(Debug, Clone)]
65pub struct NotificationsIn {
66	/// Protocol name to use when negotiating the substream.
67	/// The first one is the main name, while the other ones are fall backs.
68	protocol_names: Vec<ProtocolName>,
69	/// Maximum allowed size for a single notification.
70	max_notification_size: u64,
71}
72
73/// Upgrade that opens a substream, waits for the remote to accept by sending back a status
74/// message, then becomes a unidirectional sink of data.
75#[derive(Debug, Clone)]
76pub struct NotificationsOut {
77	/// Protocol name to use when negotiating the substream.
78	/// The first one is the main name, while the other ones are fall backs.
79	protocol_names: Vec<ProtocolName>,
80	/// Message to send when we start the handshake.
81	initial_message: Vec<u8>,
82	/// Maximum allowed size for a single notification.
83	max_notification_size: u64,
84	/// The peerID of the remote.
85	peer_id: PeerId,
86}
87
88/// A substream for incoming notification messages.
89///
90/// When creating, this struct starts in a state in which we must first send back a handshake
91/// message to the remote. No message will come before this has been done.
92#[pin_project::pin_project]
93pub struct NotificationsInSubstream<TSubstream> {
94	#[pin]
95	socket: Framed<TSubstream, UviBytes<io::Cursor<Vec<u8>>>>,
96	handshake: NotificationsInSubstreamHandshake,
97}
98
99/// State of the handshake sending back process.
100#[derive(Debug)]
101pub enum NotificationsInSubstreamHandshake {
102	/// Waiting for the user to give us the handshake message.
103	NotSent,
104	/// User gave us the handshake message. Trying to push it in the socket.
105	PendingSend(Vec<u8>),
106	/// Handshake message was pushed in the socket. Still need to flush.
107	Flush,
108	/// Handshake message successfully sent and flushed.
109	Sent,
110	/// Remote has closed their writing side. We close our own writing side in return.
111	ClosingInResponseToRemote,
112	/// Both our side and the remote have closed their writing side.
113	BothSidesClosed,
114}
115
116/// A substream for outgoing notification messages.
117#[pin_project::pin_project]
118pub struct NotificationsOutSubstream<TSubstream> {
119	/// Substream where to send messages.
120	#[pin]
121	socket: Framed<TSubstream, UviBytes<io::Cursor<Vec<u8>>>>,
122
123	/// The remote peer.
124	peer_id: PeerId,
125}
126
127#[cfg(test)]
128impl<TSubstream> NotificationsOutSubstream<TSubstream> {
129	pub fn new(socket: Framed<TSubstream, UviBytes<io::Cursor<Vec<u8>>>>) -> Self {
130		Self { socket, peer_id: PeerId::random() }
131	}
132}
133
134impl NotificationsIn {
135	/// Builds a new potential upgrade.
136	pub fn new(
137		main_protocol_name: impl Into<ProtocolName>,
138		fallback_names: Vec<ProtocolName>,
139		max_notification_size: u64,
140	) -> Self {
141		let mut protocol_names = fallback_names;
142		protocol_names.insert(0, main_protocol_name.into());
143
144		Self { protocol_names, max_notification_size }
145	}
146}
147
148impl UpgradeInfo for NotificationsIn {
149	type Info = ProtocolName;
150	type InfoIter = vec::IntoIter<Self::Info>;
151
152	fn protocol_info(&self) -> Self::InfoIter {
153		self.protocol_names.clone().into_iter()
154	}
155}
156
157impl<TSubstream> InboundUpgrade<TSubstream> for NotificationsIn
158where
159	TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
160{
161	type Output = NotificationsInOpen<TSubstream>;
162	type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
163	type Error = NotificationsHandshakeError;
164
165	fn upgrade_inbound(self, mut socket: TSubstream, _negotiated_name: Self::Info) -> Self::Future {
166		Box::pin(async move {
167			let handshake_len = unsigned_varint::aio::read_usize(&mut socket).await?;
168			if handshake_len > MAX_HANDSHAKE_SIZE {
169				return Err(NotificationsHandshakeError::TooLarge {
170					requested: handshake_len,
171					max: MAX_HANDSHAKE_SIZE,
172				})
173			}
174
175			let mut handshake = vec![0u8; handshake_len];
176			if !handshake.is_empty() {
177				socket.read_exact(&mut handshake).await?;
178			}
179
180			let mut codec = UviBytes::default();
181			codec.set_max_len(usize::try_from(self.max_notification_size).unwrap_or(usize::MAX));
182
183			let substream = NotificationsInSubstream {
184				socket: Framed::new(socket, codec),
185				handshake: NotificationsInSubstreamHandshake::NotSent,
186			};
187
188			Ok(NotificationsInOpen { handshake, substream })
189		})
190	}
191}
192
193/// Yielded by the [`NotificationsIn`] after a successfully upgrade.
194pub struct NotificationsInOpen<TSubstream> {
195	/// Handshake sent by the remote.
196	pub handshake: Vec<u8>,
197	/// Implementation of `Stream` that allows receives messages from the substream.
198	pub substream: NotificationsInSubstream<TSubstream>,
199}
200
201impl<TSubstream> fmt::Debug for NotificationsInOpen<TSubstream> {
202	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203		f.debug_struct("NotificationsInOpen")
204			.field("handshake", &self.handshake)
205			.finish_non_exhaustive()
206	}
207}
208
209impl<TSubstream> NotificationsInSubstream<TSubstream>
210where
211	TSubstream: AsyncRead + AsyncWrite + Unpin,
212{
213	#[cfg(test)]
214	pub fn new(
215		socket: Framed<TSubstream, UviBytes<io::Cursor<Vec<u8>>>>,
216		handshake: NotificationsInSubstreamHandshake,
217	) -> Self {
218		Self { socket, handshake }
219	}
220
221	/// Sends the handshake in order to inform the remote that we accept the substream.
222	pub fn send_handshake(&mut self, message: impl Into<Vec<u8>>) {
223		if !matches!(self.handshake, NotificationsInSubstreamHandshake::NotSent) {
224			error!(target: LOG_TARGET, "Tried to send handshake twice");
225			return
226		}
227
228		self.handshake = NotificationsInSubstreamHandshake::PendingSend(message.into());
229	}
230
231	/// Equivalent to `Stream::poll_next`, except that it only drives the handshake and is
232	/// guaranteed to not generate any notification.
233	pub fn poll_process(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
234		let mut this = self.project();
235
236		loop {
237			match mem::replace(this.handshake, NotificationsInSubstreamHandshake::Sent) {
238				NotificationsInSubstreamHandshake::PendingSend(msg) => {
239					match Sink::poll_ready(this.socket.as_mut(), cx) {
240						Poll::Ready(_) => {
241							*this.handshake = NotificationsInSubstreamHandshake::Flush;
242							match Sink::start_send(this.socket.as_mut(), io::Cursor::new(msg)) {
243								Ok(()) => {},
244								Err(err) => return Poll::Ready(Err(err)),
245							}
246						},
247						Poll::Pending => {
248							*this.handshake = NotificationsInSubstreamHandshake::PendingSend(msg);
249							return Poll::Pending
250						},
251					}
252				},
253				NotificationsInSubstreamHandshake::Flush => {
254					match Sink::poll_flush(this.socket.as_mut(), cx)? {
255						Poll::Ready(()) => {
256							*this.handshake = NotificationsInSubstreamHandshake::Sent;
257							return Poll::Ready(Ok(()));
258						},
259						Poll::Pending => {
260							*this.handshake = NotificationsInSubstreamHandshake::Flush;
261							return Poll::Pending
262						},
263					}
264				},
265
266				st @ NotificationsInSubstreamHandshake::NotSent |
267				st @ NotificationsInSubstreamHandshake::Sent |
268				st @ NotificationsInSubstreamHandshake::ClosingInResponseToRemote |
269				st @ NotificationsInSubstreamHandshake::BothSidesClosed => {
270					*this.handshake = st;
271					return Poll::Ready(Ok(()));
272				},
273			}
274		}
275	}
276}
277
278impl<TSubstream> Stream for NotificationsInSubstream<TSubstream>
279where
280	TSubstream: AsyncRead + AsyncWrite + Unpin,
281{
282	type Item = Result<BytesMut, io::Error>;
283
284	fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
285		let mut this = self.project();
286
287		// This `Stream` implementation first tries to send back the handshake if necessary.
288		loop {
289			match mem::replace(this.handshake, NotificationsInSubstreamHandshake::Sent) {
290				NotificationsInSubstreamHandshake::NotSent => {
291					*this.handshake = NotificationsInSubstreamHandshake::NotSent;
292					return Poll::Pending
293				},
294				NotificationsInSubstreamHandshake::PendingSend(msg) => {
295					match Sink::poll_ready(this.socket.as_mut(), cx) {
296						Poll::Ready(_) => {
297							*this.handshake = NotificationsInSubstreamHandshake::Flush;
298							match Sink::start_send(this.socket.as_mut(), io::Cursor::new(msg)) {
299								Ok(()) => {},
300								Err(err) => return Poll::Ready(Some(Err(err))),
301							}
302						},
303						Poll::Pending => {
304							*this.handshake = NotificationsInSubstreamHandshake::PendingSend(msg);
305							return Poll::Pending
306						},
307					}
308				},
309				NotificationsInSubstreamHandshake::Flush => {
310					match Sink::poll_flush(this.socket.as_mut(), cx)? {
311						Poll::Ready(()) =>
312							*this.handshake = NotificationsInSubstreamHandshake::Sent,
313						Poll::Pending => {
314							*this.handshake = NotificationsInSubstreamHandshake::Flush;
315							return Poll::Pending
316						},
317					}
318				},
319
320				NotificationsInSubstreamHandshake::Sent => {
321					match Stream::poll_next(this.socket.as_mut(), cx) {
322						Poll::Ready(None) =>
323							*this.handshake =
324								NotificationsInSubstreamHandshake::ClosingInResponseToRemote,
325						Poll::Ready(Some(msg)) => {
326							*this.handshake = NotificationsInSubstreamHandshake::Sent;
327							return Poll::Ready(Some(msg))
328						},
329						Poll::Pending => {
330							*this.handshake = NotificationsInSubstreamHandshake::Sent;
331							return Poll::Pending
332						},
333					}
334				},
335
336				NotificationsInSubstreamHandshake::ClosingInResponseToRemote =>
337					match Sink::poll_close(this.socket.as_mut(), cx)? {
338						Poll::Ready(()) =>
339							*this.handshake = NotificationsInSubstreamHandshake::BothSidesClosed,
340						Poll::Pending => {
341							*this.handshake =
342								NotificationsInSubstreamHandshake::ClosingInResponseToRemote;
343							return Poll::Pending
344						},
345					},
346
347				NotificationsInSubstreamHandshake::BothSidesClosed => return Poll::Ready(None),
348			}
349		}
350	}
351}
352
353impl NotificationsOut {
354	/// Builds a new potential upgrade.
355	pub fn new(
356		main_protocol_name: impl Into<ProtocolName>,
357		fallback_names: Vec<ProtocolName>,
358		initial_message: impl Into<Vec<u8>>,
359		max_notification_size: u64,
360		peer_id: PeerId,
361	) -> Self {
362		let initial_message = initial_message.into();
363		if initial_message.len() > MAX_HANDSHAKE_SIZE {
364			error!(target: LOG_TARGET, "Outbound networking handshake is above allowed protocol limit");
365		}
366
367		let mut protocol_names = fallback_names;
368		protocol_names.insert(0, main_protocol_name.into());
369
370		Self { protocol_names, initial_message, max_notification_size, peer_id }
371	}
372}
373
374impl UpgradeInfo for NotificationsOut {
375	type Info = ProtocolName;
376	type InfoIter = vec::IntoIter<Self::Info>;
377
378	fn protocol_info(&self) -> Self::InfoIter {
379		self.protocol_names.clone().into_iter()
380	}
381}
382
383impl<TSubstream> OutboundUpgrade<TSubstream> for NotificationsOut
384where
385	TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
386{
387	type Output = NotificationsOutOpen<TSubstream>;
388	type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
389	type Error = NotificationsHandshakeError;
390
391	fn upgrade_outbound(self, mut socket: TSubstream, negotiated_name: Self::Info) -> Self::Future {
392		Box::pin(async move {
393			{
394				let mut len_data = unsigned_varint::encode::usize_buffer();
395				let encoded_len =
396					unsigned_varint::encode::usize(self.initial_message.len(), &mut len_data).len();
397				socket.write_all(&len_data[..encoded_len]).await?;
398			}
399			socket.write_all(&self.initial_message).await?;
400			socket.flush().await?;
401
402			// Reading handshake.
403			let handshake_len = unsigned_varint::aio::read_usize(&mut socket).await?;
404			if handshake_len > MAX_HANDSHAKE_SIZE {
405				return Err(NotificationsHandshakeError::TooLarge {
406					requested: handshake_len,
407					max: MAX_HANDSHAKE_SIZE,
408				})
409			}
410
411			let mut handshake = vec![0u8; handshake_len];
412			if !handshake.is_empty() {
413				socket.read_exact(&mut handshake).await?;
414			}
415
416			let mut codec = UviBytes::default();
417			codec.set_max_len(usize::try_from(self.max_notification_size).unwrap_or(usize::MAX));
418
419			Ok(NotificationsOutOpen {
420				handshake,
421				negotiated_fallback: if negotiated_name == self.protocol_names[0] {
422					None
423				} else {
424					Some(negotiated_name)
425				},
426				substream: NotificationsOutSubstream {
427					socket: Framed::new(socket, codec),
428					peer_id: self.peer_id,
429				},
430			})
431		})
432	}
433}
434
435/// Yielded by the [`NotificationsOut`] after a successfully upgrade.
436pub struct NotificationsOutOpen<TSubstream> {
437	/// Handshake returned by the remote.
438	pub handshake: Vec<u8>,
439	/// If the negotiated name is not the "main" protocol name but a fallback, contains the
440	/// name of the negotiated fallback.
441	pub negotiated_fallback: Option<ProtocolName>,
442	/// Implementation of `Sink` that allows sending messages on the substream.
443	pub substream: NotificationsOutSubstream<TSubstream>,
444}
445
446impl<TSubstream> fmt::Debug for NotificationsOutOpen<TSubstream> {
447	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
448		f.debug_struct("NotificationsOutOpen")
449			.field("handshake", &self.handshake)
450			.field("negotiated_fallback", &self.negotiated_fallback)
451			.finish_non_exhaustive()
452	}
453}
454
455impl<TSubstream> Sink<Vec<u8>> for NotificationsOutSubstream<TSubstream>
456where
457	TSubstream: AsyncRead + AsyncWrite + Unpin,
458{
459	type Error = NotificationsOutError;
460
461	fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
462		let mut this = self.project();
463		Sink::poll_ready(this.socket.as_mut(), cx).map_err(NotificationsOutError::Io)
464	}
465
466	fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
467		let mut this = self.project();
468		Sink::start_send(this.socket.as_mut(), io::Cursor::new(item))
469			.map_err(NotificationsOutError::Io)
470	}
471
472	fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
473		let mut this = self.project();
474
475		// `Sink::poll_flush` does not expose stream closed error until we write something into
476		// the stream, so the code below makes sure we detect that the substream was closed
477		// even if we don't write anything into it.
478		match Stream::poll_next(this.socket.as_mut(), cx) {
479			Poll::Pending => {},
480			Poll::Ready(Some(result)) => match result {
481				Ok(_) => {
482					debug!(
483						target: "sub-libp2p",
484						"Unexpected incoming data in `NotificationsOutSubstream` peer={:?}",
485						this.peer_id
486					);
487
488					return Poll::Ready(Err(NotificationsOutError::UnexpectedData));
489				},
490				Err(error) => {
491					debug!(
492						target: "sub-libp2p",
493						"Error while reading from `NotificationsOutSubstream` peer={:?} error={error:?}",
494						this.peer_id
495					);
496
497					// The expectation is that the remote has closed the substream.
498					return Poll::Ready(Err(NotificationsOutError::Closed));
499				},
500			},
501			Poll::Ready(None) => return Poll::Ready(Err(NotificationsOutError::Closed)),
502		}
503
504		Sink::poll_flush(this.socket.as_mut(), cx).map_err(NotificationsOutError::Io)
505	}
506
507	fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
508		let mut this = self.project();
509		Sink::poll_close(this.socket.as_mut(), cx).map_err(NotificationsOutError::Io)
510	}
511}
512
513/// Error generated by sending on a notifications out substream.
514#[derive(Debug, thiserror::Error)]
515pub enum NotificationsHandshakeError {
516	/// I/O error on the substream.
517	#[error(transparent)]
518	Io(#[from] io::Error),
519
520	/// Initial message or handshake was too large.
521	#[error("Initial message or handshake was too large: {requested}")]
522	TooLarge {
523		/// Size requested by the remote.
524		requested: usize,
525		/// Maximum allowed,
526		max: usize,
527	},
528
529	/// Error while decoding the variable-length integer.
530	#[error(transparent)]
531	VarintDecode(#[from] unsigned_varint::decode::Error),
532}
533
534impl From<unsigned_varint::io::ReadError> for NotificationsHandshakeError {
535	fn from(err: unsigned_varint::io::ReadError) -> Self {
536		match err {
537			unsigned_varint::io::ReadError::Io(err) => Self::Io(err),
538			unsigned_varint::io::ReadError::Decode(err) => Self::VarintDecode(err),
539			_ => {
540				warn!("Unrecognized varint decoding error");
541				Self::Io(From::from(io::ErrorKind::InvalidData))
542			},
543		}
544	}
545}
546
547/// Error generated by sending on a notifications out substream.
548#[derive(Debug, thiserror::Error)]
549pub enum NotificationsOutError {
550	/// I/O error on the substream.
551	#[error(transparent)]
552	Io(#[from] io::Error),
553
554	/// The substream was closed.
555	#[error("substream was closed/reset")]
556	Closed,
557
558	/// The remote peer did not comply with the notification spec.
559	///
560	/// This is a terminal error and the peer should be banned immediately.
561	#[error("unexpected data received from the remote peer")]
562	UnexpectedData,
563}
564
565#[cfg(test)]
566mod tests {
567	use crate::ProtocolName;
568
569	use super::{
570		NotificationsHandshakeError, NotificationsIn, NotificationsInOpen,
571		NotificationsInSubstream, NotificationsOut, NotificationsOutError, NotificationsOutOpen,
572		NotificationsOutSubstream,
573	};
574	use futures::{channel::oneshot, future, prelude::*, SinkExt, StreamExt};
575	use libp2p::{
576		core::{upgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo},
577		PeerId,
578	};
579	use std::{pin::Pin, task::Poll};
580	use tokio::net::{TcpListener, TcpStream};
581	use tokio_util::compat::TokioAsyncReadCompatExt;
582
583	/// Opens a substream to the given address, negotiates the protocol, and returns the substream
584	/// along with the handshake message.
585	async fn dial(
586		addr: std::net::SocketAddr,
587		handshake: impl Into<Vec<u8>>,
588	) -> Result<
589		(
590			Vec<u8>,
591			NotificationsOutSubstream<
592				multistream_select::Negotiated<tokio_util::compat::Compat<TcpStream>>,
593			>,
594		),
595		NotificationsHandshakeError,
596	> {
597		let socket = TcpStream::connect(addr).await.unwrap();
598		let notifs_out = NotificationsOut::new(
599			"/test/proto/1",
600			Vec::new(),
601			handshake,
602			1024 * 1024,
603			PeerId::random(),
604		);
605		let (_, substream) = multistream_select::dialer_select_proto(
606			socket.compat(),
607			notifs_out.protocol_info(),
608			upgrade::Version::V1,
609		)
610		.await
611		.unwrap();
612		let NotificationsOutOpen { handshake, substream, .. } =
613			<NotificationsOut as OutboundUpgrade<_>>::upgrade_outbound(
614				notifs_out,
615				substream,
616				"/test/proto/1".into(),
617			)
618			.await?;
619		Ok((handshake, substream))
620	}
621
622	/// Listens on a localhost, negotiates the protocol, and returns the substream along with the
623	/// handshake message.
624	///
625	/// Also sends the listener address through the given channel.
626	async fn listen_on_localhost(
627		listener_addr_tx: oneshot::Sender<std::net::SocketAddr>,
628	) -> Result<
629		(
630			Vec<u8>,
631			NotificationsInSubstream<
632				multistream_select::Negotiated<tokio_util::compat::Compat<TcpStream>>,
633			>,
634		),
635		NotificationsHandshakeError,
636	> {
637		let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
638		listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
639
640		let (socket, _) = listener.accept().await.unwrap();
641		let notifs_in = NotificationsIn::new("/test/proto/1", Vec::new(), 1024 * 1024);
642		let (_, substream) =
643			multistream_select::listener_select_proto(socket.compat(), notifs_in.protocol_info())
644				.await
645				.unwrap();
646		let NotificationsInOpen { handshake, substream, .. } =
647			<NotificationsIn as InboundUpgrade<_>>::upgrade_inbound(
648				notifs_in,
649				substream,
650				"/test/proto/1".into(),
651			)
652			.await?;
653		Ok((handshake, substream))
654	}
655
656	#[tokio::test]
657	async fn basic_works() {
658		let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
659
660		let client = tokio::spawn(async move {
661			let (handshake, mut substream) =
662				dial(listener_addr_rx.await.unwrap(), &b"initial message"[..]).await.unwrap();
663
664			assert_eq!(handshake, b"hello world");
665			substream.send(b"test message".to_vec()).await.unwrap();
666		});
667
668		let (handshake, mut substream) = listen_on_localhost(listener_addr_tx).await.unwrap();
669
670		assert_eq!(handshake, b"initial message");
671		substream.send_handshake(&b"hello world"[..]);
672
673		let msg = substream.next().await.unwrap().unwrap();
674		assert_eq!(msg.as_ref(), b"test message");
675
676		client.await.unwrap();
677	}
678
679	#[tokio::test]
680	async fn empty_handshake() {
681		// Check that everything still works when the handshake messages are empty.
682
683		let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
684
685		let client = tokio::spawn(async move {
686			let (handshake, mut substream) =
687				dial(listener_addr_rx.await.unwrap(), vec![]).await.unwrap();
688
689			assert!(handshake.is_empty());
690			substream.send(Default::default()).await.unwrap();
691		});
692
693		let (handshake, mut substream) = listen_on_localhost(listener_addr_tx).await.unwrap();
694
695		assert!(handshake.is_empty());
696		substream.send_handshake(vec![]);
697
698		let msg = substream.next().await.unwrap().unwrap();
699		assert!(msg.as_ref().is_empty());
700
701		client.await.unwrap();
702	}
703
704	#[tokio::test]
705	async fn refused() {
706		let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
707
708		let client = tokio::spawn(async move {
709			let outcome = dial(listener_addr_rx.await.unwrap(), &b"hello"[..]).await;
710
711			// Despite the protocol negotiation being successfully conducted on the listener
712			// side, we have to receive an error here because the listener didn't send the
713			// handshake.
714			assert!(outcome.is_err());
715		});
716
717		let (handshake, substream) = listen_on_localhost(listener_addr_tx).await.unwrap();
718		assert_eq!(handshake, b"hello");
719
720		// We successfully upgrade to the protocol, but then close the substream.
721		drop(substream);
722
723		client.await.unwrap();
724	}
725
726	#[tokio::test]
727	async fn large_initial_message_refused() {
728		let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
729
730		let client = tokio::spawn(async move {
731			let ret =
732				dial(listener_addr_rx.await.unwrap(), (0..32768).map(|_| 0).collect::<Vec<_>>())
733					.await;
734			assert!(ret.is_err());
735		});
736
737		let _ret = listen_on_localhost(listener_addr_tx).await;
738		client.await.unwrap();
739	}
740
741	#[tokio::test]
742	async fn large_handshake_refused() {
743		let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
744
745		let client = tokio::spawn(async move {
746			let ret = dial(listener_addr_rx.await.unwrap(), &b"initial message"[..]).await;
747			assert!(ret.is_err());
748		});
749
750		let (handshake, mut substream) = listen_on_localhost(listener_addr_tx).await.unwrap();
751		assert_eq!(handshake, b"initial message");
752
753		// We check that a handshake that is too large gets refused.
754		substream.send_handshake((0..32768).map(|_| 0).collect::<Vec<_>>());
755		let _ = substream.next().await;
756
757		client.await.unwrap();
758	}
759
760	#[tokio::test]
761	async fn send_handshake_without_polling_for_incoming_data() {
762		const PROTO_NAME: &str = "/test/proto/1";
763		let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
764
765		let client = tokio::spawn(async move {
766			let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
767			let NotificationsOutOpen { handshake, .. } = OutboundUpgrade::upgrade_outbound(
768				NotificationsOut::new(
769					PROTO_NAME,
770					Vec::new(),
771					&b"initial message"[..],
772					1024 * 1024,
773					PeerId::random(),
774				),
775				socket.compat(),
776				ProtocolName::Static(PROTO_NAME),
777			)
778			.await
779			.unwrap();
780
781			assert_eq!(handshake, b"hello world");
782		});
783
784		let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
785		listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
786
787		let (socket, _) = listener.accept().await.unwrap();
788		let NotificationsInOpen { handshake, mut substream, .. } = InboundUpgrade::upgrade_inbound(
789			NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024),
790			socket.compat(),
791			ProtocolName::Static(PROTO_NAME),
792		)
793		.await
794		.unwrap();
795
796		assert_eq!(handshake, b"initial message");
797		substream.send_handshake(&b"hello world"[..]);
798
799		// Actually send the handshake.
800		future::poll_fn(|cx| Pin::new(&mut substream).poll_process(cx)).await.unwrap();
801
802		client.await.unwrap();
803	}
804
805	#[tokio::test]
806	async fn can_detect_dropped_out_substream_without_writing_data() {
807		const PROTO_NAME: &str = "/test/proto/1";
808		let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
809
810		let client = tokio::spawn(async move {
811			let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
812			let NotificationsOutOpen { handshake, mut substream, .. } =
813				OutboundUpgrade::upgrade_outbound(
814					NotificationsOut::new(
815						PROTO_NAME,
816						Vec::new(),
817						&b"initial message"[..],
818						1024 * 1024,
819						PeerId::random(),
820					),
821					socket.compat(),
822					ProtocolName::Static(PROTO_NAME),
823				)
824				.await
825				.unwrap();
826
827			assert_eq!(handshake, b"hello world");
828
829			future::poll_fn(|cx| match Pin::new(&mut substream).poll_flush(cx) {
830				Poll::Pending => Poll::Pending,
831				Poll::Ready(Ok(())) => {
832					cx.waker().wake_by_ref();
833					Poll::Pending
834				},
835				Poll::Ready(Err(e)) => {
836					assert!(matches!(e, NotificationsOutError::Closed));
837					Poll::Ready(())
838				},
839			})
840			.await;
841		});
842
843		let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
844		listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
845
846		let (socket, _) = listener.accept().await.unwrap();
847		let NotificationsInOpen { handshake, mut substream, .. } = InboundUpgrade::upgrade_inbound(
848			NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024),
849			socket.compat(),
850			ProtocolName::Static(PROTO_NAME),
851		)
852		.await
853		.unwrap();
854
855		assert_eq!(handshake, b"initial message");
856
857		// Send the handhsake.
858		substream.send_handshake(&b"hello world"[..]);
859		future::poll_fn(|cx| Pin::new(&mut substream).poll_process(cx)).await.unwrap();
860
861		drop(substream);
862
863		client.await.unwrap();
864	}
865}