use futures::prelude::*;
use libp2p::core::upgrade::{InboundUpgrade, UpgradeInfo};
use std::{
pin::Pin,
task::{Context, Poll},
vec,
};
#[derive(Debug, Clone)]
pub struct UpgradeCollec<T>(pub Vec<T>);
impl<T> From<Vec<T>> for UpgradeCollec<T> {
fn from(list: Vec<T>) -> Self {
Self(list)
}
}
impl<T> FromIterator<T> for UpgradeCollec<T> {
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
Self(iter.into_iter().collect())
}
}
impl<T: UpgradeInfo> UpgradeInfo for UpgradeCollec<T> {
type Info = ProtoNameWithUsize<T::Info>;
type InfoIter = vec::IntoIter<Self::Info>;
fn protocol_info(&self) -> Self::InfoIter {
self.0
.iter()
.enumerate()
.flat_map(|(n, p)| p.protocol_info().into_iter().map(move |i| ProtoNameWithUsize(i, n)))
.collect::<Vec<_>>()
.into_iter()
}
}
impl<T, C> InboundUpgrade<C> for UpgradeCollec<T>
where
T: InboundUpgrade<C>,
{
type Output = (T::Output, usize);
type Error = (T::Error, usize);
type Future = FutWithUsize<T::Future>;
fn upgrade_inbound(mut self, sock: C, info: Self::Info) -> Self::Future {
let fut = self.0.remove(info.1).upgrade_inbound(sock, info.0);
FutWithUsize(fut, info.1)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ProtoNameWithUsize<T>(T, usize);
impl<T: AsRef<str>> AsRef<str> for ProtoNameWithUsize<T> {
fn as_ref(&self) -> &str {
self.0.as_ref()
}
}
#[pin_project::pin_project]
pub struct FutWithUsize<T>(#[pin] T, usize);
impl<T: Future<Output = Result<O, E>>, O, E> Future for FutWithUsize<T> {
type Output = Result<(O, usize), (E, usize)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
match Future::poll(this.0, cx) {
Poll::Ready(Ok(v)) => Poll::Ready(Ok((v, *this.1))),
Poll::Ready(Err(e)) => Poll::Ready(Err((e, *this.1))),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::ProtocolName as ProtoName;
use libp2p::core::upgrade::UpgradeInfo;
mockall::mock! {
pub ProtocolUpgrade<T> {}
impl<T: Clone + AsRef<str>> UpgradeInfo for ProtocolUpgrade<T> {
type Info = T;
type InfoIter = vec::IntoIter<T>;
fn protocol_info(&self) -> vec::IntoIter<T>;
}
}
#[test]
fn protocol_info() {
let upgrades = (1..=3)
.map(|i| {
let mut upgrade = MockProtocolUpgrade::<ProtoNameWithUsize<ProtoName>>::new();
upgrade.expect_protocol_info().return_once(move || {
vec![ProtoNameWithUsize(ProtoName::from(format!("protocol{i}")), i)].into_iter()
});
upgrade
})
.collect::<Vec<_>>();
let upgrade: UpgradeCollec<_> = upgrades.into_iter().collect::<UpgradeCollec<_>>();
let protos = vec![
ProtoNameWithUsize(ProtoNameWithUsize(ProtoName::from("protocol1".to_string()), 1), 0),
ProtoNameWithUsize(ProtoNameWithUsize(ProtoName::from("protocol2".to_string()), 2), 1),
ProtoNameWithUsize(ProtoNameWithUsize(ProtoName::from("protocol3".to_string()), 3), 2),
];
let upgrades = upgrade.protocol_info().collect::<Vec<_>>();
assert_eq!(upgrades, protos,);
}
#[test]
fn nested_protocol_info() {
let mut upgrades = (1..=2)
.map(|i| {
let mut upgrade = MockProtocolUpgrade::<ProtoNameWithUsize<ProtoName>>::new();
upgrade.expect_protocol_info().return_once(move || {
vec![ProtoNameWithUsize(ProtoName::from(format!("protocol{i}")), i)].into_iter()
});
upgrade
})
.collect::<Vec<_>>();
upgrades.push({
let mut upgrade = MockProtocolUpgrade::<ProtoNameWithUsize<ProtoName>>::new();
upgrade.expect_protocol_info().return_once(move || {
vec![
ProtoNameWithUsize(ProtoName::from("protocol22".to_string()), 1),
ProtoNameWithUsize(ProtoName::from("protocol33".to_string()), 2),
ProtoNameWithUsize(ProtoName::from("protocol44".to_string()), 3),
]
.into_iter()
});
upgrade
});
let upgrade: UpgradeCollec<_> = upgrades.into_iter().collect::<UpgradeCollec<_>>();
let protos = vec![
ProtoNameWithUsize(ProtoNameWithUsize(ProtoName::from("protocol1".to_string()), 1), 0),
ProtoNameWithUsize(ProtoNameWithUsize(ProtoName::from("protocol2".to_string()), 2), 1),
ProtoNameWithUsize(ProtoNameWithUsize(ProtoName::from("protocol22".to_string()), 1), 2),
ProtoNameWithUsize(ProtoNameWithUsize(ProtoName::from("protocol33".to_string()), 2), 2),
ProtoNameWithUsize(ProtoNameWithUsize(ProtoName::from("protocol44".to_string()), 3), 2),
];
let upgrades = upgrade.protocol_info().collect::<Vec<_>>();
assert_eq!(upgrades, protos,);
}
}