use futures::{channel::oneshot, StreamExt};
use sc_network_types::PeerId;
use sc_network::{
request_responses::{IfDisconnected, RequestFailure},
types::ProtocolName,
NetworkPeers, NetworkRequest, ReputationChange,
};
use sc_utils::mpsc::{tracing_unbounded, TracingUnboundedReceiver, TracingUnboundedSender};
use std::sync::Arc;
pub trait Network: NetworkPeers + NetworkRequest {}
impl<T> Network for T where T: NetworkPeers + NetworkRequest {}
pub struct NetworkServiceProvider {
rx: TracingUnboundedReceiver<ToServiceCommand>,
handle: NetworkServiceHandle,
}
#[derive(Debug)]
pub enum ToServiceCommand {
DisconnectPeer(PeerId, ProtocolName),
ReportPeer(PeerId, ReputationChange),
StartRequest(
PeerId,
ProtocolName,
Vec<u8>,
oneshot::Sender<Result<(Vec<u8>, ProtocolName), RequestFailure>>,
IfDisconnected,
),
}
#[derive(Debug, Clone)]
pub struct NetworkServiceHandle {
tx: TracingUnboundedSender<ToServiceCommand>,
}
impl NetworkServiceHandle {
pub fn new(tx: TracingUnboundedSender<ToServiceCommand>) -> NetworkServiceHandle {
Self { tx }
}
pub fn report_peer(&self, who: PeerId, cost_benefit: ReputationChange) {
let _ = self.tx.unbounded_send(ToServiceCommand::ReportPeer(who, cost_benefit));
}
pub fn disconnect_peer(&self, who: PeerId, protocol: ProtocolName) {
let _ = self.tx.unbounded_send(ToServiceCommand::DisconnectPeer(who, protocol));
}
pub fn start_request(
&self,
who: PeerId,
protocol: ProtocolName,
request: Vec<u8>,
tx: oneshot::Sender<Result<(Vec<u8>, ProtocolName), RequestFailure>>,
connect: IfDisconnected,
) {
let _ = self
.tx
.unbounded_send(ToServiceCommand::StartRequest(who, protocol, request, tx, connect));
}
}
impl NetworkServiceProvider {
pub fn new() -> Self {
let (tx, rx) = tracing_unbounded("mpsc_network_service_provider", 100_000);
Self { rx, handle: NetworkServiceHandle::new(tx) }
}
pub fn handle(&self) -> NetworkServiceHandle {
self.handle.clone()
}
pub async fn run(self, service: Arc<dyn Network + Send + Sync>) {
let Self { mut rx, handle } = self;
drop(handle);
while let Some(inner) = rx.next().await {
match inner {
ToServiceCommand::DisconnectPeer(peer, protocol_name) =>
service.disconnect_peer(peer, protocol_name),
ToServiceCommand::ReportPeer(peer, reputation_change) =>
service.report_peer(peer, reputation_change),
ToServiceCommand::StartRequest(peer, protocol, request, tx, connect) =>
service.start_request(peer, protocol, request, None, tx, connect),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::service::mock::MockNetwork;
#[tokio::test]
async fn disconnect_and_report_peer() {
let provider = NetworkServiceProvider::new();
let handle = provider.handle();
let peer = PeerId::random();
let proto = ProtocolName::from("test-protocol");
let proto_clone = proto.clone();
let change = sc_network::ReputationChange::new_fatal("test-change");
let mut mock_network = MockNetwork::new();
mock_network
.expect_disconnect_peer()
.withf(move |in_peer, in_proto| &peer == in_peer && &proto == in_proto)
.once()
.returning(|_, _| ());
mock_network
.expect_report_peer()
.withf(move |in_peer, in_change| &peer == in_peer && &change == in_change)
.once()
.returning(|_, _| ());
tokio::spawn(async move {
provider.run(Arc::new(mock_network)).await;
});
handle.disconnect_peer(peer, proto_clone);
handle.report_peer(peer, change);
}
}