1use crate::handler::{
25 AddressChange, ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent, DialUpgradeError,
26 FullyNegotiatedInbound, FullyNegotiatedOutbound, KeepAlive, ListenUpgradeError,
27 SubstreamProtocol,
28};
29use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, UpgradeInfoSend};
30use crate::Stream;
31use futures::{future::BoxFuture, prelude::*};
32use rand::Rng;
33use std::{
34 cmp,
35 collections::{HashMap, HashSet},
36 error,
37 fmt::{self, Debug},
38 hash::Hash,
39 iter::{self, FromIterator},
40 task::{Context, Poll},
41 time::Duration,
42};
43
44#[derive(Clone)]
46pub struct MultiHandler<K, H> {
47 handlers: HashMap<K, H>,
48}
49
50impl<K, H> fmt::Debug for MultiHandler<K, H>
51where
52 K: fmt::Debug + Eq + Hash,
53 H: fmt::Debug,
54{
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 f.debug_struct("MultiHandler")
57 .field("handlers", &self.handlers)
58 .finish()
59 }
60}
61
62impl<K, H> MultiHandler<K, H>
63where
64 K: Clone + Debug + Hash + Eq + Send + 'static,
65 H: ConnectionHandler,
66{
67 pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
71 where
72 I: IntoIterator<Item = (K, H)>,
73 {
74 let m = MultiHandler {
75 handlers: HashMap::from_iter(iter),
76 };
77 uniq_proto_names(
78 m.handlers
79 .values()
80 .map(|h| h.listen_protocol().into_upgrade().0),
81 )?;
82 Ok(m)
83 }
84
85 fn on_listen_upgrade_error(
86 &mut self,
87 ListenUpgradeError {
88 error: (key, error),
89 mut info,
90 }: ListenUpgradeError<
91 <Self as ConnectionHandler>::InboundOpenInfo,
92 <Self as ConnectionHandler>::InboundProtocol,
93 >,
94 ) {
95 if let Some(h) = self.handlers.get_mut(&key) {
96 if let Some(i) = info.take(&key) {
97 h.on_connection_event(ConnectionEvent::ListenUpgradeError(ListenUpgradeError {
98 info: i,
99 error,
100 }));
101 }
102 }
103 }
104}
105
106impl<K, H> ConnectionHandler for MultiHandler<K, H>
107where
108 K: Clone + Debug + Hash + Eq + Send + 'static,
109 H: ConnectionHandler,
110 H::InboundProtocol: InboundUpgradeSend,
111 H::OutboundProtocol: OutboundUpgradeSend,
112{
113 type FromBehaviour = (K, <H as ConnectionHandler>::FromBehaviour);
114 type ToBehaviour = (K, <H as ConnectionHandler>::ToBehaviour);
115 #[allow(deprecated)]
116 type Error = <H as ConnectionHandler>::Error;
117 type InboundProtocol = Upgrade<K, <H as ConnectionHandler>::InboundProtocol>;
118 type OutboundProtocol = <H as ConnectionHandler>::OutboundProtocol;
119 type InboundOpenInfo = Info<K, <H as ConnectionHandler>::InboundOpenInfo>;
120 type OutboundOpenInfo = (K, <H as ConnectionHandler>::OutboundOpenInfo);
121
122 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
123 let (upgrade, info, timeout) = self
124 .handlers
125 .iter()
126 .map(|(key, handler)| {
127 let proto = handler.listen_protocol();
128 let timeout = *proto.timeout();
129 let (upgrade, info) = proto.into_upgrade();
130 (key.clone(), (upgrade, info, timeout))
131 })
132 .fold(
133 (Upgrade::new(), Info::new(), Duration::from_secs(0)),
134 |(mut upg, mut inf, mut timeout), (k, (u, i, t))| {
135 upg.upgrades.push((k.clone(), u));
136 inf.infos.push((k, i));
137 timeout = cmp::max(timeout, t);
138 (upg, inf, timeout)
139 },
140 );
141 SubstreamProtocol::new(upgrade, info).with_timeout(timeout)
142 }
143
144 fn on_connection_event(
145 &mut self,
146 event: ConnectionEvent<
147 Self::InboundProtocol,
148 Self::OutboundProtocol,
149 Self::InboundOpenInfo,
150 Self::OutboundOpenInfo,
151 >,
152 ) {
153 match event {
154 ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
155 protocol,
156 info: (key, arg),
157 }) => {
158 if let Some(h) = self.handlers.get_mut(&key) {
159 h.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
160 FullyNegotiatedOutbound {
161 protocol,
162 info: arg,
163 },
164 ));
165 } else {
166 log::error!("FullyNegotiatedOutbound: no handler for key")
167 }
168 }
169 ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
170 protocol: (key, arg),
171 mut info,
172 }) => {
173 if let Some(h) = self.handlers.get_mut(&key) {
174 if let Some(i) = info.take(&key) {
175 h.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
176 FullyNegotiatedInbound {
177 protocol: arg,
178 info: i,
179 },
180 ));
181 }
182 } else {
183 log::error!("FullyNegotiatedInbound: no handler for key")
184 }
185 }
186 ConnectionEvent::AddressChange(AddressChange { new_address }) => {
187 for h in self.handlers.values_mut() {
188 h.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
189 new_address,
190 }));
191 }
192 }
193 ConnectionEvent::DialUpgradeError(DialUpgradeError {
194 info: (key, arg),
195 error,
196 }) => {
197 if let Some(h) = self.handlers.get_mut(&key) {
198 h.on_connection_event(ConnectionEvent::DialUpgradeError(DialUpgradeError {
199 info: arg,
200 error,
201 }));
202 } else {
203 log::error!("DialUpgradeError: no handler for protocol")
204 }
205 }
206 ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
207 self.on_listen_upgrade_error(listen_upgrade_error)
208 }
209 ConnectionEvent::LocalProtocolsChange(supported_protocols) => {
210 for h in self.handlers.values_mut() {
211 h.on_connection_event(ConnectionEvent::LocalProtocolsChange(
212 supported_protocols.clone(),
213 ));
214 }
215 }
216 ConnectionEvent::RemoteProtocolsChange(supported_protocols) => {
217 for h in self.handlers.values_mut() {
218 h.on_connection_event(ConnectionEvent::RemoteProtocolsChange(
219 supported_protocols.clone(),
220 ));
221 }
222 }
223 }
224 }
225
226 fn on_behaviour_event(&mut self, (key, event): Self::FromBehaviour) {
227 if let Some(h) = self.handlers.get_mut(&key) {
228 h.on_behaviour_event(event)
229 } else {
230 log::error!("on_behaviour_event: no handler for key")
231 }
232 }
233
234 fn connection_keep_alive(&self) -> KeepAlive {
235 self.handlers
236 .values()
237 .map(|h| h.connection_keep_alive())
238 .max()
239 .unwrap_or(KeepAlive::No)
240 }
241
242 #[allow(deprecated)]
243 fn poll(
244 &mut self,
245 cx: &mut Context<'_>,
246 ) -> Poll<
247 ConnectionHandlerEvent<
248 Self::OutboundProtocol,
249 Self::OutboundOpenInfo,
250 Self::ToBehaviour,
251 Self::Error,
252 >,
253 > {
254 if self.handlers.is_empty() {
257 return Poll::Pending;
258 }
259
260 let pos = rand::thread_rng().gen_range(0..self.handlers.len());
262
263 for (k, h) in self.handlers.iter_mut().skip(pos) {
264 if let Poll::Ready(e) = h.poll(cx) {
265 let e = e
266 .map_outbound_open_info(|i| (k.clone(), i))
267 .map_custom(|p| (k.clone(), p));
268 return Poll::Ready(e);
269 }
270 }
271
272 for (k, h) in self.handlers.iter_mut().take(pos) {
273 if let Poll::Ready(e) = h.poll(cx) {
274 let e = e
275 .map_outbound_open_info(|i| (k.clone(), i))
276 .map_custom(|p| (k.clone(), p));
277 return Poll::Ready(e);
278 }
279 }
280
281 Poll::Pending
282 }
283}
284
285impl<K, H> IntoIterator for MultiHandler<K, H> {
287 type Item = <Self::IntoIter as Iterator>::Item;
288 type IntoIter = std::collections::hash_map::IntoIter<K, H>;
289
290 fn into_iter(self) -> Self::IntoIter {
291 self.handlers.into_iter()
292 }
293}
294
295#[derive(Debug, Clone)]
297pub struct IndexedProtoName<H>(usize, H);
298
299impl<H: AsRef<str>> AsRef<str> for IndexedProtoName<H> {
300 fn as_ref(&self) -> &str {
301 self.1.as_ref()
302 }
303}
304
305#[derive(Clone)]
307pub struct Info<K, I> {
308 infos: Vec<(K, I)>,
309}
310
311impl<K: Eq, I> Info<K, I> {
312 fn new() -> Self {
313 Info { infos: Vec::new() }
314 }
315
316 pub fn take(&mut self, k: &K) -> Option<I> {
317 if let Some(p) = self.infos.iter().position(|(key, _)| key == k) {
318 return Some(self.infos.remove(p).1);
319 }
320 None
321 }
322}
323
324#[derive(Clone)]
326pub struct Upgrade<K, H> {
327 upgrades: Vec<(K, H)>,
328}
329
330impl<K, H> Upgrade<K, H> {
331 fn new() -> Self {
332 Upgrade {
333 upgrades: Vec::new(),
334 }
335 }
336}
337
338impl<K, H> fmt::Debug for Upgrade<K, H>
339where
340 K: fmt::Debug + Eq + Hash,
341 H: fmt::Debug,
342{
343 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344 f.debug_struct("Upgrade")
345 .field("upgrades", &self.upgrades)
346 .finish()
347 }
348}
349
350impl<K, H> UpgradeInfoSend for Upgrade<K, H>
351where
352 H: UpgradeInfoSend,
353 K: Send + 'static,
354{
355 type Info = IndexedProtoName<H::Info>;
356 type InfoIter = std::vec::IntoIter<Self::Info>;
357
358 fn protocol_info(&self) -> Self::InfoIter {
359 self.upgrades
360 .iter()
361 .enumerate()
362 .flat_map(|(i, (_, h))| iter::repeat(i).zip(h.protocol_info()))
363 .map(|(i, h)| IndexedProtoName(i, h))
364 .collect::<Vec<_>>()
365 .into_iter()
366 }
367}
368
369impl<K, H> InboundUpgradeSend for Upgrade<K, H>
370where
371 H: InboundUpgradeSend,
372 K: Send + 'static,
373{
374 type Output = (K, <H as InboundUpgradeSend>::Output);
375 type Error = (K, <H as InboundUpgradeSend>::Error);
376 type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
377
378 fn upgrade_inbound(mut self, resource: Stream, info: Self::Info) -> Self::Future {
379 let IndexedProtoName(index, info) = info;
380 let (key, upgrade) = self.upgrades.remove(index);
381 upgrade
382 .upgrade_inbound(resource, info)
383 .map(move |out| match out {
384 Ok(o) => Ok((key, o)),
385 Err(e) => Err((key, e)),
386 })
387 .boxed()
388 }
389}
390
391impl<K, H> OutboundUpgradeSend for Upgrade<K, H>
392where
393 H: OutboundUpgradeSend,
394 K: Send + 'static,
395{
396 type Output = (K, <H as OutboundUpgradeSend>::Output);
397 type Error = (K, <H as OutboundUpgradeSend>::Error);
398 type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
399
400 fn upgrade_outbound(mut self, resource: Stream, info: Self::Info) -> Self::Future {
401 let IndexedProtoName(index, info) = info;
402 let (key, upgrade) = self.upgrades.remove(index);
403 upgrade
404 .upgrade_outbound(resource, info)
405 .map(move |out| match out {
406 Ok(o) => Ok((key, o)),
407 Err(e) => Err((key, e)),
408 })
409 .boxed()
410 }
411}
412
413fn uniq_proto_names<I, T>(iter: I) -> Result<(), DuplicateProtonameError>
415where
416 I: Iterator<Item = T>,
417 T: UpgradeInfoSend,
418{
419 let mut set = HashSet::new();
420 for infos in iter {
421 for i in infos.protocol_info() {
422 let v = Vec::from(i.as_ref());
423 if set.contains(&v) {
424 return Err(DuplicateProtonameError(v));
425 } else {
426 set.insert(v);
427 }
428 }
429 }
430 Ok(())
431}
432
433#[derive(Debug, Clone)]
435pub struct DuplicateProtonameError(Vec<u8>);
436
437impl DuplicateProtonameError {
438 pub fn protocol_name(&self) -> &[u8] {
440 &self.0
441 }
442}
443
444impl fmt::Display for DuplicateProtonameError {
445 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
446 if let Ok(s) = std::str::from_utf8(&self.0) {
447 write!(f, "duplicate protocol name: {s}")
448 } else {
449 write!(f, "duplicate protocol name: {:?}", self.0)
450 }
451 }
452}
453
454impl error::Error for DuplicateProtonameError {}