litep2p/protocol/
mdns.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2// Copyright 2023 litep2p developers
3//
4// Permission is hereby granted, free of charge, to any person obtaining a
5// copy of this software and associated documentation files (the "Software"),
6// to deal in the Software without restriction, including without limitation
7// the rights to use, copy, modify, merge, publish, distribute, sublicense,
8// and/or sell copies of the Software, and to permit persons to whom the
9// Software is furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20// DEALINGS IN THE SOFTWARE.
21
22//! [Multicast DNS](https://en.wikipedia.org/wiki/Multicast_DNS) implementation.
23
24use crate::{transport::manager::TransportManagerHandle, DEFAULT_CHANNEL_SIZE};
25
26use futures::Stream;
27use multiaddr::Multiaddr;
28use rand::{distributions::Alphanumeric, Rng};
29use simple_dns::{
30    rdata::{RData, PTR, TXT},
31    Name, Packet, PacketFlag, Question, ResourceRecord, CLASS, QCLASS, QTYPE, TYPE,
32};
33use socket2::{Domain, Protocol, Socket, Type};
34use tokio::{
35    net::UdpSocket,
36    sync::mpsc::{channel, Sender},
37};
38use tokio_stream::wrappers::ReceiverStream;
39
40use std::{
41    collections::HashSet,
42    net,
43    net::{IpAddr, Ipv4Addr, SocketAddr},
44    sync::Arc,
45    time::Duration,
46};
47
48/// Logging target for the file.
49const LOG_TARGET: &str = "litep2p::mdns";
50
51/// IPv4 multicast address.
52const IPV4_MULTICAST_ADDRESS: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
53
54/// IPV4 multicast port.
55const IPV4_MULTICAST_PORT: u16 = 5353;
56
57/// Service name.
58const SERVICE_NAME: &str = "_p2p._udp.local";
59
60/// Events emitted by mDNS.
61// #[derive(Debug, Clone)]
62pub enum MdnsEvent {
63    /// One or more addresses discovered.
64    Discovered(Vec<Multiaddr>),
65}
66
67/// mDNS configuration.
68// #[derive(Debug)]
69pub struct Config {
70    /// How often the network should be queried for new peers.
71    query_interval: Duration,
72
73    /// TX channel for sending mDNS events to user.
74    tx: Sender<MdnsEvent>,
75}
76
77impl Config {
78    /// Create new [`Config`].
79    ///
80    /// Return the configuration and an event stream for receiving [`MdnsEvent`]s.
81    pub fn new(
82        query_interval: Duration,
83    ) -> (Self, Box<dyn Stream<Item = MdnsEvent> + Send + Unpin>) {
84        let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE);
85        (
86            Self { query_interval, tx },
87            Box::new(ReceiverStream::new(rx)),
88        )
89    }
90}
91
92/// Main mDNS object.
93pub(crate) struct Mdns {
94    /// Query interval.
95    query_interval: tokio::time::Interval,
96
97    /// TX channel for sending events to user.
98    event_tx: Sender<MdnsEvent>,
99
100    /// Handle to `TransportManager`.
101    _transport_handle: TransportManagerHandle,
102
103    // Username.
104    username: String,
105
106    /// Next query ID.
107    next_query_id: u16,
108
109    /// Buffer for incoming messages.
110    receive_buffer: Vec<u8>,
111
112    /// Listen addresses.
113    listen_addresses: Vec<Arc<str>>,
114
115    /// Discovered addresses.
116    discovered: HashSet<Multiaddr>,
117}
118
119impl Mdns {
120    /// Create new [`Mdns`].
121    pub(crate) fn new(
122        _transport_handle: TransportManagerHandle,
123        config: Config,
124        listen_addresses: Vec<Multiaddr>,
125    ) -> Self {
126        let mut query_interval = tokio::time::interval(config.query_interval);
127        query_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
128
129        Self {
130            _transport_handle,
131            event_tx: config.tx,
132            next_query_id: 1337u16,
133            discovered: HashSet::new(),
134            query_interval,
135            receive_buffer: vec![0u8; 4096],
136            username: rand::thread_rng()
137                .sample_iter(&Alphanumeric)
138                .take(32)
139                .map(char::from)
140                .collect(),
141            listen_addresses: listen_addresses
142                .into_iter()
143                .map(|address| format!("dnsaddr={address}").into())
144                .collect(),
145        }
146    }
147
148    /// Get next query ID.
149    fn next_query_id(&mut self) -> u16 {
150        let query_id = self.next_query_id;
151        self.next_query_id += 1;
152
153        query_id
154    }
155
156    /// Send mDNS query on the network.
157    async fn on_outbound_request(&mut self, socket: &UdpSocket) -> crate::Result<()> {
158        tracing::debug!(target: LOG_TARGET, "send outbound query");
159
160        let mut packet = Packet::new_query(self.next_query_id());
161
162        packet.questions.push(Question {
163            qname: Name::new_unchecked(SERVICE_NAME),
164            qtype: QTYPE::TYPE(TYPE::PTR),
165            qclass: QCLASS::CLASS(CLASS::IN),
166            unicast_response: false,
167        });
168
169        socket
170            .send_to(
171                &packet.build_bytes_vec().expect("valid packet"),
172                (IPV4_MULTICAST_ADDRESS, IPV4_MULTICAST_PORT),
173            )
174            .await
175            .map(|_| ())
176            .map_err(From::from)
177    }
178
179    /// Handle inbound query.
180    fn on_inbound_request(&self, packet: Packet) -> Option<Vec<u8>> {
181        tracing::debug!(target: LOG_TARGET, ?packet, "handle inbound request");
182
183        let mut packet = Packet::new_reply(packet.id());
184        let srv_name = Name::new_unchecked(SERVICE_NAME);
185
186        packet.answers.push(ResourceRecord::new(
187            srv_name.clone(),
188            CLASS::IN,
189            360,
190            RData::PTR(PTR(Name::new_unchecked(&self.username))),
191        ));
192
193        for address in &self.listen_addresses {
194            let mut record = TXT::new();
195            record.add_string(address).expect("valid string");
196
197            packet.additional_records.push(ResourceRecord {
198                name: Name::new_unchecked(&self.username),
199                class: CLASS::IN,
200                ttl: 360,
201                rdata: RData::TXT(record),
202                cache_flush: false,
203            });
204        }
205
206        Some(packet.build_bytes_vec().expect("valid packet"))
207    }
208
209    /// Handle inbound response.
210    fn on_inbound_response(&self, packet: Packet) -> Vec<Multiaddr> {
211        tracing::debug!(target: LOG_TARGET, "handle inbound response");
212
213        let names = packet
214            .answers
215            .iter()
216            .filter_map(|answer| {
217                if answer.name != Name::new_unchecked(SERVICE_NAME) {
218                    return None;
219                }
220
221                match answer.rdata {
222                    RData::PTR(PTR(ref name)) if name != &Name::new_unchecked(&self.username) =>
223                        Some(name),
224                    _ => None,
225                }
226            })
227            .collect::<Vec<&Name>>();
228
229        let name = match names.len() {
230            0 => return Vec::new(),
231            _ => {
232                tracing::debug!(
233                    target: LOG_TARGET,
234                    ?names,
235                    "response name"
236                );
237
238                names[0]
239            }
240        };
241
242        packet
243            .additional_records
244            .iter()
245            .flat_map(|record| {
246                if &record.name != name {
247                    return vec![];
248                }
249
250                // TODO: https://github.com/paritytech/litep2p/issues/333
251                // `filter_map` is not necessary as there's at most one entry
252                match &record.rdata {
253                    RData::TXT(text) => text
254                        .attributes()
255                        .iter()
256                        .filter_map(|(_, address)| {
257                            address.as_ref().and_then(|inner| inner.parse().ok())
258                        })
259                        .collect(),
260                    _ => vec![],
261                }
262            })
263            .collect()
264    }
265
266    /// Setup the socket.
267    fn setup_socket() -> crate::Result<UdpSocket> {
268        let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
269        socket.set_reuse_address(true)?;
270        #[cfg(unix)]
271        socket.set_reuse_port(true)?;
272        socket.bind(
273            &SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), IPV4_MULTICAST_PORT).into(),
274        )?;
275        socket.set_multicast_loop_v4(true)?;
276        socket.set_multicast_ttl_v4(255)?;
277        socket.join_multicast_v4(&IPV4_MULTICAST_ADDRESS, &Ipv4Addr::UNSPECIFIED)?;
278        socket.set_nonblocking(true)?;
279
280        UdpSocket::from_std(net::UdpSocket::from(socket)).map_err(Into::into)
281    }
282
283    /// Event loop for [`Mdns`].
284    pub(crate) async fn start(mut self) {
285        tracing::debug!(target: LOG_TARGET, "starting mdns event loop");
286
287        let mut socket_opt = None;
288
289        loop {
290            let socket = match socket_opt.take() {
291                Some(s) => s,
292                None => {
293                    let _ = self.query_interval.tick().await;
294                    match Self::setup_socket() {
295                        Ok(s) => s,
296                        Err(error) => {
297                            tracing::debug!(
298                                target: LOG_TARGET,
299                                ?error,
300                                "failed to setup mDNS socket, will try again"
301                            );
302                            continue;
303                        }
304                    }
305                }
306            };
307
308            tokio::select! {
309                _ = self.query_interval.tick() => {
310                    tracing::trace!(target: LOG_TARGET, "query interval ticked");
311
312                    if let Err(error) = self.on_outbound_request(&socket).await {
313                        tracing::debug!(target: LOG_TARGET, ?error, "failed to send mdns query");
314                        // Let's recreate the socket
315                        continue;
316                    }
317                },
318
319                result = socket.recv_from(&mut self.receive_buffer) => match result {
320                    Ok((nread, address)) => match Packet::parse(&self.receive_buffer[..nread]) {
321                        Ok(packet) => match packet.has_flags(PacketFlag::RESPONSE) {
322                            true => {
323                                let to_forward = self.on_inbound_response(packet).into_iter().filter_map(|address| {
324                                    self.discovered.insert(address.clone()).then_some(address)
325                                })
326                                .collect::<Vec<_>>();
327
328                                if !to_forward.is_empty() {
329                                    let _ = self.event_tx.send(MdnsEvent::Discovered(to_forward)).await;
330                                }
331                            }
332                            false => if let Some(response) = self.on_inbound_request(packet) {
333                                if let Err(error) = socket
334                                    .send_to(&response, (IPV4_MULTICAST_ADDRESS, IPV4_MULTICAST_PORT))
335                                    .await {
336                                    tracing::debug!(target: LOG_TARGET, ?error, "failed to send mdns response");
337                                    // Let's recreate the socket
338                                    continue;
339                                }
340                            }
341                        }
342                        Err(error) => tracing::debug!(
343                            target: LOG_TARGET,
344                            ?address,
345                            ?error,
346                            ?nread,
347                            "failed to parse mdns packet"
348                        ),
349                    }
350                    Err(error) => {
351                        tracing::debug!(target: LOG_TARGET, ?error, "failed to read from socket");
352                        // Let's recreate the socket
353                        continue;
354                    }
355                },
356            };
357
358            socket_opt = Some(socket);
359        }
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use crate::transport::manager::TransportManagerBuilder;
367    use futures::StreamExt;
368    use multiaddr::Protocol;
369
370    #[tokio::test]
371    async fn mdns_works() {
372        let _ = tracing_subscriber::fmt()
373            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
374            .try_init();
375
376        let (config1, mut stream1) = Config::new(Duration::from_secs(5));
377        let manager1 = TransportManagerBuilder::new().build();
378
379        let mdns1 = Mdns::new(
380            manager1.transport_manager_handle(),
381            config1,
382            vec![
383                "/ip6/::1/tcp/8888/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTaaaa"
384                    .parse()
385                    .unwrap(),
386                "/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTaaaa"
387                    .parse()
388                    .unwrap(),
389            ],
390        );
391
392        let (config2, mut stream2) = Config::new(Duration::from_secs(5));
393        let manager2 = TransportManagerBuilder::new().build();
394
395        let mdns2 = Mdns::new(
396            manager2.transport_manager_handle(),
397            config2,
398            vec![
399                "/ip6/::1/tcp/9999/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTbbbb"
400                    .parse()
401                    .unwrap(),
402                "/ip4/127.0.0.1/tcp/9999/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTbbbb"
403                    .parse()
404                    .unwrap(),
405            ],
406        );
407
408        tokio::spawn(mdns1.start());
409        tokio::spawn(mdns2.start());
410
411        let mut peer1_discovered = false;
412        let mut peer2_discovered = false;
413
414        while !peer1_discovered && !peer2_discovered {
415            tokio::select! {
416                event = stream1.next() => match event.unwrap() {
417                    MdnsEvent::Discovered(addrs) => {
418                        if addrs.len() == 2 {
419                            let mut iter = addrs[0].iter();
420
421                            if !std::matches!(iter.next(), Some(Protocol::Ip4(_) | Protocol::Ip6(_))) {
422                                continue
423                            }
424
425                            match iter.next() {
426                                Some(Protocol::Tcp(port)) => {
427                                    if port != 9999 {
428                                        continue
429                                    }
430                                }
431                                _ => continue,
432                            }
433
434                            peer1_discovered = true;
435                        }
436                    }
437                },
438                event = stream2.next() => match event.unwrap() {
439                    MdnsEvent::Discovered(addrs) => {
440                        if addrs.len() == 2 {
441                            let mut iter = addrs[0].iter();
442
443                            if !std::matches!(iter.next(), Some(Protocol::Ip4(_) | Protocol::Ip6(_))) {
444                                continue
445                            }
446
447                            match iter.next() {
448                                Some(Protocol::Tcp(port)) => {
449                                    if port != 8888 {
450                                        continue
451                                    }
452                                }
453                                _ => continue,
454                            }
455
456                            peer2_discovered = true;
457                        }
458                    }
459                }
460            }
461        }
462    }
463}