1use crate::{transport::manager::TransportManagerHandle, DEFAULT_CHANNEL_SIZE};
25
26use futures::Stream;
27use multiaddr::Multiaddr;
28use rand::{distributions::Alphanumeric, Rng};
29use simple_dns::{
30 rdata::{RData, PTR, TXT},
31 Name, Packet, PacketFlag, Question, ResourceRecord, CLASS, QCLASS, QTYPE, TYPE,
32};
33use socket2::{Domain, Protocol, Socket, Type};
34use tokio::{
35 net::UdpSocket,
36 sync::mpsc::{channel, Sender},
37};
38use tokio_stream::wrappers::ReceiverStream;
39
40use std::{
41 collections::HashSet,
42 net,
43 net::{IpAddr, Ipv4Addr, SocketAddr},
44 sync::Arc,
45 time::Duration,
46};
47
48const LOG_TARGET: &str = "litep2p::mdns";
50
51const IPV4_MULTICAST_ADDRESS: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
53
54const IPV4_MULTICAST_PORT: u16 = 5353;
56
57const SERVICE_NAME: &str = "_p2p._udp.local";
59
60pub enum MdnsEvent {
63 Discovered(Vec<Multiaddr>),
65}
66
67pub struct Config {
70 query_interval: Duration,
72
73 tx: Sender<MdnsEvent>,
75}
76
77impl Config {
78 pub fn new(
82 query_interval: Duration,
83 ) -> (Self, Box<dyn Stream<Item = MdnsEvent> + Send + Unpin>) {
84 let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE);
85 (
86 Self { query_interval, tx },
87 Box::new(ReceiverStream::new(rx)),
88 )
89 }
90}
91
92pub(crate) struct Mdns {
94 query_interval: tokio::time::Interval,
96
97 event_tx: Sender<MdnsEvent>,
99
100 _transport_handle: TransportManagerHandle,
102
103 username: String,
105
106 next_query_id: u16,
108
109 receive_buffer: Vec<u8>,
111
112 listen_addresses: Vec<Arc<str>>,
114
115 discovered: HashSet<Multiaddr>,
117}
118
119impl Mdns {
120 pub(crate) fn new(
122 _transport_handle: TransportManagerHandle,
123 config: Config,
124 listen_addresses: Vec<Multiaddr>,
125 ) -> Self {
126 let mut query_interval = tokio::time::interval(config.query_interval);
127 query_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
128
129 Self {
130 _transport_handle,
131 event_tx: config.tx,
132 next_query_id: 1337u16,
133 discovered: HashSet::new(),
134 query_interval,
135 receive_buffer: vec![0u8; 4096],
136 username: rand::thread_rng()
137 .sample_iter(&Alphanumeric)
138 .take(32)
139 .map(char::from)
140 .collect(),
141 listen_addresses: listen_addresses
142 .into_iter()
143 .map(|address| format!("dnsaddr={address}").into())
144 .collect(),
145 }
146 }
147
148 fn next_query_id(&mut self) -> u16 {
150 let query_id = self.next_query_id;
151 self.next_query_id += 1;
152
153 query_id
154 }
155
156 async fn on_outbound_request(&mut self, socket: &UdpSocket) -> crate::Result<()> {
158 tracing::debug!(target: LOG_TARGET, "send outbound query");
159
160 let mut packet = Packet::new_query(self.next_query_id());
161
162 packet.questions.push(Question {
163 qname: Name::new_unchecked(SERVICE_NAME),
164 qtype: QTYPE::TYPE(TYPE::PTR),
165 qclass: QCLASS::CLASS(CLASS::IN),
166 unicast_response: false,
167 });
168
169 socket
170 .send_to(
171 &packet.build_bytes_vec().expect("valid packet"),
172 (IPV4_MULTICAST_ADDRESS, IPV4_MULTICAST_PORT),
173 )
174 .await
175 .map(|_| ())
176 .map_err(From::from)
177 }
178
179 fn on_inbound_request(&self, packet: Packet) -> Option<Vec<u8>> {
181 tracing::debug!(target: LOG_TARGET, ?packet, "handle inbound request");
182
183 let mut packet = Packet::new_reply(packet.id());
184 let srv_name = Name::new_unchecked(SERVICE_NAME);
185
186 packet.answers.push(ResourceRecord::new(
187 srv_name.clone(),
188 CLASS::IN,
189 360,
190 RData::PTR(PTR(Name::new_unchecked(&self.username))),
191 ));
192
193 for address in &self.listen_addresses {
194 let mut record = TXT::new();
195 record.add_string(address).expect("valid string");
196
197 packet.additional_records.push(ResourceRecord {
198 name: Name::new_unchecked(&self.username),
199 class: CLASS::IN,
200 ttl: 360,
201 rdata: RData::TXT(record),
202 cache_flush: false,
203 });
204 }
205
206 Some(packet.build_bytes_vec().expect("valid packet"))
207 }
208
209 fn on_inbound_response(&self, packet: Packet) -> Vec<Multiaddr> {
211 tracing::debug!(target: LOG_TARGET, "handle inbound response");
212
213 let names = packet
214 .answers
215 .iter()
216 .filter_map(|answer| {
217 if answer.name != Name::new_unchecked(SERVICE_NAME) {
218 return None;
219 }
220
221 match answer.rdata {
222 RData::PTR(PTR(ref name)) if name != &Name::new_unchecked(&self.username) =>
223 Some(name),
224 _ => None,
225 }
226 })
227 .collect::<Vec<&Name>>();
228
229 let name = match names.len() {
230 0 => return Vec::new(),
231 _ => {
232 tracing::debug!(
233 target: LOG_TARGET,
234 ?names,
235 "response name"
236 );
237
238 names[0]
239 }
240 };
241
242 packet
243 .additional_records
244 .iter()
245 .flat_map(|record| {
246 if &record.name != name {
247 return vec![];
248 }
249
250 match &record.rdata {
253 RData::TXT(text) => text
254 .attributes()
255 .iter()
256 .filter_map(|(_, address)| {
257 address.as_ref().and_then(|inner| inner.parse().ok())
258 })
259 .collect(),
260 _ => vec![],
261 }
262 })
263 .collect()
264 }
265
266 fn setup_socket() -> crate::Result<UdpSocket> {
268 let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
269 socket.set_reuse_address(true)?;
270 #[cfg(unix)]
271 socket.set_reuse_port(true)?;
272 socket.bind(
273 &SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), IPV4_MULTICAST_PORT).into(),
274 )?;
275 socket.set_multicast_loop_v4(true)?;
276 socket.set_multicast_ttl_v4(255)?;
277 socket.join_multicast_v4(&IPV4_MULTICAST_ADDRESS, &Ipv4Addr::UNSPECIFIED)?;
278 socket.set_nonblocking(true)?;
279
280 UdpSocket::from_std(net::UdpSocket::from(socket)).map_err(Into::into)
281 }
282
283 pub(crate) async fn start(mut self) {
285 tracing::debug!(target: LOG_TARGET, "starting mdns event loop");
286
287 let mut socket_opt = None;
288
289 loop {
290 let socket = match socket_opt.take() {
291 Some(s) => s,
292 None => {
293 let _ = self.query_interval.tick().await;
294 match Self::setup_socket() {
295 Ok(s) => s,
296 Err(error) => {
297 tracing::debug!(
298 target: LOG_TARGET,
299 ?error,
300 "failed to setup mDNS socket, will try again"
301 );
302 continue;
303 }
304 }
305 }
306 };
307
308 tokio::select! {
309 _ = self.query_interval.tick() => {
310 tracing::trace!(target: LOG_TARGET, "query interval ticked");
311
312 if let Err(error) = self.on_outbound_request(&socket).await {
313 tracing::debug!(target: LOG_TARGET, ?error, "failed to send mdns query");
314 continue;
316 }
317 },
318
319 result = socket.recv_from(&mut self.receive_buffer) => match result {
320 Ok((nread, address)) => match Packet::parse(&self.receive_buffer[..nread]) {
321 Ok(packet) => match packet.has_flags(PacketFlag::RESPONSE) {
322 true => {
323 let to_forward = self.on_inbound_response(packet).into_iter().filter_map(|address| {
324 self.discovered.insert(address.clone()).then_some(address)
325 })
326 .collect::<Vec<_>>();
327
328 if !to_forward.is_empty() {
329 let _ = self.event_tx.send(MdnsEvent::Discovered(to_forward)).await;
330 }
331 }
332 false => if let Some(response) = self.on_inbound_request(packet) {
333 if let Err(error) = socket
334 .send_to(&response, (IPV4_MULTICAST_ADDRESS, IPV4_MULTICAST_PORT))
335 .await {
336 tracing::debug!(target: LOG_TARGET, ?error, "failed to send mdns response");
337 continue;
339 }
340 }
341 }
342 Err(error) => tracing::debug!(
343 target: LOG_TARGET,
344 ?address,
345 ?error,
346 ?nread,
347 "failed to parse mdns packet"
348 ),
349 }
350 Err(error) => {
351 tracing::debug!(target: LOG_TARGET, ?error, "failed to read from socket");
352 continue;
354 }
355 },
356 };
357
358 socket_opt = Some(socket);
359 }
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use crate::transport::manager::TransportManagerBuilder;
367 use futures::StreamExt;
368 use multiaddr::Protocol;
369
370 #[tokio::test]
371 async fn mdns_works() {
372 let _ = tracing_subscriber::fmt()
373 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
374 .try_init();
375
376 let (config1, mut stream1) = Config::new(Duration::from_secs(5));
377 let manager1 = TransportManagerBuilder::new().build();
378
379 let mdns1 = Mdns::new(
380 manager1.transport_manager_handle(),
381 config1,
382 vec![
383 "/ip6/::1/tcp/8888/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTaaaa"
384 .parse()
385 .unwrap(),
386 "/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTaaaa"
387 .parse()
388 .unwrap(),
389 ],
390 );
391
392 let (config2, mut stream2) = Config::new(Duration::from_secs(5));
393 let manager2 = TransportManagerBuilder::new().build();
394
395 let mdns2 = Mdns::new(
396 manager2.transport_manager_handle(),
397 config2,
398 vec![
399 "/ip6/::1/tcp/9999/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTbbbb"
400 .parse()
401 .unwrap(),
402 "/ip4/127.0.0.1/tcp/9999/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTbbbb"
403 .parse()
404 .unwrap(),
405 ],
406 );
407
408 tokio::spawn(mdns1.start());
409 tokio::spawn(mdns2.start());
410
411 let mut peer1_discovered = false;
412 let mut peer2_discovered = false;
413
414 while !peer1_discovered && !peer2_discovered {
415 tokio::select! {
416 event = stream1.next() => match event.unwrap() {
417 MdnsEvent::Discovered(addrs) => {
418 if addrs.len() == 2 {
419 let mut iter = addrs[0].iter();
420
421 if !std::matches!(iter.next(), Some(Protocol::Ip4(_) | Protocol::Ip6(_))) {
422 continue
423 }
424
425 match iter.next() {
426 Some(Protocol::Tcp(port)) => {
427 if port != 9999 {
428 continue
429 }
430 }
431 _ => continue,
432 }
433
434 peer1_discovered = true;
435 }
436 }
437 },
438 event = stream2.next() => match event.unwrap() {
439 MdnsEvent::Discovered(addrs) => {
440 if addrs.len() == 2 {
441 let mut iter = addrs[0].iter();
442
443 if !std::matches!(iter.next(), Some(Protocol::Ip4(_) | Protocol::Ip6(_))) {
444 continue
445 }
446
447 match iter.next() {
448 Some(Protocol::Tcp(port)) => {
449 if port != 8888 {
450 continue
451 }
452 }
453 _ => continue,
454 }
455
456 peer2_discovered = true;
457 }
458 }
459 }
460 }
461 }
462 }
463}