use futures::{channel::oneshot, StreamExt};
use libp2p::PeerId;
use sc_network::{
request_responses::{IfDisconnected, RequestFailure},
types::ProtocolName,
NetworkNotification, NetworkPeers, NetworkRequest, ReputationChange,
};
use sc_utils::mpsc::{tracing_unbounded, TracingUnboundedReceiver, TracingUnboundedSender};
use std::sync::Arc;
pub trait Network: NetworkPeers + NetworkRequest + NetworkNotification {}
impl<T> Network for T where T: NetworkPeers + NetworkRequest + NetworkNotification {}
pub struct NetworkServiceProvider {
rx: TracingUnboundedReceiver<ToServiceCommand>,
}
pub enum ToServiceCommand {
DisconnectPeer(PeerId, ProtocolName),
ReportPeer(PeerId, ReputationChange),
StartRequest(
PeerId,
ProtocolName,
Vec<u8>,
oneshot::Sender<Result<Vec<u8>, RequestFailure>>,
IfDisconnected,
),
WriteNotification(PeerId, ProtocolName, Vec<u8>),
SetNotificationHandshake(ProtocolName, Vec<u8>),
}
#[derive(Clone)]
pub struct NetworkServiceHandle {
tx: TracingUnboundedSender<ToServiceCommand>,
}
impl NetworkServiceHandle {
pub fn new(tx: TracingUnboundedSender<ToServiceCommand>) -> NetworkServiceHandle {
Self { tx }
}
pub fn report_peer(&self, who: PeerId, cost_benefit: ReputationChange) {
let _ = self.tx.unbounded_send(ToServiceCommand::ReportPeer(who, cost_benefit));
}
pub fn disconnect_peer(&self, who: PeerId, protocol: ProtocolName) {
let _ = self.tx.unbounded_send(ToServiceCommand::DisconnectPeer(who, protocol));
}
pub fn start_request(
&self,
who: PeerId,
protocol: ProtocolName,
request: Vec<u8>,
tx: oneshot::Sender<Result<Vec<u8>, RequestFailure>>,
connect: IfDisconnected,
) {
let _ = self
.tx
.unbounded_send(ToServiceCommand::StartRequest(who, protocol, request, tx, connect));
}
pub fn write_notification(&self, who: PeerId, protocol: ProtocolName, message: Vec<u8>) {
let _ = self
.tx
.unbounded_send(ToServiceCommand::WriteNotification(who, protocol, message));
}
pub fn set_notification_handshake(&self, protocol: ProtocolName, handshake: Vec<u8>) {
let _ = self
.tx
.unbounded_send(ToServiceCommand::SetNotificationHandshake(protocol, handshake));
}
}
impl NetworkServiceProvider {
pub fn new() -> (Self, NetworkServiceHandle) {
let (tx, rx) = tracing_unbounded("mpsc_network_service_provider", 100_000);
(Self { rx }, NetworkServiceHandle::new(tx))
}
pub async fn run(mut self, service: Arc<dyn Network + Send + Sync>) {
while let Some(inner) = self.rx.next().await {
match inner {
ToServiceCommand::DisconnectPeer(peer, protocol_name) =>
service.disconnect_peer(peer, protocol_name),
ToServiceCommand::ReportPeer(peer, reputation_change) =>
service.report_peer(peer, reputation_change),
ToServiceCommand::StartRequest(peer, protocol, request, tx, connect) =>
service.start_request(peer, protocol, request, tx, connect),
ToServiceCommand::WriteNotification(peer, protocol, message) =>
service.write_notification(peer, protocol, message),
ToServiceCommand::SetNotificationHandshake(protocol, handshake) =>
service.set_notification_handshake(protocol, handshake),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::service::mock::MockNetwork;
#[tokio::test]
async fn disconnect_and_report_peer() {
let (provider, handle) = NetworkServiceProvider::new();
let peer = PeerId::random();
let proto = ProtocolName::from("test-protocol");
let proto_clone = proto.clone();
let change = sc_network::ReputationChange::new_fatal("test-change");
let mut mock_network = MockNetwork::new();
mock_network
.expect_disconnect_peer()
.withf(move |in_peer, in_proto| &peer == in_peer && &proto == in_proto)
.once()
.returning(|_, _| ());
mock_network
.expect_report_peer()
.withf(move |in_peer, in_change| &peer == in_peer && &change == in_change)
.once()
.returning(|_, _| ());
tokio::spawn(async move {
provider.run(Arc::new(mock_network)).await;
});
handle.disconnect_peer(peer, proto_clone);
handle.report_peer(peer, change);
}
}