use crate::{protocol::libp2p::kademlia::query::QueryId, substream::Substream, PeerId};
use bytes::{Bytes, BytesMut};
use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll, Waker},
time::Duration,
};
const READ_TIMEOUT: Duration = Duration::from_secs(15);
#[derive(Debug)]
pub enum QueryResult {
SendSuccess {
substream: Substream,
},
ReadSuccess {
substream: Substream,
message: BytesMut,
},
Timeout,
SubstreamClosed,
}
#[derive(Debug)]
pub struct QueryContext {
pub peer: PeerId,
pub query_id: Option<QueryId>,
pub result: QueryResult,
}
#[derive(Default)]
pub struct FuturesStream<F> {
futures: FuturesUnordered<F>,
waker: Option<Waker>,
}
impl<F> FuturesStream<F> {
pub fn new() -> Self {
Self {
futures: FuturesUnordered::new(),
waker: None,
}
}
pub fn push(&mut self, future: F) {
self.futures.push(future);
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
impl<F: Future> Stream for FuturesStream<F> {
type Item = <F as Future>::Output;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let Poll::Ready(Some(result)) = self.futures.poll_next_unpin(cx) else {
self.waker = Some(cx.waker().clone());
return Poll::Pending;
};
Poll::Ready(Some(result))
}
}
pub struct QueryExecutor {
futures: FuturesStream<BoxFuture<'static, QueryContext>>,
}
impl QueryExecutor {
pub fn new() -> Self {
Self {
futures: FuturesStream::new(),
}
}
pub fn send_message(&mut self, peer: PeerId, message: Bytes, mut substream: Substream) {
self.futures.push(Box::pin(async move {
match substream.send_framed(message).await {
Ok(_) => QueryContext {
peer,
query_id: None,
result: QueryResult::SendSuccess { substream },
},
Err(_) => QueryContext {
peer,
query_id: None,
result: QueryResult::SubstreamClosed,
},
}
}));
}
pub fn read_message(
&mut self,
peer: PeerId,
query_id: Option<QueryId>,
mut substream: Substream,
) {
self.futures.push(Box::pin(async move {
match tokio::time::timeout(READ_TIMEOUT, substream.next()).await {
Err(_) => QueryContext {
peer,
query_id,
result: QueryResult::Timeout,
},
Ok(Some(Ok(message))) => QueryContext {
peer,
query_id,
result: QueryResult::ReadSuccess { substream, message },
},
Ok(None) | Ok(Some(Err(_))) => QueryContext {
peer,
query_id,
result: QueryResult::SubstreamClosed,
},
}
}));
}
pub fn send_request_read_response(
&mut self,
peer: PeerId,
query_id: Option<QueryId>,
message: Bytes,
mut substream: Substream,
) {
self.futures.push(Box::pin(async move {
if let Err(_) = substream.send_framed(message).await {
let _ = substream.close().await;
return QueryContext {
peer,
query_id,
result: QueryResult::SubstreamClosed,
};
}
match tokio::time::timeout(READ_TIMEOUT, substream.next()).await {
Err(_) => QueryContext {
peer,
query_id,
result: QueryResult::Timeout,
},
Ok(Some(Ok(message))) => QueryContext {
peer,
query_id,
result: QueryResult::ReadSuccess { substream, message },
},
Ok(None) | Ok(Some(Err(_))) => QueryContext {
peer,
query_id,
result: QueryResult::SubstreamClosed,
},
}
}));
}
}
impl Stream for QueryExecutor {
type Item = QueryContext;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.futures.poll_next_unpin(cx)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{mock::substream::MockSubstream, types::SubstreamId};
#[tokio::test]
async fn substream_read_timeout() {
let mut executor = QueryExecutor::new();
let peer = PeerId::random();
let mut substream = MockSubstream::new();
substream.expect_poll_next().returning(|_| Poll::Pending);
let substream = Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream));
executor.read_message(peer, None, substream);
match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
Ok(Some(QueryContext {
peer: queried_peer,
query_id,
result,
})) => {
assert_eq!(peer, queried_peer);
assert!(query_id.is_none());
assert!(std::matches!(result, QueryResult::Timeout));
}
result => panic!("invalid result received: {result:?}"),
}
}
#[tokio::test]
async fn substream_read_substream_closed() {
let mut executor = QueryExecutor::new();
let peer = PeerId::random();
let mut substream = MockSubstream::new();
substream.expect_poll_next().times(1).return_once(|_| {
Poll::Ready(Some(Err(crate::error::SubstreamError::ConnectionClosed)))
});
executor.read_message(
peer,
Some(QueryId(1338)),
Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)),
);
match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
Ok(Some(QueryContext {
peer: queried_peer,
query_id,
result,
})) => {
assert_eq!(peer, queried_peer);
assert_eq!(query_id, Some(QueryId(1338)));
assert!(std::matches!(result, QueryResult::SubstreamClosed));
}
result => panic!("invalid result received: {result:?}"),
}
}
#[tokio::test]
async fn send_succeeds_no_message_read() {
let mut executor = QueryExecutor::new();
let peer = PeerId::random();
let mut substream = MockSubstream::new();
substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(())));
substream.expect_start_send().times(1).return_once(|_| Ok(()));
substream.expect_poll_flush().times(1).return_once(|_| Poll::Ready(Ok(())));
substream.expect_poll_next().times(1).return_once(|_| {
Poll::Ready(Some(Err(crate::error::SubstreamError::ConnectionClosed)))
});
executor.send_request_read_response(
peer,
Some(QueryId(1337)),
Bytes::from_static(b"hello, world"),
Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)),
);
match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
Ok(Some(QueryContext {
peer: queried_peer,
query_id,
result,
})) => {
assert_eq!(peer, queried_peer);
assert_eq!(query_id, Some(QueryId(1337)));
assert!(std::matches!(result, QueryResult::SubstreamClosed));
}
result => panic!("invalid result received: {result:?}"),
}
}
#[tokio::test]
async fn send_fails_no_message_read() {
let mut executor = QueryExecutor::new();
let peer = PeerId::random();
let mut substream = MockSubstream::new();
substream
.expect_poll_ready()
.times(1)
.return_once(|_| Poll::Ready(Err(crate::error::SubstreamError::ConnectionClosed)));
substream.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(())));
executor.send_request_read_response(
peer,
Some(QueryId(1337)),
Bytes::from_static(b"hello, world"),
Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)),
);
match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
Ok(Some(QueryContext {
peer: queried_peer,
query_id,
result,
})) => {
assert_eq!(peer, queried_peer);
assert_eq!(query_id, Some(QueryId(1337)));
assert!(std::matches!(result, QueryResult::SubstreamClosed));
}
result => panic!("invalid result received: {result:?}"),
}
}
#[tokio::test]
async fn read_message_timeout() {
let mut executor = QueryExecutor::new();
let peer = PeerId::random();
let mut substream = MockSubstream::new();
substream.expect_poll_next().returning(|_| Poll::Pending);
executor.read_message(
peer,
Some(QueryId(1336)),
Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)),
);
match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
Ok(Some(QueryContext {
peer: queried_peer,
query_id,
result,
})) => {
assert_eq!(peer, queried_peer);
assert_eq!(query_id, Some(QueryId(1336)));
assert!(std::matches!(result, QueryResult::Timeout));
}
result => panic!("invalid result received: {result:?}"),
}
}
#[tokio::test]
async fn read_message_substream_closed() {
let mut executor = QueryExecutor::new();
let peer = PeerId::random();
let mut substream = MockSubstream::new();
substream
.expect_poll_next()
.times(1)
.return_once(|_| Poll::Ready(Some(Err(crate::error::SubstreamError::ChannelClogged))));
executor.read_message(
peer,
Some(QueryId(1335)),
Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)),
);
match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
Ok(Some(QueryContext {
peer: queried_peer,
query_id,
result,
})) => {
assert_eq!(peer, queried_peer);
assert_eq!(query_id, Some(QueryId(1335)));
assert!(std::matches!(result, QueryResult::SubstreamClosed));
}
result => panic!("invalid result received: {result:?}"),
}
}
}