libp2p_swarm/handler/
multi.rs

1// Copyright 2020 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! A [`ConnectionHandler`] implementation that combines multiple other [`ConnectionHandler`]s
22//! indexed by some key.
23
24use crate::handler::{
25    AddressChange, ConnectionEvent, ConnectionHandler, ConnectionHandlerEvent, DialUpgradeError,
26    FullyNegotiatedInbound, FullyNegotiatedOutbound, KeepAlive, ListenUpgradeError,
27    SubstreamProtocol,
28};
29use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, UpgradeInfoSend};
30use crate::Stream;
31use futures::{future::BoxFuture, prelude::*};
32use rand::Rng;
33use std::{
34    cmp,
35    collections::{HashMap, HashSet},
36    error,
37    fmt::{self, Debug},
38    hash::Hash,
39    iter::{self, FromIterator},
40    task::{Context, Poll},
41    time::Duration,
42};
43
44/// A [`ConnectionHandler`] for multiple [`ConnectionHandler`]s of the same type.
45#[derive(Clone)]
46pub struct MultiHandler<K, H> {
47    handlers: HashMap<K, H>,
48}
49
50impl<K, H> fmt::Debug for MultiHandler<K, H>
51where
52    K: fmt::Debug + Eq + Hash,
53    H: fmt::Debug,
54{
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        f.debug_struct("MultiHandler")
57            .field("handlers", &self.handlers)
58            .finish()
59    }
60}
61
62impl<K, H> MultiHandler<K, H>
63where
64    K: Clone + Debug + Hash + Eq + Send + 'static,
65    H: ConnectionHandler,
66{
67    /// Create and populate a `MultiHandler` from the given handler iterator.
68    ///
69    /// It is an error for any two protocols handlers to share the same protocol name.
70    pub fn try_from_iter<I>(iter: I) -> Result<Self, DuplicateProtonameError>
71    where
72        I: IntoIterator<Item = (K, H)>,
73    {
74        let m = MultiHandler {
75            handlers: HashMap::from_iter(iter),
76        };
77        uniq_proto_names(
78            m.handlers
79                .values()
80                .map(|h| h.listen_protocol().into_upgrade().0),
81        )?;
82        Ok(m)
83    }
84
85    fn on_listen_upgrade_error(
86        &mut self,
87        ListenUpgradeError {
88            error: (key, error),
89            mut info,
90        }: ListenUpgradeError<
91            <Self as ConnectionHandler>::InboundOpenInfo,
92            <Self as ConnectionHandler>::InboundProtocol,
93        >,
94    ) {
95        if let Some(h) = self.handlers.get_mut(&key) {
96            if let Some(i) = info.take(&key) {
97                h.on_connection_event(ConnectionEvent::ListenUpgradeError(ListenUpgradeError {
98                    info: i,
99                    error,
100                }));
101            }
102        }
103    }
104}
105
106impl<K, H> ConnectionHandler for MultiHandler<K, H>
107where
108    K: Clone + Debug + Hash + Eq + Send + 'static,
109    H: ConnectionHandler,
110    H::InboundProtocol: InboundUpgradeSend,
111    H::OutboundProtocol: OutboundUpgradeSend,
112{
113    type FromBehaviour = (K, <H as ConnectionHandler>::FromBehaviour);
114    type ToBehaviour = (K, <H as ConnectionHandler>::ToBehaviour);
115    #[allow(deprecated)]
116    type Error = <H as ConnectionHandler>::Error;
117    type InboundProtocol = Upgrade<K, <H as ConnectionHandler>::InboundProtocol>;
118    type OutboundProtocol = <H as ConnectionHandler>::OutboundProtocol;
119    type InboundOpenInfo = Info<K, <H as ConnectionHandler>::InboundOpenInfo>;
120    type OutboundOpenInfo = (K, <H as ConnectionHandler>::OutboundOpenInfo);
121
122    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
123        let (upgrade, info, timeout) = self
124            .handlers
125            .iter()
126            .map(|(key, handler)| {
127                let proto = handler.listen_protocol();
128                let timeout = *proto.timeout();
129                let (upgrade, info) = proto.into_upgrade();
130                (key.clone(), (upgrade, info, timeout))
131            })
132            .fold(
133                (Upgrade::new(), Info::new(), Duration::from_secs(0)),
134                |(mut upg, mut inf, mut timeout), (k, (u, i, t))| {
135                    upg.upgrades.push((k.clone(), u));
136                    inf.infos.push((k, i));
137                    timeout = cmp::max(timeout, t);
138                    (upg, inf, timeout)
139                },
140            );
141        SubstreamProtocol::new(upgrade, info).with_timeout(timeout)
142    }
143
144    fn on_connection_event(
145        &mut self,
146        event: ConnectionEvent<
147            Self::InboundProtocol,
148            Self::OutboundProtocol,
149            Self::InboundOpenInfo,
150            Self::OutboundOpenInfo,
151        >,
152    ) {
153        match event {
154            ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound {
155                protocol,
156                info: (key, arg),
157            }) => {
158                if let Some(h) = self.handlers.get_mut(&key) {
159                    h.on_connection_event(ConnectionEvent::FullyNegotiatedOutbound(
160                        FullyNegotiatedOutbound {
161                            protocol,
162                            info: arg,
163                        },
164                    ));
165                } else {
166                    log::error!("FullyNegotiatedOutbound: no handler for key")
167                }
168            }
169            ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound {
170                protocol: (key, arg),
171                mut info,
172            }) => {
173                if let Some(h) = self.handlers.get_mut(&key) {
174                    if let Some(i) = info.take(&key) {
175                        h.on_connection_event(ConnectionEvent::FullyNegotiatedInbound(
176                            FullyNegotiatedInbound {
177                                protocol: arg,
178                                info: i,
179                            },
180                        ));
181                    }
182                } else {
183                    log::error!("FullyNegotiatedInbound: no handler for key")
184                }
185            }
186            ConnectionEvent::AddressChange(AddressChange { new_address }) => {
187                for h in self.handlers.values_mut() {
188                    h.on_connection_event(ConnectionEvent::AddressChange(AddressChange {
189                        new_address,
190                    }));
191                }
192            }
193            ConnectionEvent::DialUpgradeError(DialUpgradeError {
194                info: (key, arg),
195                error,
196            }) => {
197                if let Some(h) = self.handlers.get_mut(&key) {
198                    h.on_connection_event(ConnectionEvent::DialUpgradeError(DialUpgradeError {
199                        info: arg,
200                        error,
201                    }));
202                } else {
203                    log::error!("DialUpgradeError: no handler for protocol")
204                }
205            }
206            ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => {
207                self.on_listen_upgrade_error(listen_upgrade_error)
208            }
209            ConnectionEvent::LocalProtocolsChange(supported_protocols) => {
210                for h in self.handlers.values_mut() {
211                    h.on_connection_event(ConnectionEvent::LocalProtocolsChange(
212                        supported_protocols.clone(),
213                    ));
214                }
215            }
216            ConnectionEvent::RemoteProtocolsChange(supported_protocols) => {
217                for h in self.handlers.values_mut() {
218                    h.on_connection_event(ConnectionEvent::RemoteProtocolsChange(
219                        supported_protocols.clone(),
220                    ));
221                }
222            }
223        }
224    }
225
226    fn on_behaviour_event(&mut self, (key, event): Self::FromBehaviour) {
227        if let Some(h) = self.handlers.get_mut(&key) {
228            h.on_behaviour_event(event)
229        } else {
230            log::error!("on_behaviour_event: no handler for key")
231        }
232    }
233
234    fn connection_keep_alive(&self) -> KeepAlive {
235        self.handlers
236            .values()
237            .map(|h| h.connection_keep_alive())
238            .max()
239            .unwrap_or(KeepAlive::No)
240    }
241
242    #[allow(deprecated)]
243    fn poll(
244        &mut self,
245        cx: &mut Context<'_>,
246    ) -> Poll<
247        ConnectionHandlerEvent<
248            Self::OutboundProtocol,
249            Self::OutboundOpenInfo,
250            Self::ToBehaviour,
251            Self::Error,
252        >,
253    > {
254        // Calling `gen_range(0, 0)` (see below) would panic, so we have return early to avoid
255        // that situation.
256        if self.handlers.is_empty() {
257            return Poll::Pending;
258        }
259
260        // Not always polling handlers in the same order should give anyone the chance to make progress.
261        let pos = rand::thread_rng().gen_range(0..self.handlers.len());
262
263        for (k, h) in self.handlers.iter_mut().skip(pos) {
264            if let Poll::Ready(e) = h.poll(cx) {
265                let e = e
266                    .map_outbound_open_info(|i| (k.clone(), i))
267                    .map_custom(|p| (k.clone(), p));
268                return Poll::Ready(e);
269            }
270        }
271
272        for (k, h) in self.handlers.iter_mut().take(pos) {
273            if let Poll::Ready(e) = h.poll(cx) {
274                let e = e
275                    .map_outbound_open_info(|i| (k.clone(), i))
276                    .map_custom(|p| (k.clone(), p));
277                return Poll::Ready(e);
278            }
279        }
280
281        Poll::Pending
282    }
283}
284
285/// Split [`MultiHandler`] into parts.
286impl<K, H> IntoIterator for MultiHandler<K, H> {
287    type Item = <Self::IntoIter as Iterator>::Item;
288    type IntoIter = std::collections::hash_map::IntoIter<K, H>;
289
290    fn into_iter(self) -> Self::IntoIter {
291        self.handlers.into_iter()
292    }
293}
294
295/// Index and protocol name pair used as `UpgradeInfo::Info`.
296#[derive(Debug, Clone)]
297pub struct IndexedProtoName<H>(usize, H);
298
299impl<H: AsRef<str>> AsRef<str> for IndexedProtoName<H> {
300    fn as_ref(&self) -> &str {
301        self.1.as_ref()
302    }
303}
304
305/// The aggregated `InboundOpenInfo`s of supported inbound substream protocols.
306#[derive(Clone)]
307pub struct Info<K, I> {
308    infos: Vec<(K, I)>,
309}
310
311impl<K: Eq, I> Info<K, I> {
312    fn new() -> Self {
313        Info { infos: Vec::new() }
314    }
315
316    pub fn take(&mut self, k: &K) -> Option<I> {
317        if let Some(p) = self.infos.iter().position(|(key, _)| key == k) {
318            return Some(self.infos.remove(p).1);
319        }
320        None
321    }
322}
323
324/// Inbound and outbound upgrade for all [`ConnectionHandler`]s.
325#[derive(Clone)]
326pub struct Upgrade<K, H> {
327    upgrades: Vec<(K, H)>,
328}
329
330impl<K, H> Upgrade<K, H> {
331    fn new() -> Self {
332        Upgrade {
333            upgrades: Vec::new(),
334        }
335    }
336}
337
338impl<K, H> fmt::Debug for Upgrade<K, H>
339where
340    K: fmt::Debug + Eq + Hash,
341    H: fmt::Debug,
342{
343    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344        f.debug_struct("Upgrade")
345            .field("upgrades", &self.upgrades)
346            .finish()
347    }
348}
349
350impl<K, H> UpgradeInfoSend for Upgrade<K, H>
351where
352    H: UpgradeInfoSend,
353    K: Send + 'static,
354{
355    type Info = IndexedProtoName<H::Info>;
356    type InfoIter = std::vec::IntoIter<Self::Info>;
357
358    fn protocol_info(&self) -> Self::InfoIter {
359        self.upgrades
360            .iter()
361            .enumerate()
362            .flat_map(|(i, (_, h))| iter::repeat(i).zip(h.protocol_info()))
363            .map(|(i, h)| IndexedProtoName(i, h))
364            .collect::<Vec<_>>()
365            .into_iter()
366    }
367}
368
369impl<K, H> InboundUpgradeSend for Upgrade<K, H>
370where
371    H: InboundUpgradeSend,
372    K: Send + 'static,
373{
374    type Output = (K, <H as InboundUpgradeSend>::Output);
375    type Error = (K, <H as InboundUpgradeSend>::Error);
376    type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
377
378    fn upgrade_inbound(mut self, resource: Stream, info: Self::Info) -> Self::Future {
379        let IndexedProtoName(index, info) = info;
380        let (key, upgrade) = self.upgrades.remove(index);
381        upgrade
382            .upgrade_inbound(resource, info)
383            .map(move |out| match out {
384                Ok(o) => Ok((key, o)),
385                Err(e) => Err((key, e)),
386            })
387            .boxed()
388    }
389}
390
391impl<K, H> OutboundUpgradeSend for Upgrade<K, H>
392where
393    H: OutboundUpgradeSend,
394    K: Send + 'static,
395{
396    type Output = (K, <H as OutboundUpgradeSend>::Output);
397    type Error = (K, <H as OutboundUpgradeSend>::Error);
398    type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
399
400    fn upgrade_outbound(mut self, resource: Stream, info: Self::Info) -> Self::Future {
401        let IndexedProtoName(index, info) = info;
402        let (key, upgrade) = self.upgrades.remove(index);
403        upgrade
404            .upgrade_outbound(resource, info)
405            .map(move |out| match out {
406                Ok(o) => Ok((key, o)),
407                Err(e) => Err((key, e)),
408            })
409            .boxed()
410    }
411}
412
413/// Check that no two protocol names are equal.
414fn uniq_proto_names<I, T>(iter: I) -> Result<(), DuplicateProtonameError>
415where
416    I: Iterator<Item = T>,
417    T: UpgradeInfoSend,
418{
419    let mut set = HashSet::new();
420    for infos in iter {
421        for i in infos.protocol_info() {
422            let v = Vec::from(i.as_ref());
423            if set.contains(&v) {
424                return Err(DuplicateProtonameError(v));
425            } else {
426                set.insert(v);
427            }
428        }
429    }
430    Ok(())
431}
432
433/// It is an error if two handlers share the same protocol name.
434#[derive(Debug, Clone)]
435pub struct DuplicateProtonameError(Vec<u8>);
436
437impl DuplicateProtonameError {
438    /// The protocol name bytes that occured in more than one handler.
439    pub fn protocol_name(&self) -> &[u8] {
440        &self.0
441    }
442}
443
444impl fmt::Display for DuplicateProtonameError {
445    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
446        if let Ok(s) = std::str::from_utf8(&self.0) {
447            write!(f, "duplicate protocol name: {s}")
448        } else {
449            write!(f, "duplicate protocol name: {:?}", self.0)
450        }
451    }
452}
453
454impl error::Error for DuplicateProtonameError {}