libp2p_mdns/
behaviour.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21mod iface;
22mod socket;
23mod timer;
24
25use self::iface::InterfaceState;
26use crate::behaviour::{socket::AsyncSocket, timer::Builder};
27use crate::Config;
28use futures::Stream;
29use if_watch::IfEvent;
30use libp2p_core::{Endpoint, Multiaddr};
31use libp2p_identity::PeerId;
32use libp2p_swarm::behaviour::FromSwarm;
33use libp2p_swarm::{
34    dummy, ConnectionDenied, ConnectionId, ListenAddresses, NetworkBehaviour, PollParameters,
35    THandler, THandlerInEvent, THandlerOutEvent, ToSwarm,
36};
37use smallvec::SmallVec;
38use std::collections::hash_map::{Entry, HashMap};
39use std::{cmp, fmt, io, net::IpAddr, pin::Pin, task::Context, task::Poll, time::Instant};
40
41/// An abstraction to allow for compatibility with various async runtimes.
42pub trait Provider: 'static {
43    /// The Async Socket type.
44    type Socket: AsyncSocket;
45    /// The Async Timer type.
46    type Timer: Builder + Stream;
47    /// The IfWatcher type.
48    type Watcher: Stream<Item = std::io::Result<IfEvent>> + fmt::Debug + Unpin;
49
50    /// Create a new instance of the `IfWatcher` type.
51    fn new_watcher() -> Result<Self::Watcher, std::io::Error>;
52}
53
54/// The type of a [`Behaviour`] using the `async-io` implementation.
55#[cfg(feature = "async-io")]
56pub mod async_io {
57    use super::Provider;
58    use crate::behaviour::{socket::asio::AsyncUdpSocket, timer::asio::AsyncTimer};
59    use if_watch::smol::IfWatcher;
60
61    #[doc(hidden)]
62    pub enum AsyncIo {}
63
64    impl Provider for AsyncIo {
65        type Socket = AsyncUdpSocket;
66        type Timer = AsyncTimer;
67        type Watcher = IfWatcher;
68
69        fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
70            IfWatcher::new()
71        }
72    }
73
74    pub type Behaviour = super::Behaviour<AsyncIo>;
75}
76
77/// The type of a [`Behaviour`] using the `tokio` implementation.
78#[cfg(feature = "tokio")]
79pub mod tokio {
80    use super::Provider;
81    use crate::behaviour::{socket::tokio::TokioUdpSocket, timer::tokio::TokioTimer};
82    use if_watch::tokio::IfWatcher;
83
84    #[doc(hidden)]
85    pub enum Tokio {}
86
87    impl Provider for Tokio {
88        type Socket = TokioUdpSocket;
89        type Timer = TokioTimer;
90        type Watcher = IfWatcher;
91
92        fn new_watcher() -> Result<Self::Watcher, std::io::Error> {
93            IfWatcher::new()
94        }
95    }
96
97    pub type Behaviour = super::Behaviour<Tokio>;
98}
99
100/// A `NetworkBehaviour` for mDNS. Automatically discovers peers on the local network and adds
101/// them to the topology.
102#[derive(Debug)]
103pub struct Behaviour<P>
104where
105    P: Provider,
106{
107    /// InterfaceState config.
108    config: Config,
109
110    /// Iface watcher.
111    if_watch: P::Watcher,
112
113    /// Mdns interface states.
114    iface_states: HashMap<IpAddr, InterfaceState<P::Socket, P::Timer>>,
115
116    /// List of nodes that we have discovered, the address, and when their TTL expires.
117    ///
118    /// Each combination of `PeerId` and `Multiaddr` can only appear once, but the same `PeerId`
119    /// can appear multiple times.
120    discovered_nodes: SmallVec<[(PeerId, Multiaddr, Instant); 8]>,
121
122    /// Future that fires when the TTL of at least one node in `discovered_nodes` expires.
123    ///
124    /// `None` if `discovered_nodes` is empty.
125    closest_expiration: Option<P::Timer>,
126
127    listen_addresses: ListenAddresses,
128
129    local_peer_id: PeerId,
130}
131
132impl<P> Behaviour<P>
133where
134    P: Provider,
135{
136    /// Builds a new `Mdns` behaviour.
137    pub fn new(config: Config, local_peer_id: PeerId) -> io::Result<Self> {
138        Ok(Self {
139            config,
140            if_watch: P::new_watcher()?,
141            iface_states: Default::default(),
142            discovered_nodes: Default::default(),
143            closest_expiration: Default::default(),
144            listen_addresses: Default::default(),
145            local_peer_id,
146        })
147    }
148
149    /// Returns true if the given `PeerId` is in the list of nodes discovered through mDNS.
150    pub fn has_node(&self, peer_id: &PeerId) -> bool {
151        self.discovered_nodes().any(|p| p == peer_id)
152    }
153
154    /// Returns the list of nodes that we have discovered through mDNS and that are not expired.
155    pub fn discovered_nodes(&self) -> impl ExactSizeIterator<Item = &PeerId> {
156        self.discovered_nodes.iter().map(|(p, _, _)| p)
157    }
158
159    /// Expires a node before the ttl.
160    pub fn expire_node(&mut self, peer_id: &PeerId) {
161        let now = Instant::now();
162        for (peer, _addr, expires) in &mut self.discovered_nodes {
163            if peer == peer_id {
164                *expires = now;
165            }
166        }
167        self.closest_expiration = Some(P::Timer::at(now));
168    }
169}
170
171impl<P> NetworkBehaviour for Behaviour<P>
172where
173    P: Provider,
174{
175    type ConnectionHandler = dummy::ConnectionHandler;
176    type ToSwarm = Event;
177
178    fn handle_established_inbound_connection(
179        &mut self,
180        _: ConnectionId,
181        _: PeerId,
182        _: &Multiaddr,
183        _: &Multiaddr,
184    ) -> Result<THandler<Self>, ConnectionDenied> {
185        Ok(dummy::ConnectionHandler)
186    }
187
188    fn handle_pending_outbound_connection(
189        &mut self,
190        _connection_id: ConnectionId,
191        maybe_peer: Option<PeerId>,
192        _addresses: &[Multiaddr],
193        _effective_role: Endpoint,
194    ) -> Result<Vec<Multiaddr>, ConnectionDenied> {
195        let peer_id = match maybe_peer {
196            None => return Ok(vec![]),
197            Some(peer) => peer,
198        };
199
200        Ok(self
201            .discovered_nodes
202            .iter()
203            .filter(|(peer, _, _)| peer == &peer_id)
204            .map(|(_, addr, _)| addr.clone())
205            .collect())
206    }
207
208    fn handle_established_outbound_connection(
209        &mut self,
210        _: ConnectionId,
211        _: PeerId,
212        _: &Multiaddr,
213        _: Endpoint,
214    ) -> Result<THandler<Self>, ConnectionDenied> {
215        Ok(dummy::ConnectionHandler)
216    }
217
218    fn on_connection_handler_event(
219        &mut self,
220        _: PeerId,
221        _: ConnectionId,
222        ev: THandlerOutEvent<Self>,
223    ) {
224        void::unreachable(ev)
225    }
226
227    fn on_swarm_event(&mut self, event: FromSwarm<Self::ConnectionHandler>) {
228        self.listen_addresses.on_swarm_event(&event);
229
230        match event {
231            FromSwarm::NewListener(_) => {
232                log::trace!("waking interface state because listening address changed");
233                for iface in self.iface_states.values_mut() {
234                    iface.fire_timer();
235                }
236            }
237            FromSwarm::ConnectionClosed(_)
238            | FromSwarm::ConnectionEstablished(_)
239            | FromSwarm::DialFailure(_)
240            | FromSwarm::AddressChange(_)
241            | FromSwarm::ListenFailure(_)
242            | FromSwarm::NewListenAddr(_)
243            | FromSwarm::ExpiredListenAddr(_)
244            | FromSwarm::ListenerError(_)
245            | FromSwarm::ListenerClosed(_)
246            | FromSwarm::NewExternalAddrCandidate(_)
247            | FromSwarm::ExternalAddrExpired(_)
248            | FromSwarm::ExternalAddrConfirmed(_) => {}
249        }
250    }
251
252    fn poll(
253        &mut self,
254        cx: &mut Context<'_>,
255        _: &mut impl PollParameters,
256    ) -> Poll<ToSwarm<Self::ToSwarm, THandlerInEvent<Self>>> {
257        // Poll ifwatch.
258        while let Poll::Ready(Some(event)) = Pin::new(&mut self.if_watch).poll_next(cx) {
259            match event {
260                Ok(IfEvent::Up(inet)) => {
261                    let addr = inet.addr();
262                    if addr.is_loopback() {
263                        continue;
264                    }
265                    if addr.is_ipv4() && self.config.enable_ipv6
266                        || addr.is_ipv6() && !self.config.enable_ipv6
267                    {
268                        continue;
269                    }
270                    if let Entry::Vacant(e) = self.iface_states.entry(addr) {
271                        match InterfaceState::new(addr, self.config.clone(), self.local_peer_id) {
272                            Ok(iface_state) => {
273                                e.insert(iface_state);
274                            }
275                            Err(err) => log::error!("failed to create `InterfaceState`: {}", err),
276                        }
277                    }
278                }
279                Ok(IfEvent::Down(inet)) => {
280                    if self.iface_states.contains_key(&inet.addr()) {
281                        log::info!("dropping instance {}", inet.addr());
282                        self.iface_states.remove(&inet.addr());
283                    }
284                }
285                Err(err) => log::error!("if watch returned an error: {}", err),
286            }
287        }
288        // Emit discovered event.
289        let mut discovered = Vec::new();
290        for iface_state in self.iface_states.values_mut() {
291            while let Poll::Ready((peer, addr, expiration)) =
292                iface_state.poll(cx, &self.listen_addresses)
293            {
294                if let Some((_, _, cur_expires)) = self
295                    .discovered_nodes
296                    .iter_mut()
297                    .find(|(p, a, _)| *p == peer && *a == addr)
298                {
299                    *cur_expires = cmp::max(*cur_expires, expiration);
300                } else {
301                    log::info!("discovered: {} {}", peer, addr);
302                    self.discovered_nodes.push((peer, addr.clone(), expiration));
303                    discovered.push((peer, addr));
304                }
305            }
306        }
307        if !discovered.is_empty() {
308            let event = Event::Discovered(discovered);
309            return Poll::Ready(ToSwarm::GenerateEvent(event));
310        }
311        // Emit expired event.
312        let now = Instant::now();
313        let mut closest_expiration = None;
314        let mut expired = Vec::new();
315        self.discovered_nodes.retain(|(peer, addr, expiration)| {
316            if *expiration <= now {
317                log::info!("expired: {} {}", peer, addr);
318                expired.push((*peer, addr.clone()));
319                return false;
320            }
321            closest_expiration = Some(closest_expiration.unwrap_or(*expiration).min(*expiration));
322            true
323        });
324        if !expired.is_empty() {
325            let event = Event::Expired(expired);
326            return Poll::Ready(ToSwarm::GenerateEvent(event));
327        }
328        if let Some(closest_expiration) = closest_expiration {
329            let mut timer = P::Timer::at(closest_expiration);
330            let _ = Pin::new(&mut timer).poll_next(cx);
331
332            self.closest_expiration = Some(timer);
333        }
334        Poll::Pending
335    }
336}
337
338/// Event that can be produced by the `Mdns` behaviour.
339#[derive(Debug, Clone)]
340pub enum Event {
341    /// Discovered nodes through mDNS.
342    Discovered(Vec<(PeerId, Multiaddr)>),
343
344    /// The given combinations of `PeerId` and `Multiaddr` have expired.
345    ///
346    /// Each discovered record has a time-to-live. When this TTL expires and the address hasn't
347    /// been refreshed, we remove it from the list and emit it as an `Expired` event.
348    Expired(Vec<(PeerId, Multiaddr)>),
349}