litep2p/protocol/
mdns.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2// Copyright 2023 litep2p developers
3//
4// Permission is hereby granted, free of charge, to any person obtaining a
5// copy of this software and associated documentation files (the "Software"),
6// to deal in the Software without restriction, including without limitation
7// the rights to use, copy, modify, merge, publish, distribute, sublicense,
8// and/or sell copies of the Software, and to permit persons to whom the
9// Software is furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
15// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
20// DEALINGS IN THE SOFTWARE.
21
22//! [Multicast DNS](https://en.wikipedia.org/wiki/Multicast_DNS) implementation.
23
24use crate::{transport::manager::TransportManagerHandle, DEFAULT_CHANNEL_SIZE};
25
26use futures::Stream;
27use multiaddr::Multiaddr;
28use rand::{distributions::Alphanumeric, Rng};
29use simple_dns::{
30    rdata::{RData, PTR, TXT},
31    Name, Packet, PacketFlag, Question, ResourceRecord, CLASS, QCLASS, QTYPE, TYPE,
32};
33use socket2::{Domain, Protocol, Socket, Type};
34use tokio::{
35    net::UdpSocket,
36    sync::mpsc::{channel, Sender},
37};
38use tokio_stream::wrappers::ReceiverStream;
39
40use std::{
41    collections::HashSet,
42    net,
43    net::{IpAddr, Ipv4Addr, SocketAddr},
44    sync::Arc,
45    time::Duration,
46};
47
48/// Logging target for the file.
49const LOG_TARGET: &str = "litep2p::mdns";
50
51/// IPv4 multicast address.
52const IPV4_MULTICAST_ADDRESS: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
53
54/// IPV4 multicast port.
55const IPV4_MULTICAST_PORT: u16 = 5353;
56
57/// Service name.
58const SERVICE_NAME: &str = "_p2p._udp.local";
59
60/// Events emitted by mDNS.
61// #[derive(Debug, Clone)]
62pub enum MdnsEvent {
63    /// One or more addresses discovered.
64    Discovered(Vec<Multiaddr>),
65}
66
67/// mDNS configuration.
68// #[derive(Debug)]
69pub struct Config {
70    /// How often the network should be queried for new peers.
71    query_interval: Duration,
72
73    /// TX channel for sending mDNS events to user.
74    tx: Sender<MdnsEvent>,
75}
76
77impl Config {
78    /// Create new [`Config`].
79    ///
80    /// Return the configuration and an event stream for receiving [`MdnsEvent`]s.
81    pub fn new(
82        query_interval: Duration,
83    ) -> (Self, Box<dyn Stream<Item = MdnsEvent> + Send + Unpin>) {
84        let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE);
85        (
86            Self { query_interval, tx },
87            Box::new(ReceiverStream::new(rx)),
88        )
89    }
90}
91
92/// Main mDNS object.
93pub(crate) struct Mdns {
94    /// Query interval.
95    query_interval: tokio::time::Interval,
96
97    /// TX channel for sending events to user.
98    event_tx: Sender<MdnsEvent>,
99
100    /// Handle to `TransportManager`.
101    _transport_handle: TransportManagerHandle,
102
103    // Username.
104    username: String,
105
106    /// Next query ID.
107    next_query_id: u16,
108
109    /// Buffer for incoming messages.
110    receive_buffer: Vec<u8>,
111
112    /// Listen addresses.
113    listen_addresses: Vec<Arc<str>>,
114
115    /// Discovered addresses.
116    discovered: HashSet<Multiaddr>,
117}
118
119impl Mdns {
120    /// Create new [`Mdns`].
121    pub(crate) fn new(
122        _transport_handle: TransportManagerHandle,
123        config: Config,
124        listen_addresses: Vec<Multiaddr>,
125    ) -> Self {
126        let mut query_interval = tokio::time::interval(config.query_interval);
127        query_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
128
129        Self {
130            _transport_handle,
131            event_tx: config.tx,
132            next_query_id: 1337u16,
133            discovered: HashSet::new(),
134            query_interval,
135            receive_buffer: vec![0u8; 4096],
136            username: rand::thread_rng()
137                .sample_iter(&Alphanumeric)
138                .take(32)
139                .map(char::from)
140                .collect(),
141            listen_addresses: listen_addresses
142                .into_iter()
143                .map(|address| format!("dnsaddr={address}").into())
144                .collect(),
145        }
146    }
147
148    /// Get next query ID.
149    fn next_query_id(&mut self) -> u16 {
150        let query_id = self.next_query_id;
151        self.next_query_id += 1;
152
153        query_id
154    }
155
156    /// Send mDNS query on the network.
157    async fn on_outbound_request(&mut self, socket: &UdpSocket) -> crate::Result<()> {
158        tracing::debug!(target: LOG_TARGET, "send outbound query");
159
160        let mut packet = Packet::new_query(self.next_query_id());
161
162        packet.questions.push(Question {
163            qname: Name::new_unchecked(SERVICE_NAME),
164            qtype: QTYPE::TYPE(TYPE::PTR),
165            qclass: QCLASS::CLASS(CLASS::IN),
166            unicast_response: false,
167        });
168
169        socket
170            .send_to(
171                &packet.build_bytes_vec().expect("valid packet"),
172                (IPV4_MULTICAST_ADDRESS, IPV4_MULTICAST_PORT),
173            )
174            .await
175            .map(|_| ())
176            .map_err(From::from)
177    }
178
179    /// Handle inbound query.
180    fn on_inbound_request(&self, packet: Packet) -> Option<Vec<u8>> {
181        tracing::debug!(target: LOG_TARGET, ?packet, "handle inbound request");
182
183        let mut packet = Packet::new_reply(packet.id());
184        let srv_name = Name::new_unchecked(SERVICE_NAME);
185
186        packet.answers.push(ResourceRecord::new(
187            srv_name.clone(),
188            CLASS::IN,
189            360,
190            RData::PTR(PTR(Name::new_unchecked(&self.username))),
191        ));
192
193        for address in &self.listen_addresses {
194            let mut record = TXT::new();
195            record.add_string(address).expect("valid string");
196
197            packet.additional_records.push(ResourceRecord {
198                name: Name::new_unchecked(&self.username),
199                class: CLASS::IN,
200                ttl: 360,
201                rdata: RData::TXT(record),
202                cache_flush: false,
203            });
204        }
205
206        Some(packet.build_bytes_vec().expect("valid packet"))
207    }
208
209    /// Handle inbound response.
210    fn on_inbound_response(&self, packet: Packet) -> Vec<Multiaddr> {
211        tracing::debug!(target: LOG_TARGET, "handle inbound response");
212
213        let names = packet
214            .answers
215            .iter()
216            .filter_map(|answer| {
217                if answer.name != Name::new_unchecked(SERVICE_NAME) {
218                    return None;
219                }
220
221                match answer.rdata {
222                    RData::PTR(PTR(ref name)) if name != &Name::new_unchecked(&self.username) =>
223                        Some(name),
224                    _ => None,
225                }
226            })
227            .collect::<Vec<&Name>>();
228
229        let name = match names.len() {
230            0 => return Vec::new(),
231            _ => {
232                tracing::debug!(
233                    target: LOG_TARGET,
234                    ?names,
235                    "response name"
236                );
237
238                names[0]
239            }
240        };
241
242        packet
243            .additional_records
244            .iter()
245            .flat_map(|record| {
246                if &record.name != name {
247                    return vec![];
248                }
249
250                // TODO: https://github.com/paritytech/litep2p/issues/333
251                // `filter_map` is not necessary as there's at most one entry
252                match &record.rdata {
253                    RData::TXT(text) => text
254                        .attributes()
255                        .values()
256                        .filter_map(|address| address.as_ref().and_then(|inner| inner.parse().ok()))
257                        .collect(),
258                    _ => vec![],
259                }
260            })
261            .collect()
262    }
263
264    /// Setup the socket.
265    fn setup_socket() -> crate::Result<UdpSocket> {
266        let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
267        socket.set_reuse_address(true)?;
268        #[cfg(unix)]
269        socket.set_reuse_port(true)?;
270        socket.bind(
271            &SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), IPV4_MULTICAST_PORT).into(),
272        )?;
273        socket.set_multicast_loop_v4(true)?;
274        socket.set_multicast_ttl_v4(255)?;
275        socket.join_multicast_v4(&IPV4_MULTICAST_ADDRESS, &Ipv4Addr::UNSPECIFIED)?;
276        socket.set_nonblocking(true)?;
277
278        UdpSocket::from_std(net::UdpSocket::from(socket)).map_err(Into::into)
279    }
280
281    /// Event loop for [`Mdns`].
282    pub(crate) async fn start(mut self) {
283        tracing::debug!(target: LOG_TARGET, "starting mdns event loop");
284
285        let mut socket_opt = None;
286
287        loop {
288            let socket = match socket_opt.take() {
289                Some(s) => s,
290                None => {
291                    let _ = self.query_interval.tick().await;
292                    match Self::setup_socket() {
293                        Ok(s) => s,
294                        Err(error) => {
295                            tracing::debug!(
296                                target: LOG_TARGET,
297                                ?error,
298                                "failed to setup mDNS socket, will try again"
299                            );
300                            continue;
301                        }
302                    }
303                }
304            };
305
306            tokio::select! {
307                _ = self.query_interval.tick() => {
308                    tracing::trace!(target: LOG_TARGET, "query interval ticked");
309
310                    if let Err(error) = self.on_outbound_request(&socket).await {
311                        tracing::debug!(target: LOG_TARGET, ?error, "failed to send mdns query");
312                        // Let's recreate the socket
313                        continue;
314                    }
315                },
316
317                result = socket.recv_from(&mut self.receive_buffer) => match result {
318                    Ok((nread, address)) => match Packet::parse(&self.receive_buffer[..nread]) {
319                        Ok(packet) => match packet.has_flags(PacketFlag::RESPONSE) {
320                            true => {
321                                let to_forward = self.on_inbound_response(packet).into_iter().filter_map(|address| {
322                                    self.discovered.insert(address.clone()).then_some(address)
323                                })
324                                .collect::<Vec<_>>();
325
326                                if !to_forward.is_empty() {
327                                    let _ = self.event_tx.send(MdnsEvent::Discovered(to_forward)).await;
328                                }
329                            }
330                            false => if let Some(response) = self.on_inbound_request(packet) {
331                                if let Err(error) = socket
332                                    .send_to(&response, (IPV4_MULTICAST_ADDRESS, IPV4_MULTICAST_PORT))
333                                    .await {
334                                    tracing::debug!(target: LOG_TARGET, ?error, "failed to send mdns response");
335                                    // Let's recreate the socket
336                                    continue;
337                                }
338                            }
339                        }
340                        Err(error) => tracing::debug!(
341                            target: LOG_TARGET,
342                            ?address,
343                            ?error,
344                            ?nread,
345                            "failed to parse mdns packet"
346                        ),
347                    }
348                    Err(error) => {
349                        tracing::debug!(target: LOG_TARGET, ?error, "failed to read from socket");
350                        // Let's recreate the socket
351                        continue;
352                    }
353                },
354            };
355
356            socket_opt = Some(socket);
357        }
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364    use crate::transport::manager::TransportManagerBuilder;
365    use futures::StreamExt;
366    use multiaddr::Protocol;
367
368    #[tokio::test]
369    async fn mdns_works() {
370        let _ = tracing_subscriber::fmt()
371            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
372            .try_init();
373
374        let (config1, mut stream1) = Config::new(Duration::from_secs(5));
375        let manager1 = TransportManagerBuilder::new().build();
376
377        let mdns1 = Mdns::new(
378            manager1.transport_manager_handle(),
379            config1,
380            vec![
381                "/ip6/::1/tcp/8888/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTaaaa"
382                    .parse()
383                    .unwrap(),
384                "/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTaaaa"
385                    .parse()
386                    .unwrap(),
387            ],
388        );
389
390        let (config2, mut stream2) = Config::new(Duration::from_secs(5));
391        let manager2 = TransportManagerBuilder::new().build();
392
393        let mdns2 = Mdns::new(
394            manager2.transport_manager_handle(),
395            config2,
396            vec![
397                "/ip6/::1/tcp/9999/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTbbbb"
398                    .parse()
399                    .unwrap(),
400                "/ip4/127.0.0.1/tcp/9999/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTbbbb"
401                    .parse()
402                    .unwrap(),
403            ],
404        );
405
406        tokio::spawn(mdns1.start());
407        tokio::spawn(mdns2.start());
408
409        let mut peer1_discovered = false;
410        let mut peer2_discovered = false;
411
412        while !peer1_discovered && !peer2_discovered {
413            tokio::select! {
414                event = stream1.next() => match event.unwrap() {
415                    MdnsEvent::Discovered(addrs) => {
416                        if addrs.len() == 2 {
417                            let mut iter = addrs[0].iter();
418
419                            if !std::matches!(iter.next(), Some(Protocol::Ip4(_) | Protocol::Ip6(_))) {
420                                continue
421                            }
422
423                            match iter.next() {
424                                Some(Protocol::Tcp(port)) => {
425                                    if port != 9999 {
426                                        continue
427                                    }
428                                }
429                                _ => continue,
430                            }
431
432                            peer1_discovered = true;
433                        }
434                    }
435                },
436                event = stream2.next() => match event.unwrap() {
437                    MdnsEvent::Discovered(addrs) => {
438                        if addrs.len() == 2 {
439                            let mut iter = addrs[0].iter();
440
441                            if !std::matches!(iter.next(), Some(Protocol::Ip4(_) | Protocol::Ip6(_))) {
442                                continue
443                            }
444
445                            match iter.next() {
446                                Some(Protocol::Tcp(port)) => {
447                                    if port != 8888 {
448                                        continue
449                                    }
450                                }
451                                _ => continue,
452                            }
453
454                            peer2_discovered = true;
455                        }
456                    }
457                }
458            }
459        }
460    }
461}