libp2p_swarm/handler/
multi.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! A [`ConnectionHandler`] implementation that combines multiple other [`ConnectionHandler`]s
22//! indexed by some key.
23
24use crate::handler::{
25    AddressChange, ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent, DialUpgradeError,
26    FullyNegotiatedInbound, FullyNegotiatedOutbound, ListenUpgradeError, SubstreamProtocol,
27};
28use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, UpgradeInfoSend};
29use crate::Stream;
30use futures::{future::BoxFuture, prelude::*, ready};
31use rand::Rng;
32use std::{
33    cmp,
34    collections::{HashMap, HashSet},
35    error,
36    fmt::{self, Debug},
37    hash::Hash,
38    iter,
39    task::{Context, Poll},
40    time::Duration,
41};
42
43/// A [`ConnectionHandler`] for multiple [`ConnectionHandler`]s of the same type.
44#[derive(Clone)]
45pub struct MultiHandler<K, H> {
46    handlers: HashMap<K, H>,
47}
48
49impl<K, H> fmt::Debug for MultiHandler<K, H>
50where
51    K: fmt::Debug + Eq + Hash,
52    H: fmt::Debug,
53{
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        f.debug_struct("MultiHandler")
56            .field("handlers", &self.handlers)
57            .finish()
58    }
59}
60
61impl<K, H> MultiHandler<K, H>
62where
63    K: Clone + Debug + Hash + Eq + Send + 'static,
64    H: ConnectionHandler,
65{
66    /// Create and populate a `MultiHandler` from the given handler iterator.
67    ///
68    /// It is an error for any two protocols handlers to share the same protocol name.
69    pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
70    where
71        I: IntoIterator<Item = (K, H)>,
72    {
73        let m = MultiHandler {
74            handlers: HashMap::from_iter(iter),
75        };
76        uniq_proto_names(
77            m.handlers
78                .values()
79                .map(|h| h.listen_protocol().into_upgrade().0),
80        )?;
81        Ok(m)
82    }
83
84    fn on_listen_upgrade_error(
85        &mut self,
86        ListenUpgradeError {
87            error: (key, error),
88            mut info,
89        }: ListenUpgradeError<
90            <Self as ConnectionHandler>::InboundOpenInfo,
91            <Self as ConnectionHandler>::InboundProtocol,
92        >,
93    ) {
94        if let Some(h) = self.handlers.get_mut(&key) {
95            if let Some(i) = info.take(&key) {
96                h.on_connection_event(ConnectionEvent::ListenUpgradeError(ListenUpgradeError {
97                    info: i,
98                    error,
99                }));
100            }
101        }
102    }
103}
104
105impl<K, H> ConnectionHandler for MultiHandler<K, H>
106where
107    K: Clone + Debug + Hash + Eq + Send + 'static,
108    H: ConnectionHandler,
109    H::InboundProtocol: InboundUpgradeSend,
110    H::OutboundProtocol: OutboundUpgradeSend,
111{
112    type FromBehaviour = (K, <H as ConnectionHandler>::FromBehaviour);
113    type ToBehaviour = (K, <H as ConnectionHandler>::ToBehaviour);
114    type InboundProtocol = Upgrade<K, <H as ConnectionHandler>::InboundProtocol>;
115    type OutboundProtocol = <H as ConnectionHandler>::OutboundProtocol;
116    type InboundOpenInfo = Info<K, <H as ConnectionHandler>::InboundOpenInfo>;
117    type OutboundOpenInfo = (K, <H as ConnectionHandler>::OutboundOpenInfo);
118
119    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
120        let (upgrade, info, timeout) = self
121            .handlers
122            .iter()
123            .map(|(key, handler)| {
124                let proto = handler.listen_protocol();
125                let timeout = *proto.timeout();
126                let (upgrade, info) = proto.into_upgrade();
127                (key.clone(), (upgrade, info, timeout))
128            })
129            .fold(
130                (Upgrade::new(), Info::new(), Duration::from_secs(0)),
131                |(mut upg, mut inf, mut timeout), (k, (u, i, t))| {
132                    upg.upgrades.push((k.clone(), u));
133                    inf.infos.push((k, i));
134                    timeout = cmp::max(timeout, t);
135                    (upg, inf, timeout)
136                },
137            );
138        SubstreamProtocol::new(upgrade, info).with_timeout(timeout)
139    }
140
141    fn on_connection_event(
142        &mut self,
143        event: ConnectionEvent<
144            Self::InboundProtocol,
145            Self::OutboundProtocol,
146            Self::InboundOpenInfo,
147            Self::OutboundOpenInfo,
148        >,
149    ) {
150        match event {
151            ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
152                protocol,
153                info: (key, arg),
154            }) => {
155                if let Some(h) = self.handlers.get_mut(&key) {
156                    h.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
157                        FullyNegotiatedOutbound {
158                            protocol,
159                            info: arg,
160                        },
161                    ));
162                } else {
163                    tracing::error!("FullyNegotiatedOutbound: no handler for key")
164                }
165            }
166            ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
167                protocol: (key, arg),
168                mut info,
169            }) => {
170                if let Some(h) = self.handlers.get_mut(&key) {
171                    if let Some(i) = info.take(&key) {
172                        h.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
173                            FullyNegotiatedInbound {
174                                protocol: arg,
175                                info: i,
176                            },
177                        ));
178                    }
179                } else {
180                    tracing::error!("FullyNegotiatedInbound: no handler for key")
181                }
182            }
183            ConnectionEvent::AddressChange(AddressChange { new_address }) => {
184                for h in self.handlers.values_mut() {
185                    h.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
186                        new_address,
187                    }));
188                }
189            }
190            ConnectionEvent::DialUpgradeError(DialUpgradeError {
191                info: (key, arg),
192                error,
193            }) => {
194                if let Some(h) = self.handlers.get_mut(&key) {
195                    h.on_connection_event(ConnectionEvent::DialUpgradeError(DialUpgradeError {
196                        info: arg,
197                        error,
198                    }));
199                } else {
200                    tracing::error!("DialUpgradeError: no handler for protocol")
201                }
202            }
203            ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
204                self.on_listen_upgrade_error(listen_upgrade_error)
205            }
206            ConnectionEvent::LocalProtocolsChange(supported_protocols) => {
207                for h in self.handlers.values_mut() {
208                    h.on_connection_event(ConnectionEvent::LocalProtocolsChange(
209                        supported_protocols.clone(),
210                    ));
211                }
212            }
213            ConnectionEvent::RemoteProtocolsChange(supported_protocols) => {
214                for h in self.handlers.values_mut() {
215                    h.on_connection_event(ConnectionEvent::RemoteProtocolsChange(
216                        supported_protocols.clone(),
217                    ));
218                }
219            }
220        }
221    }
222
223    fn on_behaviour_event(&mut self, (key, event): Self::FromBehaviour) {
224        if let Some(h) = self.handlers.get_mut(&key) {
225            h.on_behaviour_event(event)
226        } else {
227            tracing::error!("on_behaviour_event: no handler for key")
228        }
229    }
230
231    fn connection_keep_alive(&self) -> bool {
232        self.handlers
233            .values()
234            .map(|h| h.connection_keep_alive())
235            .max()
236            .unwrap_or(false)
237    }
238
239    fn poll(
240        &mut self,
241        cx: &mut Context<'_>,
242    ) -> Poll<
243        ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
244    > {
245        // Calling `gen_range(0, 0)` (see below) would panic, so we have return early to avoid
246        // that situation.
247        if self.handlers.is_empty() {
248            return Poll::Pending;
249        }
250
251        // Not always polling handlers in the same order should give anyone the chance to make progress.
252        let pos = rand::thread_rng().gen_range(0..self.handlers.len());
253
254        for (k, h) in self.handlers.iter_mut().skip(pos) {
255            if let Poll::Ready(e) = h.poll(cx) {
256                let e = e
257                    .map_outbound_open_info(|i| (k.clone(), i))
258                    .map_custom(|p| (k.clone(), p));
259                return Poll::Ready(e);
260            }
261        }
262
263        for (k, h) in self.handlers.iter_mut().take(pos) {
264            if let Poll::Ready(e) = h.poll(cx) {
265                let e = e
266                    .map_outbound_open_info(|i| (k.clone(), i))
267                    .map_custom(|p| (k.clone(), p));
268                return Poll::Ready(e);
269            }
270        }
271
272        Poll::Pending
273    }
274
275    fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
276        for (k, h) in self.handlers.iter_mut() {
277            let Some(e) = ready!(h.poll_close(cx)) else {
278                continue;
279            };
280            return Poll::Ready(Some((k.clone(), e)));
281        }
282
283        Poll::Ready(None)
284    }
285}
286
287/// Split [`MultiHandler`] into parts.
288impl<K, H> IntoIterator for MultiHandler<K, H> {
289    type Item = <Self::IntoIter as Iterator>::Item;
290    type IntoIter = std::collections::hash_map::IntoIter<K, H>;
291
292    fn into_iter(self) -> Self::IntoIter {
293        self.handlers.into_iter()
294    }
295}
296
297/// Index and protocol name pair used as `UpgradeInfo::Info`.
298#[derive(Debug, Clone)]
299pub struct IndexedProtoName<H>(usize, H);
300
301impl<H: AsRef<str>> AsRef<str> for IndexedProtoName<H> {
302    fn as_ref(&self) -> &str {
303        self.1.as_ref()
304    }
305}
306
307/// The aggregated `InboundOpenInfo`s of supported inbound substream protocols.
308#[derive(Clone)]
309pub struct Info<K, I> {
310    infos: Vec<(K, I)>,
311}
312
313impl<K: Eq, I> Info<K, I> {
314    fn new() -> Self {
315        Info { infos: Vec::new() }
316    }
317
318    pub fn take(&mut self, k: &K) -> Option<I> {
319        if let Some(p) = self.infos.iter().position(|(key, _)| key == k) {
320            return Some(self.infos.remove(p).1);
321        }
322        None
323    }
324}
325
326/// Inbound and outbound upgrade for all [`ConnectionHandler`]s.
327#[derive(Clone)]
328pub struct Upgrade<K, H> {
329    upgrades: Vec<(K, H)>,
330}
331
332impl<K, H> Upgrade<K, H> {
333    fn new() -> Self {
334        Upgrade {
335            upgrades: Vec::new(),
336        }
337    }
338}
339
340impl<K, H> fmt::Debug for Upgrade<K, H>
341where
342    K: fmt::Debug + Eq + Hash,
343    H: fmt::Debug,
344{
345    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346        f.debug_struct("Upgrade")
347            .field("upgrades", &self.upgrades)
348            .finish()
349    }
350}
351
352impl<K, H> UpgradeInfoSend for Upgrade<K, H>
353where
354    H: UpgradeInfoSend,
355    K: Send + 'static,
356{
357    type Info = IndexedProtoName<H::Info>;
358    type InfoIter = std::vec::IntoIter<Self::Info>;
359
360    fn protocol_info(&self) -> Self::InfoIter {
361        self.upgrades
362            .iter()
363            .enumerate()
364            .flat_map(|(i, (_, h))| iter::repeat(i).zip(h.protocol_info()))
365            .map(|(i, h)| IndexedProtoName(i, h))
366            .collect::<Vec<_>>()
367            .into_iter()
368    }
369}
370
371impl<K, H> InboundUpgradeSend for Upgrade<K, H>
372where
373    H: InboundUpgradeSend,
374    K: Send + 'static,
375{
376    type Output = (K, <H as InboundUpgradeSend>::Output);
377    type Error = (K, <H as InboundUpgradeSend>::Error);
378    type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
379
380    fn upgrade_inbound(mut self, resource: Stream, info: Self::Info) -> Self::Future {
381        let IndexedProtoName(index, info) = info;
382        let (key, upgrade) = self.upgrades.remove(index);
383        upgrade
384            .upgrade_inbound(resource, info)
385            .map(move |out| match out {
386                Ok(o) => Ok((key, o)),
387                Err(e) => Err((key, e)),
388            })
389            .boxed()
390    }
391}
392
393impl<K, H> OutboundUpgradeSend for Upgrade<K, H>
394where
395    H: OutboundUpgradeSend,
396    K: Send + 'static,
397{
398    type Output = (K, <H as OutboundUpgradeSend>::Output);
399    type Error = (K, <H as OutboundUpgradeSend>::Error);
400    type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
401
402    fn upgrade_outbound(mut self, resource: Stream, info: Self::Info) -> Self::Future {
403        let IndexedProtoName(index, info) = info;
404        let (key, upgrade) = self.upgrades.remove(index);
405        upgrade
406            .upgrade_outbound(resource, info)
407            .map(move |out| match out {
408                Ok(o) => Ok((key, o)),
409                Err(e) => Err((key, e)),
410            })
411            .boxed()
412    }
413}
414
415/// Check that no two protocol names are equal.
416fn uniq_proto_names<I, T>(iter: I) -> Result<(), DuplicateProtonameError>
417where
418    I: Iterator<Item = T>,
419    T: UpgradeInfoSend,
420{
421    let mut set = HashSet::new();
422    for infos in iter {
423        for i in infos.protocol_info() {
424            let v = Vec::from(i.as_ref());
425            if set.contains(&v) {
426                return Err(DuplicateProtonameError(v));
427            } else {
428                set.insert(v);
429            }
430        }
431    }
432    Ok(())
433}
434
435/// It is an error if two handlers share the same protocol name.
436#[derive(Debug, Clone)]
437pub struct DuplicateProtonameError(Vec<u8>);
438
439impl DuplicateProtonameError {
440    /// The protocol name bytes that occurred in more than one handler.
441    pub fn protocol_name(&self) -> &[u8] {
442        &self.0
443    }
444}
445
446impl fmt::Display for DuplicateProtonameError {
447    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
448        if let Ok(s) = std::str::from_utf8(&self.0) {
449            write!(f, "duplicate protocol name: {s}")
450        } else {
451            write!(f, "duplicate protocol name: {:?}", self.0)
452        }
453    }
454}
455
456impl error::Error for DuplicateProtonameError {}