1use crate::handler::{
25 AddressChange, ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent, DialUpgradeError,
26 FullyNegotiatedInbound, FullyNegotiatedOutbound, ListenUpgradeError, SubstreamProtocol,
27};
28use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, UpgradeInfoSend};
29use crate::Stream;
30use futures::{future::BoxFuture, prelude::*, ready};
31use rand::Rng;
32use std::{
33 cmp,
34 collections::{HashMap, HashSet},
35 error,
36 fmt::{self, Debug},
37 hash::Hash,
38 iter,
39 task::{Context, Poll},
40 time::Duration,
41};
42
43#[derive(Clone)]
45pub struct MultiHandler<K, H> {
46 handlers: HashMap<K, H>,
47}
48
49impl<K, H> fmt::Debug for MultiHandler<K, H>
50where
51 K: fmt::Debug + Eq + Hash,
52 H: fmt::Debug,
53{
54 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55 f.debug_struct("MultiHandler")
56 .field("handlers", &self.handlers)
57 .finish()
58 }
59}
60
61impl<K, H> MultiHandler<K, H>
62where
63 K: Clone + Debug + Hash + Eq + Send + 'static,
64 H: ConnectionHandler,
65{
66 pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
70 where
71 I: IntoIterator<Item = (K, H)>,
72 {
73 let m = MultiHandler {
74 handlers: HashMap::from_iter(iter),
75 };
76 uniq_proto_names(
77 m.handlers
78 .values()
79 .map(|h| h.listen_protocol().into_upgrade().0),
80 )?;
81 Ok(m)
82 }
83
84 fn on_listen_upgrade_error(
85 &mut self,
86 ListenUpgradeError {
87 error: (key, error),
88 mut info,
89 }: ListenUpgradeError<
90 <Self as ConnectionHandler>::InboundOpenInfo,
91 <Self as ConnectionHandler>::InboundProtocol,
92 >,
93 ) {
94 if let Some(h) = self.handlers.get_mut(&key) {
95 if let Some(i) = info.take(&key) {
96 h.on_connection_event(ConnectionEvent::ListenUpgradeError(ListenUpgradeError {
97 info: i,
98 error,
99 }));
100 }
101 }
102 }
103}
104
105impl<K, H> ConnectionHandler for MultiHandler<K, H>
106where
107 K: Clone + Debug + Hash + Eq + Send + 'static,
108 H: ConnectionHandler,
109 H::InboundProtocol: InboundUpgradeSend,
110 H::OutboundProtocol: OutboundUpgradeSend,
111{
112 type FromBehaviour = (K, <H as ConnectionHandler>::FromBehaviour);
113 type ToBehaviour = (K, <H as ConnectionHandler>::ToBehaviour);
114 type InboundProtocol = Upgrade<K, <H as ConnectionHandler>::InboundProtocol>;
115 type OutboundProtocol = <H as ConnectionHandler>::OutboundProtocol;
116 type InboundOpenInfo = Info<K, <H as ConnectionHandler>::InboundOpenInfo>;
117 type OutboundOpenInfo = (K, <H as ConnectionHandler>::OutboundOpenInfo);
118
119 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
120 let (upgrade, info, timeout) = self
121 .handlers
122 .iter()
123 .map(|(key, handler)| {
124 let proto = handler.listen_protocol();
125 let timeout = *proto.timeout();
126 let (upgrade, info) = proto.into_upgrade();
127 (key.clone(), (upgrade, info, timeout))
128 })
129 .fold(
130 (Upgrade::new(), Info::new(), Duration::from_secs(0)),
131 |(mut upg, mut inf, mut timeout), (k, (u, i, t))| {
132 upg.upgrades.push((k.clone(), u));
133 inf.infos.push((k, i));
134 timeout = cmp::max(timeout, t);
135 (upg, inf, timeout)
136 },
137 );
138 SubstreamProtocol::new(upgrade, info).with_timeout(timeout)
139 }
140
141 fn on_connection_event(
142 &mut self,
143 event: ConnectionEvent<
144 Self::InboundProtocol,
145 Self::OutboundProtocol,
146 Self::InboundOpenInfo,
147 Self::OutboundOpenInfo,
148 >,
149 ) {
150 match event {
151 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
152 protocol,
153 info: (key, arg),
154 }) => {
155 if let Some(h) = self.handlers.get_mut(&key) {
156 h.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
157 FullyNegotiatedOutbound {
158 protocol,
159 info: arg,
160 },
161 ));
162 } else {
163 tracing::error!("FullyNegotiatedOutbound: no handler for key")
164 }
165 }
166 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
167 protocol: (key, arg),
168 mut info,
169 }) => {
170 if let Some(h) = self.handlers.get_mut(&key) {
171 if let Some(i) = info.take(&key) {
172 h.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
173 FullyNegotiatedInbound {
174 protocol: arg,
175 info: i,
176 },
177 ));
178 }
179 } else {
180 tracing::error!("FullyNegotiatedInbound: no handler for key")
181 }
182 }
183 ConnectionEvent::AddressChange(AddressChange { new_address }) => {
184 for h in self.handlers.values_mut() {
185 h.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
186 new_address,
187 }));
188 }
189 }
190 ConnectionEvent::DialUpgradeError(DialUpgradeError {
191 info: (key, arg),
192 error,
193 }) => {
194 if let Some(h) = self.handlers.get_mut(&key) {
195 h.on_connection_event(ConnectionEvent::DialUpgradeError(DialUpgradeError {
196 info: arg,
197 error,
198 }));
199 } else {
200 tracing::error!("DialUpgradeError: no handler for protocol")
201 }
202 }
203 ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
204 self.on_listen_upgrade_error(listen_upgrade_error)
205 }
206 ConnectionEvent::LocalProtocolsChange(supported_protocols) => {
207 for h in self.handlers.values_mut() {
208 h.on_connection_event(ConnectionEvent::LocalProtocolsChange(
209 supported_protocols.clone(),
210 ));
211 }
212 }
213 ConnectionEvent::RemoteProtocolsChange(supported_protocols) => {
214 for h in self.handlers.values_mut() {
215 h.on_connection_event(ConnectionEvent::RemoteProtocolsChange(
216 supported_protocols.clone(),
217 ));
218 }
219 }
220 }
221 }
222
223 fn on_behaviour_event(&mut self, (key, event): Self::FromBehaviour) {
224 if let Some(h) = self.handlers.get_mut(&key) {
225 h.on_behaviour_event(event)
226 } else {
227 tracing::error!("on_behaviour_event: no handler for key")
228 }
229 }
230
231 fn connection_keep_alive(&self) -> bool {
232 self.handlers
233 .values()
234 .map(|h| h.connection_keep_alive())
235 .max()
236 .unwrap_or(false)
237 }
238
239 fn poll(
240 &mut self,
241 cx: &mut Context<'_>,
242 ) -> Poll<
243 ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
244 > {
245 if self.handlers.is_empty() {
248 return Poll::Pending;
249 }
250
251 let pos = rand::thread_rng().gen_range(0..self.handlers.len());
253
254 for (k, h) in self.handlers.iter_mut().skip(pos) {
255 if let Poll::Ready(e) = h.poll(cx) {
256 let e = e
257 .map_outbound_open_info(|i| (k.clone(), i))
258 .map_custom(|p| (k.clone(), p));
259 return Poll::Ready(e);
260 }
261 }
262
263 for (k, h) in self.handlers.iter_mut().take(pos) {
264 if let Poll::Ready(e) = h.poll(cx) {
265 let e = e
266 .map_outbound_open_info(|i| (k.clone(), i))
267 .map_custom(|p| (k.clone(), p));
268 return Poll::Ready(e);
269 }
270 }
271
272 Poll::Pending
273 }
274
275 fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
276 for (k, h) in self.handlers.iter_mut() {
277 let Some(e) = ready!(h.poll_close(cx)) else {
278 continue;
279 };
280 return Poll::Ready(Some((k.clone(), e)));
281 }
282
283 Poll::Ready(None)
284 }
285}
286
287impl<K, H> IntoIterator for MultiHandler<K, H> {
289 type Item = <Self::IntoIter as Iterator>::Item;
290 type IntoIter = std::collections::hash_map::IntoIter<K, H>;
291
292 fn into_iter(self) -> Self::IntoIter {
293 self.handlers.into_iter()
294 }
295}
296
297#[derive(Debug, Clone)]
299pub struct IndexedProtoName<H>(usize, H);
300
301impl<H: AsRef<str>> AsRef<str> for IndexedProtoName<H> {
302 fn as_ref(&self) -> &str {
303 self.1.as_ref()
304 }
305}
306
307#[derive(Clone)]
309pub struct Info<K, I> {
310 infos: Vec<(K, I)>,
311}
312
313impl<K: Eq, I> Info<K, I> {
314 fn new() -> Self {
315 Info { infos: Vec::new() }
316 }
317
318 pub fn take(&mut self, k: &K) -> Option<I> {
319 if let Some(p) = self.infos.iter().position(|(key, _)| key == k) {
320 return Some(self.infos.remove(p).1);
321 }
322 None
323 }
324}
325
326#[derive(Clone)]
328pub struct Upgrade<K, H> {
329 upgrades: Vec<(K, H)>,
330}
331
332impl<K, H> Upgrade<K, H> {
333 fn new() -> Self {
334 Upgrade {
335 upgrades: Vec::new(),
336 }
337 }
338}
339
340impl<K, H> fmt::Debug for Upgrade<K, H>
341where
342 K: fmt::Debug + Eq + Hash,
343 H: fmt::Debug,
344{
345 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346 f.debug_struct("Upgrade")
347 .field("upgrades", &self.upgrades)
348 .finish()
349 }
350}
351
352impl<K, H> UpgradeInfoSend for Upgrade<K, H>
353where
354 H: UpgradeInfoSend,
355 K: Send + 'static,
356{
357 type Info = IndexedProtoName<H::Info>;
358 type InfoIter = std::vec::IntoIter<Self::Info>;
359
360 fn protocol_info(&self) -> Self::InfoIter {
361 self.upgrades
362 .iter()
363 .enumerate()
364 .flat_map(|(i, (_, h))| iter::repeat(i).zip(h.protocol_info()))
365 .map(|(i, h)| IndexedProtoName(i, h))
366 .collect::<Vec<_>>()
367 .into_iter()
368 }
369}
370
371impl<K, H> InboundUpgradeSend for Upgrade<K, H>
372where
373 H: InboundUpgradeSend,
374 K: Send + 'static,
375{
376 type Output = (K, <H as InboundUpgradeSend>::Output);
377 type Error = (K, <H as InboundUpgradeSend>::Error);
378 type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
379
380 fn upgrade_inbound(mut self, resource: Stream, info: Self::Info) -> Self::Future {
381 let IndexedProtoName(index, info) = info;
382 let (key, upgrade) = self.upgrades.remove(index);
383 upgrade
384 .upgrade_inbound(resource, info)
385 .map(move |out| match out {
386 Ok(o) => Ok((key, o)),
387 Err(e) => Err((key, e)),
388 })
389 .boxed()
390 }
391}
392
393impl<K, H> OutboundUpgradeSend for Upgrade<K, H>
394where
395 H: OutboundUpgradeSend,
396 K: Send + 'static,
397{
398 type Output = (K, <H as OutboundUpgradeSend>::Output);
399 type Error = (K, <H as OutboundUpgradeSend>::Error);
400 type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
401
402 fn upgrade_outbound(mut self, resource: Stream, info: Self::Info) -> Self::Future {
403 let IndexedProtoName(index, info) = info;
404 let (key, upgrade) = self.upgrades.remove(index);
405 upgrade
406 .upgrade_outbound(resource, info)
407 .map(move |out| match out {
408 Ok(o) => Ok((key, o)),
409 Err(e) => Err((key, e)),
410 })
411 .boxed()
412 }
413}
414
415fn uniq_proto_names<I, T>(iter: I) -> Result<(), DuplicateProtonameError>
417where
418 I: Iterator<Item = T>,
419 T: UpgradeInfoSend,
420{
421 let mut set = HashSet::new();
422 for infos in iter {
423 for i in infos.protocol_info() {
424 let v = Vec::from(i.as_ref());
425 if set.contains(&v) {
426 return Err(DuplicateProtonameError(v));
427 } else {
428 set.insert(v);
429 }
430 }
431 }
432 Ok(())
433}
434
435#[derive(Debug, Clone)]
437pub struct DuplicateProtonameError(Vec<u8>);
438
439impl DuplicateProtonameError {
440 pub fn protocol_name(&self) -> &[u8] {
442 &self.0
443 }
444}
445
446impl fmt::Display for DuplicateProtonameError {
447 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
448 if let Ok(s) = std::str::from_utf8(&self.0) {
449 write!(f, "duplicate protocol name: {s}")
450 } else {
451 write!(f, "duplicate protocol name: {:?}", self.0)
452 }
453 }
454}
455
456impl error::Error for DuplicateProtonameError {}