use crate::{
peer_store::{PeerStoreProvider, BANNED_THRESHOLD},
service::traits::RequestResponseConfig as RequestResponseConfigT,
types::ProtocolName,
ReputationChange,
};
use futures::{channel::oneshot, prelude::*};
use libp2p::{
core::{Endpoint, Multiaddr},
request_response::{self, Behaviour, Codec, Message, ProtocolSupport, ResponseChannel},
swarm::{
behaviour::{ConnectionClosed, FromSwarm},
handler::multi::MultiHandler,
ConnectionDenied, ConnectionId, NetworkBehaviour, PollParameters, THandler,
THandlerInEvent, THandlerOutEvent, ToSwarm,
},
PeerId,
};
use std::{
collections::{hash_map::Entry, HashMap},
io, iter,
ops::Deref,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::{Duration, Instant},
};
pub use libp2p::request_response::{Config, RequestId};
#[derive(Debug, thiserror::Error)]
pub enum OutboundFailure {
#[error("Failed to dial the requested peer")]
DialFailure,
#[error("Timeout while waiting for a response")]
Timeout,
#[error("Connection was closed before a response was received")]
ConnectionClosed,
#[error("The remote supports none of the requested protocols")]
UnsupportedProtocols,
}
impl From<request_response::OutboundFailure> for OutboundFailure {
fn from(out: request_response::OutboundFailure) -> Self {
match out {
request_response::OutboundFailure::DialFailure => OutboundFailure::DialFailure,
request_response::OutboundFailure::Timeout => OutboundFailure::Timeout,
request_response::OutboundFailure::ConnectionClosed =>
OutboundFailure::ConnectionClosed,
request_response::OutboundFailure::UnsupportedProtocols =>
OutboundFailure::UnsupportedProtocols,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum InboundFailure {
#[error("Timeout while receiving request or sending response")]
Timeout,
#[error("Connection was closed before a response could be sent")]
ConnectionClosed,
#[error("The local peer supports none of the protocols requested by the remote")]
UnsupportedProtocols,
#[error("The response channel was dropped without sending a response to the remote")]
ResponseOmission,
}
impl From<request_response::InboundFailure> for InboundFailure {
fn from(out: request_response::InboundFailure) -> Self {
match out {
request_response::InboundFailure::ResponseOmission => InboundFailure::ResponseOmission,
request_response::InboundFailure::Timeout => InboundFailure::Timeout,
request_response::InboundFailure::ConnectionClosed => InboundFailure::ConnectionClosed,
request_response::InboundFailure::UnsupportedProtocols =>
InboundFailure::UnsupportedProtocols,
}
}
}
#[derive(Debug, thiserror::Error)]
#[allow(missing_docs)]
pub enum RequestFailure {
#[error("We are not currently connected to the requested peer.")]
NotConnected,
#[error("Given protocol hasn't been registered.")]
UnknownProtocol,
#[error("Remote has closed the substream before answering, thereby signaling that it considers the request as valid, but refused to answer it.")]
Refused,
#[error("The remote replied, but the local node is no longer interested in the response.")]
Obsolete,
#[error("Problem on the network: {0}")]
Network(OutboundFailure),
}
#[derive(Debug, Clone)]
pub struct ProtocolConfig {
pub name: ProtocolName,
pub fallback_names: Vec<ProtocolName>,
pub max_request_size: u64,
pub max_response_size: u64,
pub request_timeout: Duration,
pub inbound_queue: Option<async_channel::Sender<IncomingRequest>>,
}
impl RequestResponseConfigT for ProtocolConfig {
fn protocol_name(&self) -> &ProtocolName {
&self.name
}
}
#[derive(Debug)]
pub struct IncomingRequest {
pub peer: sc_network_types::PeerId,
pub payload: Vec<u8>,
pub pending_response: oneshot::Sender<OutgoingResponse>,
}
#[derive(Debug)]
pub struct OutgoingResponse {
pub result: Result<Vec<u8>, ()>,
pub reputation_changes: Vec<ReputationChange>,
pub sent_feedback: Option<oneshot::Sender<()>>,
}
struct PendingRequest {
started_at: Instant,
response_tx: oneshot::Sender<Result<(Vec<u8>, ProtocolName), RequestFailure>>,
fallback_request: Option<(Vec<u8>, ProtocolName)>,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum IfDisconnected {
TryConnect,
ImmediateError,
}
impl IfDisconnected {
pub fn should_connect(self) -> bool {
match self {
Self::TryConnect => true,
Self::ImmediateError => false,
}
}
}
#[derive(Debug)]
pub enum Event {
InboundRequest {
peer: PeerId,
protocol: ProtocolName,
result: Result<Duration, ResponseFailure>,
},
RequestFinished {
peer: PeerId,
protocol: ProtocolName,
duration: Duration,
result: Result<(), RequestFailure>,
},
ReputationChanges {
peer: PeerId,
changes: Vec<ReputationChange>,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct ProtocolRequestId {
protocol: ProtocolName,
request_id: RequestId,
}
impl From<(ProtocolName, RequestId)> for ProtocolRequestId {
fn from((protocol, request_id): (ProtocolName, RequestId)) -> Self {
Self { protocol, request_id }
}
}
pub struct RequestResponsesBehaviour {
protocols: HashMap<
ProtocolName,
(Behaviour<GenericCodec>, Option<async_channel::Sender<IncomingRequest>>),
>,
pending_requests: HashMap<ProtocolRequestId, PendingRequest>,
pending_responses: stream::FuturesUnordered<
Pin<Box<dyn Future<Output = Option<RequestProcessingOutcome>> + Send>>,
>,
pending_responses_arrival_time: HashMap<ProtocolRequestId, Instant>,
send_feedback: HashMap<ProtocolRequestId, oneshot::Sender<()>>,
peer_store: Arc<dyn PeerStoreProvider>,
}
struct RequestProcessingOutcome {
peer: PeerId,
request_id: RequestId,
protocol: ProtocolName,
inner_channel: ResponseChannel<Result<Vec<u8>, ()>>,
response: OutgoingResponse,
}
impl RequestResponsesBehaviour {
pub fn new(
list: impl Iterator<Item = ProtocolConfig>,
peer_store: Arc<dyn PeerStoreProvider>,
) -> Result<Self, RegisterError> {
let mut protocols = HashMap::new();
for protocol in list {
let mut cfg = Config::default();
cfg.set_request_timeout(protocol.request_timeout);
let protocol_support = if protocol.inbound_queue.is_some() {
ProtocolSupport::Full
} else {
ProtocolSupport::Outbound
};
let rq_rp = Behaviour::with_codec(
GenericCodec {
max_request_size: protocol.max_request_size,
max_response_size: protocol.max_response_size,
},
iter::once(protocol.name.clone())
.chain(protocol.fallback_names)
.zip(iter::repeat(protocol_support)),
cfg,
);
match protocols.entry(protocol.name) {
Entry::Vacant(e) => e.insert((rq_rp, protocol.inbound_queue)),
Entry::Occupied(e) => return Err(RegisterError::DuplicateProtocol(e.key().clone())),
};
}
Ok(Self {
protocols,
pending_requests: Default::default(),
pending_responses: Default::default(),
pending_responses_arrival_time: Default::default(),
send_feedback: Default::default(),
peer_store,
})
}
pub fn send_request(
&mut self,
target: &PeerId,
protocol_name: ProtocolName,
request: Vec<u8>,
fallback_request: Option<(Vec<u8>, ProtocolName)>,
pending_response: oneshot::Sender<Result<(Vec<u8>, ProtocolName), RequestFailure>>,
connect: IfDisconnected,
) {
log::trace!(target: "sub-libp2p", "send request to {target} ({protocol_name:?}), {} bytes", request.len());
if let Some((protocol, _)) = self.protocols.get_mut(protocol_name.deref()) {
Self::send_request_inner(
protocol,
&mut self.pending_requests,
target,
protocol_name,
request,
fallback_request,
pending_response,
connect,
)
} else if pending_response.send(Err(RequestFailure::UnknownProtocol)).is_err() {
log::debug!(
target: "sub-libp2p",
"Unknown protocol {:?}. At the same time local \
node is no longer interested in the result.",
protocol_name,
);
}
}
fn send_request_inner(
behaviour: &mut Behaviour<GenericCodec>,
pending_requests: &mut HashMap<ProtocolRequestId, PendingRequest>,
target: &PeerId,
protocol_name: ProtocolName,
request: Vec<u8>,
fallback_request: Option<(Vec<u8>, ProtocolName)>,
pending_response: oneshot::Sender<Result<(Vec<u8>, ProtocolName), RequestFailure>>,
connect: IfDisconnected,
) {
if behaviour.is_connected(target) || connect.should_connect() {
let request_id = behaviour.send_request(target, request);
let prev_req_id = pending_requests.insert(
(protocol_name.to_string().into(), request_id).into(),
PendingRequest {
started_at: Instant::now(),
response_tx: pending_response,
fallback_request,
},
);
debug_assert!(prev_req_id.is_none(), "Expect request id to be unique.");
} else if pending_response.send(Err(RequestFailure::NotConnected)).is_err() {
log::debug!(
target: "sub-libp2p",
"Not connected to peer {:?}. At the same time local \
node is no longer interested in the result.",
target,
);
}
}
}
impl NetworkBehaviour for RequestResponsesBehaviour {
type ConnectionHandler =
MultiHandler<String, <Behaviour<GenericCodec> as NetworkBehaviour>::ConnectionHandler>;
type ToSwarm = Event;
fn handle_pending_inbound_connection(
&mut self,
_connection_id: ConnectionId,
_local_addr: &Multiaddr,
_remote_addr: &Multiaddr,
) -> Result<(), ConnectionDenied> {
Ok(())
}
fn handle_pending_outbound_connection(
&mut self,
_connection_id: ConnectionId,
_maybe_peer: Option<PeerId>,
_addresses: &[Multiaddr],
_effective_role: Endpoint,
) -> Result<Vec<Multiaddr>, ConnectionDenied> {
Ok(Vec::new())
}
fn handle_established_inbound_connection(
&mut self,
connection_id: ConnectionId,
peer: PeerId,
local_addr: &Multiaddr,
remote_addr: &Multiaddr,
) -> Result<THandler<Self>, ConnectionDenied> {
let iter = self.protocols.iter_mut().filter_map(|(p, (r, _))| {
if let Ok(handler) = r.handle_established_inbound_connection(
connection_id,
peer,
local_addr,
remote_addr,
) {
Some((p.to_string(), handler))
} else {
None
}
});
Ok(MultiHandler::try_from_iter(iter).expect(
"Protocols are in a HashMap and there can be at most one handler per protocol name, \
which is the only possible error; qed",
))
}
fn handle_established_outbound_connection(
&mut self,
connection_id: ConnectionId,
peer: PeerId,
addr: &Multiaddr,
role_override: Endpoint,
) -> Result<THandler<Self>, ConnectionDenied> {
let iter = self.protocols.iter_mut().filter_map(|(p, (r, _))| {
if let Ok(handler) =
r.handle_established_outbound_connection(connection_id, peer, addr, role_override)
{
Some((p.to_string(), handler))
} else {
None
}
});
Ok(MultiHandler::try_from_iter(iter).expect(
"Protocols are in a HashMap and there can be at most one handler per protocol name, \
which is the only possible error; qed",
))
}
fn on_swarm_event(&mut self, event: FromSwarm<Self::ConnectionHandler>) {
match event {
FromSwarm::ConnectionEstablished(e) =>
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::on_swarm_event(p, FromSwarm::ConnectionEstablished(e));
},
FromSwarm::ConnectionClosed(ConnectionClosed {
peer_id,
connection_id,
endpoint,
handler,
remaining_established,
}) =>
for (p_name, p_handler) in handler.into_iter() {
if let Some((proto, _)) = self.protocols.get_mut(p_name.as_str()) {
proto.on_swarm_event(FromSwarm::ConnectionClosed(ConnectionClosed {
peer_id,
connection_id,
endpoint,
handler: p_handler,
remaining_established,
}));
} else {
log::error!(
target: "sub-libp2p",
"on_swarm_event/connection_closed: no request-response instance registered for protocol {:?}",
p_name,
)
}
},
FromSwarm::DialFailure(e) =>
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::on_swarm_event(p, FromSwarm::DialFailure(e));
},
FromSwarm::ListenerClosed(e) =>
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::on_swarm_event(p, FromSwarm::ListenerClosed(e));
},
FromSwarm::ListenFailure(e) =>
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::on_swarm_event(p, FromSwarm::ListenFailure(e));
},
FromSwarm::ListenerError(e) =>
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::on_swarm_event(p, FromSwarm::ListenerError(e));
},
FromSwarm::ExternalAddrExpired(e) =>
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::on_swarm_event(p, FromSwarm::ExternalAddrExpired(e));
},
FromSwarm::NewListener(e) =>
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::on_swarm_event(p, FromSwarm::NewListener(e));
},
FromSwarm::ExpiredListenAddr(e) =>
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::on_swarm_event(p, FromSwarm::ExpiredListenAddr(e));
},
FromSwarm::NewExternalAddrCandidate(e) =>
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::on_swarm_event(p, FromSwarm::NewExternalAddrCandidate(e));
},
FromSwarm::ExternalAddrConfirmed(e) =>
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::on_swarm_event(p, FromSwarm::ExternalAddrConfirmed(e));
},
FromSwarm::AddressChange(e) =>
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::on_swarm_event(p, FromSwarm::AddressChange(e));
},
FromSwarm::NewListenAddr(e) =>
for (p, _) in self.protocols.values_mut() {
NetworkBehaviour::on_swarm_event(p, FromSwarm::NewListenAddr(e));
},
}
}
fn on_connection_handler_event(
&mut self,
peer_id: PeerId,
connection_id: ConnectionId,
event: THandlerOutEvent<Self>,
) {
let p_name = event.0;
if let Some((proto, _)) = self.protocols.get_mut(p_name.as_str()) {
return proto.on_connection_handler_event(peer_id, connection_id, event.1)
} else {
log::warn!(
target: "sub-libp2p",
"on_connection_handler_event: no request-response instance registered for protocol {:?}",
p_name
);
}
}
fn poll(
&mut self,
cx: &mut Context,
params: &mut impl PollParameters,
) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
'poll_all: loop {
while let Poll::Ready(Some(outcome)) = self.pending_responses.poll_next_unpin(cx) {
let RequestProcessingOutcome {
peer,
request_id,
protocol: protocol_name,
inner_channel,
response: OutgoingResponse { result, reputation_changes, sent_feedback },
} = match outcome {
Some(outcome) => outcome,
None => continue,
};
if let Ok(payload) = result {
if let Some((protocol, _)) = self.protocols.get_mut(&*protocol_name) {
log::trace!(target: "sub-libp2p", "send response to {peer} ({protocol_name:?}), {} bytes", payload.len());
if protocol.send_response(inner_channel, Ok(payload)).is_err() {
log::debug!(
target: "sub-libp2p",
"Failed to send response for {:?} on protocol {:?} due to a \
timeout or due to the connection to the peer being closed. \
Dropping response",
request_id, protocol_name,
);
} else if let Some(sent_feedback) = sent_feedback {
self.send_feedback
.insert((protocol_name, request_id).into(), sent_feedback);
}
}
}
if !reputation_changes.is_empty() {
return Poll::Ready(ToSwarm::GenerateEvent(Event::ReputationChanges {
peer,
changes: reputation_changes,
}))
}
}
let mut fallback_requests = vec![];
for (protocol, (ref mut behaviour, ref mut resp_builder)) in &mut self.protocols {
'poll_protocol: while let Poll::Ready(ev) = behaviour.poll(cx, params) {
let ev = match ev {
ToSwarm::GenerateEvent(ev) => ev,
ToSwarm::Dial { opts } => {
if opts.get_peer_id().is_none() {
log::error!(
"The request-response isn't supposed to start dialing addresses"
);
}
return Poll::Ready(ToSwarm::Dial { opts })
},
ToSwarm::NotifyHandler { peer_id, handler, event } =>
return Poll::Ready(ToSwarm::NotifyHandler {
peer_id,
handler,
event: ((*protocol).to_string(), event),
}),
ToSwarm::CloseConnection { peer_id, connection } =>
return Poll::Ready(ToSwarm::CloseConnection { peer_id, connection }),
ToSwarm::NewExternalAddrCandidate(observed) =>
return Poll::Ready(ToSwarm::NewExternalAddrCandidate(observed)),
ToSwarm::ExternalAddrConfirmed(addr) =>
return Poll::Ready(ToSwarm::ExternalAddrConfirmed(addr)),
ToSwarm::ExternalAddrExpired(addr) =>
return Poll::Ready(ToSwarm::ExternalAddrExpired(addr)),
ToSwarm::ListenOn { opts } =>
return Poll::Ready(ToSwarm::ListenOn { opts }),
ToSwarm::RemoveListener { id } =>
return Poll::Ready(ToSwarm::RemoveListener { id }),
};
match ev {
request_response::Event::Message {
peer,
message: Message::Request { request_id, request, channel, .. },
} => {
self.pending_responses_arrival_time
.insert((protocol.clone(), request_id).into(), Instant::now());
let reputation = self.peer_store.peer_reputation(&peer.into());
if reputation < BANNED_THRESHOLD {
log::debug!(
target: "sub-libp2p",
"Cannot handle requests from a node with a low reputation {}: {}",
peer,
reputation,
);
continue 'poll_protocol
}
let (tx, rx) = oneshot::channel();
if let Some(resp_builder) = resp_builder {
let _ = resp_builder.try_send(IncomingRequest {
peer: peer.into(),
payload: request,
pending_response: tx,
});
} else {
debug_assert!(false, "Received message on outbound-only protocol.");
}
let protocol = protocol.clone();
self.pending_responses.push(Box::pin(async move {
rx.await.map_or(None, |response| {
Some(RequestProcessingOutcome {
peer,
request_id,
protocol,
inner_channel: channel,
response,
})
})
}));
continue 'poll_all
},
request_response::Event::Message {
peer,
message: Message::Response { request_id, response },
..
} => {
let (started, delivered) = match self
.pending_requests
.remove(&(protocol.clone(), request_id).into())
{
Some(PendingRequest { started_at, response_tx, .. }) => {
log::trace!(
target: "sub-libp2p",
"received response from {peer} ({protocol:?}), {} bytes",
response.as_ref().map_or(0usize, |response| response.len()),
);
let delivered = response_tx
.send(
response
.map_err(|()| RequestFailure::Refused)
.map(|resp| (resp, protocol.clone())),
)
.map_err(|_| RequestFailure::Obsolete);
(started_at, delivered)
},
None => {
log::warn!(
target: "sub-libp2p",
"Received `RequestResponseEvent::Message` with unexpected request id {:?}",
request_id,
);
debug_assert!(false);
continue
},
};
let out = Event::RequestFinished {
peer,
protocol: protocol.clone(),
duration: started.elapsed(),
result: delivered,
};
return Poll::Ready(ToSwarm::GenerateEvent(out))
},
request_response::Event::OutboundFailure {
peer,
request_id,
error,
..
} => {
let started = match self
.pending_requests
.remove(&(protocol.clone(), request_id).into())
{
Some(PendingRequest {
started_at,
response_tx,
fallback_request,
}) => {
if let request_response::OutboundFailure::UnsupportedProtocols =
error
{
if let Some((fallback_request, fallback_protocol)) =
fallback_request
{
log::trace!(
target: "sub-libp2p",
"Request with id {:?} failed. Trying the fallback protocol. {}",
request_id,
fallback_protocol.deref()
);
fallback_requests.push((
peer,
fallback_protocol,
fallback_request,
response_tx,
));
continue
}
}
if response_tx
.send(Err(RequestFailure::Network(error.clone().into())))
.is_err()
{
log::debug!(
target: "sub-libp2p",
"Request with id {:?} failed. At the same time local \
node is no longer interested in the result.",
request_id,
);
}
started_at
},
None => {
log::warn!(
target: "sub-libp2p",
"Received `RequestResponseEvent::Message` with unexpected request id {:?}",
request_id,
);
debug_assert!(false);
continue
},
};
let out = Event::RequestFinished {
peer,
protocol: protocol.clone(),
duration: started.elapsed(),
result: Err(RequestFailure::Network(error.into())),
};
return Poll::Ready(ToSwarm::GenerateEvent(out))
},
request_response::Event::InboundFailure {
request_id, peer, error, ..
} => {
self.pending_responses_arrival_time
.remove(&(protocol.clone(), request_id).into());
self.send_feedback.remove(&(protocol.clone(), request_id).into());
let out = Event::InboundRequest {
peer,
protocol: protocol.clone(),
result: Err(ResponseFailure::Network(error.into())),
};
return Poll::Ready(ToSwarm::GenerateEvent(out))
},
request_response::Event::ResponseSent { request_id, peer } => {
let arrival_time = self
.pending_responses_arrival_time
.remove(&(protocol.clone(), request_id).into())
.map(|t| t.elapsed())
.expect(
"Time is added for each inbound request on arrival and only \
removed on success (`ResponseSent`) or failure \
(`InboundFailure`). One can not receive a success event for a \
request that either never arrived, or that has previously \
failed; qed.",
);
if let Some(send_feedback) =
self.send_feedback.remove(&(protocol.clone(), request_id).into())
{
let _ = send_feedback.send(());
}
let out = Event::InboundRequest {
peer,
protocol: protocol.clone(),
result: Ok(arrival_time),
};
return Poll::Ready(ToSwarm::GenerateEvent(out))
},
};
}
}
for (peer, protocol, request, pending_response) in fallback_requests.drain(..) {
if let Some((behaviour, _)) = self.protocols.get_mut(&protocol) {
Self::send_request_inner(
behaviour,
&mut self.pending_requests,
&peer,
protocol,
request,
None,
pending_response,
IfDisconnected::ImmediateError,
);
}
}
break Poll::Pending
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum RegisterError {
#[error("{0}")]
DuplicateProtocol(ProtocolName),
}
#[derive(Debug, thiserror::Error)]
pub enum ResponseFailure {
#[error("Problem on the network: {0}")]
Network(InboundFailure),
}
#[derive(Debug, Clone)]
#[doc(hidden)] pub struct GenericCodec {
max_request_size: u64,
max_response_size: u64,
}
#[async_trait::async_trait]
impl Codec for GenericCodec {
type Protocol = ProtocolName;
type Request = Vec<u8>;
type Response = Result<Vec<u8>, ()>;
async fn read_request<T>(
&mut self,
_: &Self::Protocol,
mut io: &mut T,
) -> io::Result<Self::Request>
where
T: AsyncRead + Unpin + Send,
{
let length = unsigned_varint::aio::read_usize(&mut io)
.await
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
if length > usize::try_from(self.max_request_size).unwrap_or(usize::MAX) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Request size exceeds limit: {} > {}", length, self.max_request_size),
))
}
let mut buffer = vec![0; length];
io.read_exact(&mut buffer).await?;
Ok(buffer)
}
async fn read_response<T>(
&mut self,
_: &Self::Protocol,
mut io: &mut T,
) -> io::Result<Self::Response>
where
T: AsyncRead + Unpin + Send,
{
let length = match unsigned_varint::aio::read_usize(&mut io).await {
Ok(l) => l,
Err(unsigned_varint::io::ReadError::Io(err))
if matches!(err.kind(), io::ErrorKind::UnexpectedEof) =>
return Ok(Err(())),
Err(err) => return Err(io::Error::new(io::ErrorKind::InvalidInput, err)),
};
if length > usize::try_from(self.max_response_size).unwrap_or(usize::MAX) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Response size exceeds limit: {} > {}", length, self.max_response_size),
))
}
let mut buffer = vec![0; length];
io.read_exact(&mut buffer).await?;
Ok(Ok(buffer))
}
async fn write_request<T>(
&mut self,
_: &Self::Protocol,
io: &mut T,
req: Self::Request,
) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send,
{
{
let mut buffer = unsigned_varint::encode::usize_buffer();
io.write_all(unsigned_varint::encode::usize(req.len(), &mut buffer)).await?;
}
io.write_all(&req).await?;
io.close().await?;
Ok(())
}
async fn write_response<T>(
&mut self,
_: &Self::Protocol,
io: &mut T,
res: Self::Response,
) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send,
{
if let Ok(res) = res {
{
let mut buffer = unsigned_varint::encode::usize_buffer();
io.write_all(unsigned_varint::encode::usize(res.len(), &mut buffer)).await?;
}
io.write_all(&res).await?;
}
io.close().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mock::MockPeerStore;
use assert_matches::assert_matches;
use futures::{channel::oneshot, executor::LocalPool, task::Spawn};
use libp2p::{
core::{
transport::{MemoryTransport, Transport},
upgrade,
},
identity::Keypair,
noise,
swarm::{Config as SwarmConfig, Executor, Swarm, SwarmEvent},
Multiaddr,
};
use std::{iter, time::Duration};
struct TokioExecutor(tokio::runtime::Runtime);
impl Executor for TokioExecutor {
fn exec(&self, f: Pin<Box<dyn Future<Output = ()> + Send>>) {
let _ = self.0.spawn(f);
}
}
fn build_swarm(
list: impl Iterator<Item = ProtocolConfig>,
) -> (Swarm<RequestResponsesBehaviour>, Multiaddr) {
let keypair = Keypair::generate_ed25519();
let transport = MemoryTransport::new()
.upgrade(upgrade::Version::V1)
.authenticate(noise::Config::new(&keypair).unwrap())
.multiplex(libp2p::yamux::Config::default())
.boxed();
let behaviour = RequestResponsesBehaviour::new(list, Arc::new(MockPeerStore {})).unwrap();
let runtime = tokio::runtime::Runtime::new().unwrap();
let mut swarm = Swarm::new(
transport,
behaviour,
keypair.public().to_peer_id(),
SwarmConfig::with_executor(TokioExecutor(runtime)),
);
let listen_addr: Multiaddr = format!("/memory/{}", rand::random::<u64>()).parse().unwrap();
swarm.listen_on(listen_addr.clone()).unwrap();
(swarm, listen_addr)
}
#[test]
fn basic_request_response_works() {
let protocol_name = ProtocolName::from("/test/req-resp/1");
let mut pool = LocalPool::new();
let mut swarms = (0..2)
.map(|_| {
let (tx, mut rx) = async_channel::bounded::<IncomingRequest>(64);
pool.spawner()
.spawn_obj(
async move {
while let Some(rq) = rx.next().await {
let (fb_tx, fb_rx) = oneshot::channel();
assert_eq!(rq.payload, b"this is a request");
let _ = rq.pending_response.send(super::OutgoingResponse {
result: Ok(b"this is a response".to_vec()),
reputation_changes: Vec::new(),
sent_feedback: Some(fb_tx),
});
fb_rx.await.unwrap();
}
}
.boxed()
.into(),
)
.unwrap();
let protocol_config = ProtocolConfig {
name: protocol_name.clone(),
fallback_names: Vec::new(),
max_request_size: 1024,
max_response_size: 1024 * 1024,
request_timeout: Duration::from_secs(30),
inbound_queue: Some(tx),
};
build_swarm(iter::once(protocol_config))
})
.collect::<Vec<_>>();
{
let dial_addr = swarms[1].1.clone();
Swarm::dial(&mut swarms[0].0, dial_addr).unwrap();
}
let (mut swarm, _) = swarms.remove(0);
pool.spawner()
.spawn_obj({
async move {
loop {
match swarm.select_next_some().await {
SwarmEvent::Behaviour(Event::InboundRequest { result, .. }) => {
result.unwrap();
},
_ => {},
}
}
}
.boxed()
.into()
})
.unwrap();
let (mut swarm, _) = swarms.remove(0);
pool.run_until(async move {
let mut response_receiver = None;
loop {
match swarm.select_next_some().await {
SwarmEvent::ConnectionEstablished { peer_id, .. } => {
let (sender, receiver) = oneshot::channel();
swarm.behaviour_mut().send_request(
&peer_id,
protocol_name.clone(),
b"this is a request".to_vec(),
None,
sender,
IfDisconnected::ImmediateError,
);
assert!(response_receiver.is_none());
response_receiver = Some(receiver);
},
SwarmEvent::Behaviour(Event::RequestFinished { result, .. }) => {
result.unwrap();
break
},
_ => {},
}
}
assert_eq!(
response_receiver.unwrap().await.unwrap().unwrap(),
(b"this is a response".to_vec(), protocol_name)
);
});
}
#[test]
fn max_response_size_exceeded() {
let protocol_name = ProtocolName::from("/test/req-resp/1");
let mut pool = LocalPool::new();
let mut swarms = (0..2)
.map(|_| {
let (tx, mut rx) = async_channel::bounded::<IncomingRequest>(64);
pool.spawner()
.spawn_obj(
async move {
while let Some(rq) = rx.next().await {
assert_eq!(rq.payload, b"this is a request");
let _ = rq.pending_response.send(super::OutgoingResponse {
result: Ok(b"this response exceeds the limit".to_vec()),
reputation_changes: Vec::new(),
sent_feedback: None,
});
}
}
.boxed()
.into(),
)
.unwrap();
let protocol_config = ProtocolConfig {
name: protocol_name.clone(),
fallback_names: Vec::new(),
max_request_size: 1024,
max_response_size: 8, request_timeout: Duration::from_secs(30),
inbound_queue: Some(tx),
};
build_swarm(iter::once(protocol_config))
})
.collect::<Vec<_>>();
{
let dial_addr = swarms[1].1.clone();
Swarm::dial(&mut swarms[0].0, dial_addr).unwrap();
}
let (mut swarm, _) = swarms.remove(0);
pool.spawner()
.spawn_obj({
async move {
loop {
match swarm.select_next_some().await {
SwarmEvent::Behaviour(Event::InboundRequest { result, .. }) => {
assert!(result.is_ok());
break
},
_ => {},
}
}
}
.boxed()
.into()
})
.unwrap();
let (mut swarm, _) = swarms.remove(0);
pool.run_until(async move {
let mut response_receiver = None;
loop {
match swarm.select_next_some().await {
SwarmEvent::ConnectionEstablished { peer_id, .. } => {
let (sender, receiver) = oneshot::channel();
swarm.behaviour_mut().send_request(
&peer_id,
protocol_name.clone(),
b"this is a request".to_vec(),
None,
sender,
IfDisconnected::ImmediateError,
);
assert!(response_receiver.is_none());
response_receiver = Some(receiver);
},
SwarmEvent::Behaviour(Event::RequestFinished { result, .. }) => {
assert!(result.is_err());
break
},
_ => {},
}
}
match response_receiver.unwrap().await.unwrap().unwrap_err() {
RequestFailure::Network(OutboundFailure::ConnectionClosed) => {},
_ => panic!(),
}
});
}
#[test]
fn request_id_collision() {
let protocol_name_1 = ProtocolName::from("/test/req-resp-1/1");
let protocol_name_2 = ProtocolName::from("/test/req-resp-2/1");
let mut pool = LocalPool::new();
let mut swarm_1 = {
let protocol_configs = vec![
ProtocolConfig {
name: protocol_name_1.clone(),
fallback_names: Vec::new(),
max_request_size: 1024,
max_response_size: 1024 * 1024,
request_timeout: Duration::from_secs(30),
inbound_queue: None,
},
ProtocolConfig {
name: protocol_name_2.clone(),
fallback_names: Vec::new(),
max_request_size: 1024,
max_response_size: 1024 * 1024,
request_timeout: Duration::from_secs(30),
inbound_queue: None,
},
];
build_swarm(protocol_configs.into_iter()).0
};
let (mut swarm_2, mut swarm_2_handler_1, mut swarm_2_handler_2, listen_add_2) = {
let (tx_1, rx_1) = async_channel::bounded(64);
let (tx_2, rx_2) = async_channel::bounded(64);
let protocol_configs = vec![
ProtocolConfig {
name: protocol_name_1.clone(),
fallback_names: Vec::new(),
max_request_size: 1024,
max_response_size: 1024 * 1024,
request_timeout: Duration::from_secs(30),
inbound_queue: Some(tx_1),
},
ProtocolConfig {
name: protocol_name_2.clone(),
fallback_names: Vec::new(),
max_request_size: 1024,
max_response_size: 1024 * 1024,
request_timeout: Duration::from_secs(30),
inbound_queue: Some(tx_2),
},
];
let (swarm, listen_addr) = build_swarm(protocol_configs.into_iter());
(swarm, rx_1, rx_2, listen_addr)
};
swarm_1.dial(listen_add_2).unwrap();
pool.spawner()
.spawn_obj(
async move {
loop {
match swarm_2.select_next_some().await {
SwarmEvent::Behaviour(Event::InboundRequest { result, .. }) => {
result.unwrap();
},
_ => {},
}
}
}
.boxed()
.into(),
)
.unwrap();
pool.spawner()
.spawn_obj(
async move {
let protocol_1_request = swarm_2_handler_1.next().await;
let protocol_2_request = swarm_2_handler_2.next().await;
protocol_1_request
.unwrap()
.pending_response
.send(OutgoingResponse {
result: Ok(b"this is a response".to_vec()),
reputation_changes: Vec::new(),
sent_feedback: None,
})
.unwrap();
protocol_2_request
.unwrap()
.pending_response
.send(OutgoingResponse {
result: Ok(b"this is a response".to_vec()),
reputation_changes: Vec::new(),
sent_feedback: None,
})
.unwrap();
}
.boxed()
.into(),
)
.unwrap();
pool.run_until(async move {
let mut response_receivers = None;
let mut num_responses = 0;
loop {
match swarm_1.select_next_some().await {
SwarmEvent::ConnectionEstablished { peer_id, .. } => {
let (sender_1, receiver_1) = oneshot::channel();
let (sender_2, receiver_2) = oneshot::channel();
swarm_1.behaviour_mut().send_request(
&peer_id,
protocol_name_1.clone(),
b"this is a request".to_vec(),
None,
sender_1,
IfDisconnected::ImmediateError,
);
swarm_1.behaviour_mut().send_request(
&peer_id,
protocol_name_2.clone(),
b"this is a request".to_vec(),
None,
sender_2,
IfDisconnected::ImmediateError,
);
assert!(response_receivers.is_none());
response_receivers = Some((receiver_1, receiver_2));
},
SwarmEvent::Behaviour(Event::RequestFinished { result, .. }) => {
num_responses += 1;
result.unwrap();
if num_responses == 2 {
break
}
},
_ => {},
}
}
let (response_receiver_1, response_receiver_2) = response_receivers.unwrap();
assert_eq!(
response_receiver_1.await.unwrap().unwrap(),
(b"this is a response".to_vec(), protocol_name_1)
);
assert_eq!(
response_receiver_2.await.unwrap().unwrap(),
(b"this is a response".to_vec(), protocol_name_2)
);
});
}
#[test]
fn request_fallback() {
let protocol_name_1 = ProtocolName::from("/test/req-resp/2");
let protocol_name_1_fallback = ProtocolName::from("/test/req-resp/1");
let protocol_name_2 = ProtocolName::from("/test/another");
let mut pool = LocalPool::new();
let protocol_config_1 = ProtocolConfig {
name: protocol_name_1.clone(),
fallback_names: Vec::new(),
max_request_size: 1024,
max_response_size: 1024 * 1024,
request_timeout: Duration::from_secs(30),
inbound_queue: None,
};
let protocol_config_1_fallback = ProtocolConfig {
name: protocol_name_1_fallback.clone(),
fallback_names: Vec::new(),
max_request_size: 1024,
max_response_size: 1024 * 1024,
request_timeout: Duration::from_secs(30),
inbound_queue: None,
};
let protocol_config_2 = ProtocolConfig {
name: protocol_name_2.clone(),
fallback_names: Vec::new(),
max_request_size: 1024,
max_response_size: 1024 * 1024,
request_timeout: Duration::from_secs(30),
inbound_queue: None,
};
let mut older_swarm = {
let (tx_1, mut rx_1) = async_channel::bounded::<IncomingRequest>(64);
let (tx_2, mut rx_2) = async_channel::bounded::<IncomingRequest>(64);
let mut protocol_config_1_fallback = protocol_config_1_fallback.clone();
protocol_config_1_fallback.inbound_queue = Some(tx_1);
let mut protocol_config_2 = protocol_config_2.clone();
protocol_config_2.inbound_queue = Some(tx_2);
pool.spawner()
.spawn_obj(
async move {
for _ in 0..2 {
if let Some(rq) = rx_1.next().await {
let (fb_tx, fb_rx) = oneshot::channel();
assert_eq!(rq.payload, b"request on protocol /test/req-resp/1");
let _ = rq.pending_response.send(super::OutgoingResponse {
result: Ok(
b"this is a response on protocol /test/req-resp/1".to_vec()
),
reputation_changes: Vec::new(),
sent_feedback: Some(fb_tx),
});
fb_rx.await.unwrap();
}
}
if let Some(rq) = rx_2.next().await {
let (fb_tx, fb_rx) = oneshot::channel();
assert_eq!(rq.payload, b"request on protocol /test/other");
let _ = rq.pending_response.send(super::OutgoingResponse {
result: Ok(b"this is a response on protocol /test/other".to_vec()),
reputation_changes: Vec::new(),
sent_feedback: Some(fb_tx),
});
fb_rx.await.unwrap();
}
}
.boxed()
.into(),
)
.unwrap();
build_swarm(vec![protocol_config_1_fallback, protocol_config_2].into_iter())
};
let mut new_swarm = build_swarm(
vec![
protocol_config_1.clone(),
protocol_config_1_fallback.clone(),
protocol_config_2.clone(),
]
.into_iter(),
);
{
let dial_addr = older_swarm.1.clone();
Swarm::dial(&mut new_swarm.0, dial_addr).unwrap();
}
pool.spawner()
.spawn_obj({
async move {
loop {
_ = older_swarm.0.select_next_some().await;
}
}
.boxed()
.into()
})
.unwrap();
let (mut swarm, _) = new_swarm;
let mut older_peer_id = None;
pool.run_until(async move {
let mut response_receiver = None;
loop {
match swarm.select_next_some().await {
SwarmEvent::ConnectionEstablished { peer_id, .. } => {
older_peer_id = Some(peer_id);
let (sender, receiver) = oneshot::channel();
swarm.behaviour_mut().send_request(
&peer_id,
protocol_name_1.clone(),
b"request on protocol /test/req-resp/2".to_vec(),
Some((
b"request on protocol /test/req-resp/1".to_vec(),
protocol_config_1_fallback.name.clone(),
)),
sender,
IfDisconnected::ImmediateError,
);
response_receiver = Some(receiver);
},
SwarmEvent::Behaviour(Event::RequestFinished { result, .. }) => {
result.unwrap();
break
},
_ => {},
}
}
assert_eq!(
response_receiver.unwrap().await.unwrap().unwrap(),
(
b"this is a response on protocol /test/req-resp/1".to_vec(),
protocol_name_1_fallback.clone()
)
);
let (sender, response_receiver) = oneshot::channel();
swarm.behaviour_mut().send_request(
older_peer_id.as_ref().unwrap(),
protocol_name_1_fallback.clone(),
b"request on protocol /test/req-resp/1".to_vec(),
Some((
b"dummy request, will fail if processed".to_vec(),
protocol_config_1_fallback.name.clone(),
)),
sender,
IfDisconnected::ImmediateError,
);
loop {
match swarm.select_next_some().await {
SwarmEvent::Behaviour(Event::RequestFinished { result, .. }) => {
result.unwrap();
break
},
_ => {},
}
}
assert_eq!(
response_receiver.await.unwrap().unwrap(),
(
b"this is a response on protocol /test/req-resp/1".to_vec(),
protocol_name_1_fallback.clone()
)
);
let (sender, response_receiver) = oneshot::channel();
swarm.behaviour_mut().send_request(
older_peer_id.as_ref().unwrap(),
protocol_name_1.clone(),
b"request on protocol /test/req-resp-2".to_vec(),
None,
sender,
IfDisconnected::ImmediateError,
);
loop {
match swarm.select_next_some().await {
SwarmEvent::Behaviour(Event::RequestFinished { result, .. }) => {
assert_matches!(
result.unwrap_err(),
RequestFailure::Network(OutboundFailure::UnsupportedProtocols)
);
break
},
_ => {},
}
}
assert!(response_receiver.await.unwrap().is_err());
let (sender, response_receiver) = oneshot::channel();
swarm.behaviour_mut().send_request(
older_peer_id.as_ref().unwrap(),
protocol_name_2.clone(),
b"request on protocol /test/other".to_vec(),
None,
sender,
IfDisconnected::ImmediateError,
);
loop {
match swarm.select_next_some().await {
SwarmEvent::Behaviour(Event::RequestFinished { result, .. }) => {
result.unwrap();
break
},
_ => {},
}
}
assert_eq!(
response_receiver.await.unwrap().unwrap(),
(b"this is a response on protocol /test/other".to_vec(), protocol_name_2.clone())
);
});
}
}