referrerpolicy=no-referrer-when-downgrade

sc_network/protocol/notifications/upgrade/
collec.rs

1// This file is part of Substrate.
2
3// Copyright (C) Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0
5
6// This program is free software: you can redistribute it and/or modify
7// it under the terms of the GNU General Public License as published by
8// the Free Software Foundation, either version 3 of the License, or
9// (at your option) any later version.
10
11// This program is distributed in the hope that it will be useful,
12// but WITHOUT ANY WARRANTY; without even the implied warranty of
13// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14// GNU General Public License for more details.
15
16// You should have received a copy of the GNU General Public License
17// along with this program. If not, see <https://www.gnu.org/licenses/>.
18
19use futures::prelude::*;
20use libp2p::core::upgrade::{InboundUpgrade, UpgradeInfo};
21use std::{
22	pin::Pin,
23	task::{Context, Poll},
24	vec,
25};
26
27// TODO: move this to libp2p => https://github.com/libp2p/rust-libp2p/issues/1445
28
29/// Upgrade that combines multiple upgrades of the same type into one. Supports all the protocols
30/// supported by either sub-upgrade.
31#[derive(Debug, Clone)]
32pub struct UpgradeCollec<T>(pub Vec<T>);
33
34impl<T> From<Vec<T>> for UpgradeCollec<T> {
35	fn from(list: Vec<T>) -> Self {
36		Self(list)
37	}
38}
39
40impl<T> FromIterator<T> for UpgradeCollec<T> {
41	fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
42		Self(iter.into_iter().collect())
43	}
44}
45
46impl<T: UpgradeInfo> UpgradeInfo for UpgradeCollec<T> {
47	type Info = ProtoNameWithUsize<T::Info>;
48	type InfoIter = vec::IntoIter<Self::Info>;
49
50	fn protocol_info(&self) -> Self::InfoIter {
51		self.0
52			.iter()
53			.enumerate()
54			.flat_map(|(n, p)| p.protocol_info().into_iter().map(move |i| ProtoNameWithUsize(i, n)))
55			.collect::<Vec<_>>()
56			.into_iter()
57	}
58}
59
60impl<T, C> InboundUpgrade<C> for UpgradeCollec<T>
61where
62	T: InboundUpgrade<C>,
63{
64	type Output = (T::Output, usize);
65	type Error = (T::Error, usize);
66	type Future = FutWithUsize<T::Future>;
67
68	fn upgrade_inbound(mut self, sock: C, info: Self::Info) -> Self::Future {
69		let fut = self.0.remove(info.1).upgrade_inbound(sock, info.0);
70		FutWithUsize(fut, info.1)
71	}
72}
73
74/// Groups a `ProtocolName` with a `usize`.
75#[derive(Debug, Clone, PartialEq)]
76pub struct ProtoNameWithUsize<T>(T, usize);
77
78impl<T: AsRef<str>> AsRef<str> for ProtoNameWithUsize<T> {
79	fn as_ref(&self) -> &str {
80		self.0.as_ref()
81	}
82}
83
84/// Equivalent to `fut.map_ok(|v| (v, num)).map_err(|e| (e, num))`, where `fut` and `num` are
85/// the two fields of this struct.
86#[pin_project::pin_project]
87pub struct FutWithUsize<T>(#[pin] T, usize);
88
89impl<T: Future<Output = Result<O, E>>, O, E> Future for FutWithUsize<T> {
90	type Output = Result<(O, usize), (E, usize)>;
91
92	fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
93		let this = self.project();
94		match Future::poll(this.0, cx) {
95			Poll::Ready(Ok(v)) => Poll::Ready(Ok((v, *this.1))),
96			Poll::Ready(Err(e)) => Poll::Ready(Err((e, *this.1))),
97			Poll::Pending => Poll::Pending,
98		}
99	}
100}
101
102#[cfg(test)]
103mod tests {
104	use super::*;
105	use crate::types::ProtocolName as ProtoName;
106	use libp2p::core::upgrade::UpgradeInfo;
107
108	// TODO: move to mocks
109	mockall::mock! {
110		pub ProtocolUpgrade<T> {}
111
112		impl<T: Clone + AsRef<str>> UpgradeInfo for ProtocolUpgrade<T> {
113			type Info = T;
114			type InfoIter = vec::IntoIter<T>;
115			fn protocol_info(&self) -> vec::IntoIter<T>;
116		}
117	}
118
119	#[test]
120	fn protocol_info() {
121		let upgrades = (1..=3)
122			.map(|i| {
123				let mut upgrade = MockProtocolUpgrade::<ProtoNameWithUsize<ProtoName>>::new();
124				upgrade.expect_protocol_info().return_once(move || {
125					vec![ProtoNameWithUsize(ProtoName::from(format!("protocol{i}")), i)].into_iter()
126				});
127				upgrade
128			})
129			.collect::<Vec<_>>();
130
131		let upgrade: UpgradeCollec<_> = upgrades.into_iter().collect::<UpgradeCollec<_>>();
132		let protos = vec![
133			ProtoNameWithUsize(ProtoNameWithUsize(ProtoName::from("protocol1".to_string()), 1), 0),
134			ProtoNameWithUsize(ProtoNameWithUsize(ProtoName::from("protocol2".to_string()), 2), 1),
135			ProtoNameWithUsize(ProtoNameWithUsize(ProtoName::from("protocol3".to_string()), 3), 2),
136		];
137		let upgrades = upgrade.protocol_info().collect::<Vec<_>>();
138
139		assert_eq!(upgrades, protos,);
140	}
141
142	#[test]
143	fn nested_protocol_info() {
144		let mut upgrades = (1..=2)
145			.map(|i| {
146				let mut upgrade = MockProtocolUpgrade::<ProtoNameWithUsize<ProtoName>>::new();
147				upgrade.expect_protocol_info().return_once(move || {
148					vec![ProtoNameWithUsize(ProtoName::from(format!("protocol{i}")), i)].into_iter()
149				});
150				upgrade
151			})
152			.collect::<Vec<_>>();
153
154		upgrades.push({
155			let mut upgrade = MockProtocolUpgrade::<ProtoNameWithUsize<ProtoName>>::new();
156			upgrade.expect_protocol_info().return_once(move || {
157				vec![
158					ProtoNameWithUsize(ProtoName::from("protocol22".to_string()), 1),
159					ProtoNameWithUsize(ProtoName::from("protocol33".to_string()), 2),
160					ProtoNameWithUsize(ProtoName::from("protocol44".to_string()), 3),
161				]
162				.into_iter()
163			});
164			upgrade
165		});
166
167		let upgrade: UpgradeCollec<_> = upgrades.into_iter().collect::<UpgradeCollec<_>>();
168		let protos = vec![
169			ProtoNameWithUsize(ProtoNameWithUsize(ProtoName::from("protocol1".to_string()), 1), 0),
170			ProtoNameWithUsize(ProtoNameWithUsize(ProtoName::from("protocol2".to_string()), 2), 1),
171			ProtoNameWithUsize(ProtoNameWithUsize(ProtoName::from("protocol22".to_string()), 1), 2),
172			ProtoNameWithUsize(ProtoNameWithUsize(ProtoName::from("protocol33".to_string()), 2), 2),
173			ProtoNameWithUsize(ProtoNameWithUsize(ProtoName::from("protocol44".to_string()), 3), 2),
174		];
175		let upgrades = upgrade.protocol_info().collect::<Vec<_>>();
176		assert_eq!(upgrades, protos,);
177	}
178}