use crate::{strategy::StrategyKey, types::PeerRequest, LOG_TARGET};
use futures::{
channel::oneshot,
future::BoxFuture,
stream::{BoxStream, FusedStream, Stream},
FutureExt, StreamExt,
};
use log::error;
use sc_network::{request_responses::RequestFailure, types::ProtocolName};
use sc_network_types::PeerId;
use sp_runtime::traits::Block as BlockT;
use std::task::{Context, Poll, Waker};
use tokio_stream::StreamMap;
type ResponseResult = Result<Result<(Vec<u8>, ProtocolName), RequestFailure>, oneshot::Canceled>;
type ResponseFuture = BoxFuture<'static, ResponseResult>;
pub(crate) struct ResponseEvent<B: BlockT> {
pub peer_id: PeerId,
pub key: StrategyKey,
pub request: PeerRequest<B>,
pub response: ResponseResult,
}
pub(crate) struct PendingResponses<B: BlockT> {
pending_responses:
StreamMap<(PeerId, StrategyKey), BoxStream<'static, (PeerRequest<B>, ResponseResult)>>,
waker: Option<Waker>,
}
impl<B: BlockT> PendingResponses<B> {
pub fn new() -> Self {
Self { pending_responses: StreamMap::new(), waker: None }
}
pub fn insert(
&mut self,
peer_id: PeerId,
key: StrategyKey,
request: PeerRequest<B>,
response_future: ResponseFuture,
) {
let request_type = request.get_type();
if self
.pending_responses
.insert(
(peer_id, key),
Box::pin(async move { (request, response_future.await) }.into_stream()),
)
.is_some()
{
error!(
target: LOG_TARGET,
"Discarded pending response from peer {peer_id}, request type: {request_type:?}.",
);
debug_assert!(false);
}
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
pub fn remove(&mut self, peer_id: PeerId, key: StrategyKey) -> bool {
self.pending_responses.remove(&(peer_id, key)).is_some()
}
pub fn remove_all(&mut self, peer_id: &PeerId) {
let to_remove = self
.pending_responses
.keys()
.filter(|(peer, _key)| peer == peer_id)
.cloned()
.collect::<Vec<_>>();
to_remove.iter().for_each(|k| {
self.pending_responses.remove(k);
});
}
pub fn len(&self) -> usize {
self.pending_responses.len()
}
}
impl<B: BlockT> Stream for PendingResponses<B> {
type Item = ResponseEvent<B>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
match self.pending_responses.poll_next_unpin(cx) {
Poll::Ready(Some(((peer_id, key), (request, response)))) => {
self.pending_responses.remove(&(peer_id, key));
Poll::Ready(Some(ResponseEvent { peer_id, key, request, response }))
},
Poll::Ready(None) | Poll::Pending => {
self.waker = Some(cx.waker().clone());
Poll::Pending
},
}
}
}
impl<B: BlockT> FusedStream for PendingResponses<B> {
fn is_terminated(&self) -> bool {
false
}
}