use crate::types::ProtocolName;
use asynchronous_codec::Framed;
use bytes::BytesMut;
use futures::prelude::*;
use libp2p::core::{InboundUpgrade, OutboundUpgrade, UpgradeInfo};
use log::{error, warn};
use unsigned_varint::codec::UviBytes;
use std::{
fmt, io, mem,
pin::Pin,
task::{Context, Poll},
vec,
};
const MAX_HANDSHAKE_SIZE: usize = 1024;
#[derive(Debug, Clone)]
pub struct NotificationsIn {
protocol_names: Vec<ProtocolName>,
max_notification_size: u64,
}
#[derive(Debug, Clone)]
pub struct NotificationsOut {
protocol_names: Vec<ProtocolName>,
initial_message: Vec<u8>,
max_notification_size: u64,
}
#[pin_project::pin_project]
pub struct NotificationsInSubstream<TSubstream> {
#[pin]
socket: Framed<TSubstream, UviBytes<io::Cursor<Vec<u8>>>>,
handshake: NotificationsInSubstreamHandshake,
}
#[derive(Debug)]
pub enum NotificationsInSubstreamHandshake {
NotSent,
PendingSend(Vec<u8>),
Flush,
Sent,
ClosingInResponseToRemote,
BothSidesClosed,
}
#[pin_project::pin_project]
pub struct NotificationsOutSubstream<TSubstream> {
#[pin]
socket: Framed<TSubstream, UviBytes<io::Cursor<Vec<u8>>>>,
}
#[cfg(test)]
impl<TSubstream> NotificationsOutSubstream<TSubstream> {
pub fn new(socket: Framed<TSubstream, UviBytes<io::Cursor<Vec<u8>>>>) -> Self {
Self { socket }
}
}
impl NotificationsIn {
pub fn new(
main_protocol_name: impl Into<ProtocolName>,
fallback_names: Vec<ProtocolName>,
max_notification_size: u64,
) -> Self {
let mut protocol_names = fallback_names;
protocol_names.insert(0, main_protocol_name.into());
Self { protocol_names, max_notification_size }
}
}
impl UpgradeInfo for NotificationsIn {
type Info = ProtocolName;
type InfoIter = vec::IntoIter<Self::Info>;
fn protocol_info(&self) -> Self::InfoIter {
self.protocol_names.clone().into_iter()
}
}
impl<TSubstream> InboundUpgrade<TSubstream> for NotificationsIn
where
TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = NotificationsInOpen<TSubstream>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
type Error = NotificationsHandshakeError;
fn upgrade_inbound(self, mut socket: TSubstream, _negotiated_name: Self::Info) -> Self::Future {
Box::pin(async move {
let handshake_len = unsigned_varint::aio::read_usize(&mut socket).await?;
if handshake_len > MAX_HANDSHAKE_SIZE {
return Err(NotificationsHandshakeError::TooLarge {
requested: handshake_len,
max: MAX_HANDSHAKE_SIZE,
})
}
let mut handshake = vec![0u8; handshake_len];
if !handshake.is_empty() {
socket.read_exact(&mut handshake).await?;
}
let mut codec = UviBytes::default();
codec.set_max_len(usize::try_from(self.max_notification_size).unwrap_or(usize::MAX));
let substream = NotificationsInSubstream {
socket: Framed::new(socket, codec),
handshake: NotificationsInSubstreamHandshake::NotSent,
};
Ok(NotificationsInOpen { handshake, substream })
})
}
}
pub struct NotificationsInOpen<TSubstream> {
pub handshake: Vec<u8>,
pub substream: NotificationsInSubstream<TSubstream>,
}
impl<TSubstream> fmt::Debug for NotificationsInOpen<TSubstream> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NotificationsInOpen")
.field("handshake", &self.handshake)
.finish_non_exhaustive()
}
}
impl<TSubstream> NotificationsInSubstream<TSubstream>
where
TSubstream: AsyncRead + AsyncWrite + Unpin,
{
#[cfg(test)]
pub fn new(
socket: Framed<TSubstream, UviBytes<io::Cursor<Vec<u8>>>>,
handshake: NotificationsInSubstreamHandshake,
) -> Self {
Self { socket, handshake }
}
pub fn send_handshake(&mut self, message: impl Into<Vec<u8>>) {
if !matches!(self.handshake, NotificationsInSubstreamHandshake::NotSent) {
error!(target: "sub-libp2p", "Tried to send handshake twice");
return
}
self.handshake = NotificationsInSubstreamHandshake::PendingSend(message.into());
}
pub fn poll_process(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
let mut this = self.project();
loop {
match mem::replace(this.handshake, NotificationsInSubstreamHandshake::Sent) {
NotificationsInSubstreamHandshake::PendingSend(msg) => {
match Sink::poll_ready(this.socket.as_mut(), cx) {
Poll::Ready(_) => {
*this.handshake = NotificationsInSubstreamHandshake::Flush;
match Sink::start_send(this.socket.as_mut(), io::Cursor::new(msg)) {
Ok(()) => {},
Err(err) => return Poll::Ready(Err(err)),
}
},
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::PendingSend(msg);
return Poll::Pending
},
}
},
NotificationsInSubstreamHandshake::Flush => {
match Sink::poll_flush(this.socket.as_mut(), cx)? {
Poll::Ready(()) => {
*this.handshake = NotificationsInSubstreamHandshake::Sent;
return Poll::Ready(Ok(()));
},
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::Flush;
return Poll::Pending
},
}
},
st @ NotificationsInSubstreamHandshake::NotSent |
st @ NotificationsInSubstreamHandshake::Sent |
st @ NotificationsInSubstreamHandshake::ClosingInResponseToRemote |
st @ NotificationsInSubstreamHandshake::BothSidesClosed => {
*this.handshake = st;
return Poll::Ready(Ok(()));
},
}
}
}
}
impl<TSubstream> Stream for NotificationsInSubstream<TSubstream>
where
TSubstream: AsyncRead + AsyncWrite + Unpin,
{
type Item = Result<BytesMut, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut this = self.project();
loop {
match mem::replace(this.handshake, NotificationsInSubstreamHandshake::Sent) {
NotificationsInSubstreamHandshake::NotSent => {
*this.handshake = NotificationsInSubstreamHandshake::NotSent;
return Poll::Pending
},
NotificationsInSubstreamHandshake::PendingSend(msg) => {
match Sink::poll_ready(this.socket.as_mut(), cx) {
Poll::Ready(_) => {
*this.handshake = NotificationsInSubstreamHandshake::Flush;
match Sink::start_send(this.socket.as_mut(), io::Cursor::new(msg)) {
Ok(()) => {},
Err(err) => return Poll::Ready(Some(Err(err))),
}
},
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::PendingSend(msg);
return Poll::Pending
},
}
},
NotificationsInSubstreamHandshake::Flush => {
match Sink::poll_flush(this.socket.as_mut(), cx)? {
Poll::Ready(()) =>
*this.handshake = NotificationsInSubstreamHandshake::Sent,
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::Flush;
return Poll::Pending
},
}
},
NotificationsInSubstreamHandshake::Sent => {
match Stream::poll_next(this.socket.as_mut(), cx) {
Poll::Ready(None) =>
*this.handshake =
NotificationsInSubstreamHandshake::ClosingInResponseToRemote,
Poll::Ready(Some(msg)) => {
*this.handshake = NotificationsInSubstreamHandshake::Sent;
return Poll::Ready(Some(msg))
},
Poll::Pending => {
*this.handshake = NotificationsInSubstreamHandshake::Sent;
return Poll::Pending
},
}
},
NotificationsInSubstreamHandshake::ClosingInResponseToRemote =>
match Sink::poll_close(this.socket.as_mut(), cx)? {
Poll::Ready(()) =>
*this.handshake = NotificationsInSubstreamHandshake::BothSidesClosed,
Poll::Pending => {
*this.handshake =
NotificationsInSubstreamHandshake::ClosingInResponseToRemote;
return Poll::Pending
},
},
NotificationsInSubstreamHandshake::BothSidesClosed => return Poll::Ready(None),
}
}
}
}
impl NotificationsOut {
pub fn new(
main_protocol_name: impl Into<ProtocolName>,
fallback_names: Vec<ProtocolName>,
initial_message: impl Into<Vec<u8>>,
max_notification_size: u64,
) -> Self {
let initial_message = initial_message.into();
if initial_message.len() > MAX_HANDSHAKE_SIZE {
error!(target: "sub-libp2p", "Outbound networking handshake is above allowed protocol limit");
}
let mut protocol_names = fallback_names;
protocol_names.insert(0, main_protocol_name.into());
Self { protocol_names, initial_message, max_notification_size }
}
}
impl UpgradeInfo for NotificationsOut {
type Info = ProtocolName;
type InfoIter = vec::IntoIter<Self::Info>;
fn protocol_info(&self) -> Self::InfoIter {
self.protocol_names.clone().into_iter()
}
}
impl<TSubstream> OutboundUpgrade<TSubstream> for NotificationsOut
where
TSubstream: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = NotificationsOutOpen<TSubstream>;
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
type Error = NotificationsHandshakeError;
fn upgrade_outbound(self, mut socket: TSubstream, negotiated_name: Self::Info) -> Self::Future {
Box::pin(async move {
{
let mut len_data = unsigned_varint::encode::usize_buffer();
let encoded_len =
unsigned_varint::encode::usize(self.initial_message.len(), &mut len_data).len();
socket.write_all(&len_data[..encoded_len]).await?;
}
socket.write_all(&self.initial_message).await?;
socket.flush().await?;
let handshake_len = unsigned_varint::aio::read_usize(&mut socket).await?;
if handshake_len > MAX_HANDSHAKE_SIZE {
return Err(NotificationsHandshakeError::TooLarge {
requested: handshake_len,
max: MAX_HANDSHAKE_SIZE,
})
}
let mut handshake = vec![0u8; handshake_len];
if !handshake.is_empty() {
socket.read_exact(&mut handshake).await?;
}
let mut codec = UviBytes::default();
codec.set_max_len(usize::try_from(self.max_notification_size).unwrap_or(usize::MAX));
Ok(NotificationsOutOpen {
handshake,
negotiated_fallback: if negotiated_name == self.protocol_names[0] {
None
} else {
Some(negotiated_name)
},
substream: NotificationsOutSubstream { socket: Framed::new(socket, codec) },
})
})
}
}
pub struct NotificationsOutOpen<TSubstream> {
pub handshake: Vec<u8>,
pub negotiated_fallback: Option<ProtocolName>,
pub substream: NotificationsOutSubstream<TSubstream>,
}
impl<TSubstream> fmt::Debug for NotificationsOutOpen<TSubstream> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NotificationsOutOpen")
.field("handshake", &self.handshake)
.field("negotiated_fallback", &self.negotiated_fallback)
.finish_non_exhaustive()
}
}
impl<TSubstream> Sink<Vec<u8>> for NotificationsOutSubstream<TSubstream>
where
TSubstream: AsyncRead + AsyncWrite + Unpin,
{
type Error = NotificationsOutError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();
Sink::poll_ready(this.socket.as_mut(), cx).map_err(NotificationsOutError::Io)
}
fn start_send(self: Pin<&mut Self>, item: Vec<u8>) -> Result<(), Self::Error> {
let mut this = self.project();
Sink::start_send(this.socket.as_mut(), io::Cursor::new(item))
.map_err(NotificationsOutError::Io)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();
match Stream::poll_next(this.socket.as_mut(), cx) {
Poll::Pending => {},
Poll::Ready(Some(_)) => {
error!(
target: "sub-libp2p",
"Unexpected incoming data in `NotificationsOutSubstream`",
);
},
Poll::Ready(None) => return Poll::Ready(Err(NotificationsOutError::Terminated)),
}
Sink::poll_flush(this.socket.as_mut(), cx).map_err(NotificationsOutError::Io)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
let mut this = self.project();
Sink::poll_close(this.socket.as_mut(), cx).map_err(NotificationsOutError::Io)
}
}
#[derive(Debug, thiserror::Error)]
pub enum NotificationsHandshakeError {
#[error(transparent)]
Io(#[from] io::Error),
#[error("Initial message or handshake was too large: {requested}")]
TooLarge {
requested: usize,
max: usize,
},
#[error(transparent)]
VarintDecode(#[from] unsigned_varint::decode::Error),
}
impl From<unsigned_varint::io::ReadError> for NotificationsHandshakeError {
fn from(err: unsigned_varint::io::ReadError) -> Self {
match err {
unsigned_varint::io::ReadError::Io(err) => Self::Io(err),
unsigned_varint::io::ReadError::Decode(err) => Self::VarintDecode(err),
_ => {
warn!("Unrecognized varint decoding error");
Self::Io(From::from(io::ErrorKind::InvalidData))
},
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum NotificationsOutError {
#[error(transparent)]
Io(#[from] io::Error),
#[error("substream was closed/reset")]
Terminated,
}
#[cfg(test)]
mod tests {
use crate::ProtocolName;
use super::{
NotificationsHandshakeError, NotificationsIn, NotificationsInOpen,
NotificationsInSubstream, NotificationsOut, NotificationsOutError, NotificationsOutOpen,
NotificationsOutSubstream,
};
use futures::{channel::oneshot, future, prelude::*, SinkExt, StreamExt};
use libp2p::core::{upgrade, InboundUpgrade, OutboundUpgrade, UpgradeInfo};
use std::{pin::Pin, task::Poll};
use tokio::net::{TcpListener, TcpStream};
use tokio_util::compat::TokioAsyncReadCompatExt;
async fn dial(
addr: std::net::SocketAddr,
handshake: impl Into<Vec<u8>>,
) -> Result<
(
Vec<u8>,
NotificationsOutSubstream<
multistream_select::Negotiated<tokio_util::compat::Compat<TcpStream>>,
>,
),
NotificationsHandshakeError,
> {
let socket = TcpStream::connect(addr).await.unwrap();
let notifs_out = NotificationsOut::new("/test/proto/1", Vec::new(), handshake, 1024 * 1024);
let (_, substream) = multistream_select::dialer_select_proto(
socket.compat(),
notifs_out.protocol_info(),
upgrade::Version::V1,
)
.await
.unwrap();
let NotificationsOutOpen { handshake, substream, .. } =
<NotificationsOut as OutboundUpgrade<_>>::upgrade_outbound(
notifs_out,
substream,
"/test/proto/1".into(),
)
.await?;
Ok((handshake, substream))
}
async fn listen_on_localhost(
listener_addr_tx: oneshot::Sender<std::net::SocketAddr>,
) -> Result<
(
Vec<u8>,
NotificationsInSubstream<
multistream_select::Negotiated<tokio_util::compat::Compat<TcpStream>>,
>,
),
NotificationsHandshakeError,
> {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
let (socket, _) = listener.accept().await.unwrap();
let notifs_in = NotificationsIn::new("/test/proto/1", Vec::new(), 1024 * 1024);
let (_, substream) =
multistream_select::listener_select_proto(socket.compat(), notifs_in.protocol_info())
.await
.unwrap();
let NotificationsInOpen { handshake, substream, .. } =
<NotificationsIn as InboundUpgrade<_>>::upgrade_inbound(
notifs_in,
substream,
"/test/proto/1".into(),
)
.await?;
Ok((handshake, substream))
}
#[tokio::test]
async fn basic_works() {
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = tokio::spawn(async move {
let (handshake, mut substream) =
dial(listener_addr_rx.await.unwrap(), &b"initial message"[..]).await.unwrap();
assert_eq!(handshake, b"hello world");
substream.send(b"test message".to_vec()).await.unwrap();
});
let (handshake, mut substream) = listen_on_localhost(listener_addr_tx).await.unwrap();
assert_eq!(handshake, b"initial message");
substream.send_handshake(&b"hello world"[..]);
let msg = substream.next().await.unwrap().unwrap();
assert_eq!(msg.as_ref(), b"test message");
client.await.unwrap();
}
#[tokio::test]
async fn empty_handshake() {
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = tokio::spawn(async move {
let (handshake, mut substream) =
dial(listener_addr_rx.await.unwrap(), vec![]).await.unwrap();
assert!(handshake.is_empty());
substream.send(Default::default()).await.unwrap();
});
let (handshake, mut substream) = listen_on_localhost(listener_addr_tx).await.unwrap();
assert!(handshake.is_empty());
substream.send_handshake(vec![]);
let msg = substream.next().await.unwrap().unwrap();
assert!(msg.as_ref().is_empty());
client.await.unwrap();
}
#[tokio::test]
async fn refused() {
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = tokio::spawn(async move {
let outcome = dial(listener_addr_rx.await.unwrap(), &b"hello"[..]).await;
assert!(outcome.is_err());
});
let (handshake, substream) = listen_on_localhost(listener_addr_tx).await.unwrap();
assert_eq!(handshake, b"hello");
drop(substream);
client.await.unwrap();
}
#[tokio::test]
async fn large_initial_message_refused() {
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = tokio::spawn(async move {
let ret =
dial(listener_addr_rx.await.unwrap(), (0..32768).map(|_| 0).collect::<Vec<_>>())
.await;
assert!(ret.is_err());
});
let _ret = listen_on_localhost(listener_addr_tx).await;
client.await.unwrap();
}
#[tokio::test]
async fn large_handshake_refused() {
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = tokio::spawn(async move {
let ret = dial(listener_addr_rx.await.unwrap(), &b"initial message"[..]).await;
assert!(ret.is_err());
});
let (handshake, mut substream) = listen_on_localhost(listener_addr_tx).await.unwrap();
assert_eq!(handshake, b"initial message");
substream.send_handshake((0..32768).map(|_| 0).collect::<Vec<_>>());
let _ = substream.next().await;
client.await.unwrap();
}
#[tokio::test]
async fn send_handshake_without_polling_for_incoming_data() {
const PROTO_NAME: &str = "/test/proto/1";
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = tokio::spawn(async move {
let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
let NotificationsOutOpen { handshake, .. } = OutboundUpgrade::upgrade_outbound(
NotificationsOut::new(PROTO_NAME, Vec::new(), &b"initial message"[..], 1024 * 1024),
socket.compat(),
ProtocolName::Static(PROTO_NAME),
)
.await
.unwrap();
assert_eq!(handshake, b"hello world");
});
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
let (socket, _) = listener.accept().await.unwrap();
let NotificationsInOpen { handshake, mut substream, .. } = InboundUpgrade::upgrade_inbound(
NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024),
socket.compat(),
ProtocolName::Static(PROTO_NAME),
)
.await
.unwrap();
assert_eq!(handshake, b"initial message");
substream.send_handshake(&b"hello world"[..]);
future::poll_fn(|cx| Pin::new(&mut substream).poll_process(cx)).await.unwrap();
client.await.unwrap();
}
#[tokio::test]
async fn can_detect_dropped_out_substream_without_writing_data() {
const PROTO_NAME: &str = "/test/proto/1";
let (listener_addr_tx, listener_addr_rx) = oneshot::channel();
let client = tokio::spawn(async move {
let socket = TcpStream::connect(listener_addr_rx.await.unwrap()).await.unwrap();
let NotificationsOutOpen { handshake, mut substream, .. } =
OutboundUpgrade::upgrade_outbound(
NotificationsOut::new(
PROTO_NAME,
Vec::new(),
&b"initial message"[..],
1024 * 1024,
),
socket.compat(),
ProtocolName::Static(PROTO_NAME),
)
.await
.unwrap();
assert_eq!(handshake, b"hello world");
future::poll_fn(|cx| match Pin::new(&mut substream).poll_flush(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => {
cx.waker().wake_by_ref();
Poll::Pending
},
Poll::Ready(Err(e)) => {
assert!(matches!(e, NotificationsOutError::Terminated));
Poll::Ready(())
},
})
.await;
});
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
listener_addr_tx.send(listener.local_addr().unwrap()).unwrap();
let (socket, _) = listener.accept().await.unwrap();
let NotificationsInOpen { handshake, mut substream, .. } = InboundUpgrade::upgrade_inbound(
NotificationsIn::new(PROTO_NAME, Vec::new(), 1024 * 1024),
socket.compat(),
ProtocolName::Static(PROTO_NAME),
)
.await
.unwrap();
assert_eq!(handshake, b"initial message");
substream.send_handshake(&b"hello world"[..]);
future::poll_fn(|cx| Pin::new(&mut substream).poll_process(cx)).await.unwrap();
drop(substream);
client.await.unwrap();
}
}