1use crate::{transport::manager::TransportManagerHandle, DEFAULT_CHANNEL_SIZE};
25
26use futures::Stream;
27use multiaddr::Multiaddr;
28use rand::{distributions::Alphanumeric, Rng};
29use simple_dns::{
30 rdata::{RData, PTR, TXT},
31 Name, Packet, PacketFlag, Question, ResourceRecord, CLASS, QCLASS, QTYPE, TYPE,
32};
33use socket2::{Domain, Protocol, Socket, Type};
34use tokio::{
35 net::UdpSocket,
36 sync::mpsc::{channel, Sender},
37};
38use tokio_stream::wrappers::ReceiverStream;
39
40use std::{
41 collections::HashSet,
42 net,
43 net::{IpAddr, Ipv4Addr, SocketAddr},
44 sync::Arc,
45 time::Duration,
46};
47
48const LOG_TARGET: &str = "litep2p::mdns";
50
51const IPV4_MULTICAST_ADDRESS: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
53
54const IPV4_MULTICAST_PORT: u16 = 5353;
56
57const SERVICE_NAME: &str = "_p2p._udp.local";
59
60pub enum MdnsEvent {
63 Discovered(Vec<Multiaddr>),
65}
66
67pub struct Config {
70 query_interval: Duration,
72
73 tx: Sender<MdnsEvent>,
75}
76
77impl Config {
78 pub fn new(
82 query_interval: Duration,
83 ) -> (Self, Box<dyn Stream<Item = MdnsEvent> + Send + Unpin>) {
84 let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE);
85 (
86 Self { query_interval, tx },
87 Box::new(ReceiverStream::new(rx)),
88 )
89 }
90}
91
92pub(crate) struct Mdns {
94 query_interval: tokio::time::Interval,
96
97 event_tx: Sender<MdnsEvent>,
99
100 _transport_handle: TransportManagerHandle,
102
103 username: String,
105
106 next_query_id: u16,
108
109 receive_buffer: Vec<u8>,
111
112 listen_addresses: Vec<Arc<str>>,
114
115 discovered: HashSet<Multiaddr>,
117}
118
119impl Mdns {
120 pub(crate) fn new(
122 _transport_handle: TransportManagerHandle,
123 config: Config,
124 listen_addresses: Vec<Multiaddr>,
125 ) -> Self {
126 let mut query_interval = tokio::time::interval(config.query_interval);
127 query_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
128
129 Self {
130 _transport_handle,
131 event_tx: config.tx,
132 next_query_id: 1337u16,
133 discovered: HashSet::new(),
134 query_interval,
135 receive_buffer: vec![0u8; 4096],
136 username: rand::thread_rng()
137 .sample_iter(&Alphanumeric)
138 .take(32)
139 .map(char::from)
140 .collect(),
141 listen_addresses: listen_addresses
142 .into_iter()
143 .map(|address| format!("dnsaddr={address}").into())
144 .collect(),
145 }
146 }
147
148 fn next_query_id(&mut self) -> u16 {
150 let query_id = self.next_query_id;
151 self.next_query_id += 1;
152
153 query_id
154 }
155
156 async fn on_outbound_request(&mut self, socket: &UdpSocket) -> crate::Result<()> {
158 tracing::debug!(target: LOG_TARGET, "send outbound query");
159
160 let mut packet = Packet::new_query(self.next_query_id());
161
162 packet.questions.push(Question {
163 qname: Name::new_unchecked(SERVICE_NAME),
164 qtype: QTYPE::TYPE(TYPE::PTR),
165 qclass: QCLASS::CLASS(CLASS::IN),
166 unicast_response: false,
167 });
168
169 socket
170 .send_to(
171 &packet.build_bytes_vec().expect("valid packet"),
172 (IPV4_MULTICAST_ADDRESS, IPV4_MULTICAST_PORT),
173 )
174 .await
175 .map(|_| ())
176 .map_err(From::from)
177 }
178
179 fn on_inbound_request(&self, packet: Packet) -> Option<Vec<u8>> {
181 tracing::debug!(target: LOG_TARGET, ?packet, "handle inbound request");
182
183 let mut packet = Packet::new_reply(packet.id());
184 let srv_name = Name::new_unchecked(SERVICE_NAME);
185
186 packet.answers.push(ResourceRecord::new(
187 srv_name.clone(),
188 CLASS::IN,
189 360,
190 RData::PTR(PTR(Name::new_unchecked(&self.username))),
191 ));
192
193 for address in &self.listen_addresses {
194 let mut record = TXT::new();
195 record.add_string(address).expect("valid string");
196
197 packet.additional_records.push(ResourceRecord {
198 name: Name::new_unchecked(&self.username),
199 class: CLASS::IN,
200 ttl: 360,
201 rdata: RData::TXT(record),
202 cache_flush: false,
203 });
204 }
205
206 Some(packet.build_bytes_vec().expect("valid packet"))
207 }
208
209 fn on_inbound_response(&self, packet: Packet) -> Vec<Multiaddr> {
211 tracing::debug!(target: LOG_TARGET, "handle inbound response");
212
213 let names = packet
214 .answers
215 .iter()
216 .filter_map(|answer| {
217 if answer.name != Name::new_unchecked(SERVICE_NAME) {
218 return None;
219 }
220
221 match answer.rdata {
222 RData::PTR(PTR(ref name)) if name != &Name::new_unchecked(&self.username) =>
223 Some(name),
224 _ => None,
225 }
226 })
227 .collect::<Vec<&Name>>();
228
229 let name = match names.len() {
230 0 => return Vec::new(),
231 _ => {
232 tracing::debug!(
233 target: LOG_TARGET,
234 ?names,
235 "response name"
236 );
237
238 names[0]
239 }
240 };
241
242 packet
243 .additional_records
244 .iter()
245 .flat_map(|record| {
246 if &record.name != name {
247 return vec![];
248 }
249
250 match &record.rdata {
253 RData::TXT(text) => text
254 .attributes()
255 .values()
256 .filter_map(|address| address.as_ref().and_then(|inner| inner.parse().ok()))
257 .collect(),
258 _ => vec![],
259 }
260 })
261 .collect()
262 }
263
264 fn setup_socket() -> crate::Result<UdpSocket> {
266 let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
267 socket.set_reuse_address(true)?;
268 #[cfg(unix)]
269 socket.set_reuse_port(true)?;
270 socket.bind(
271 &SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), IPV4_MULTICAST_PORT).into(),
272 )?;
273 socket.set_multicast_loop_v4(true)?;
274 socket.set_multicast_ttl_v4(255)?;
275 socket.join_multicast_v4(&IPV4_MULTICAST_ADDRESS, &Ipv4Addr::UNSPECIFIED)?;
276 socket.set_nonblocking(true)?;
277
278 UdpSocket::from_std(net::UdpSocket::from(socket)).map_err(Into::into)
279 }
280
281 pub(crate) async fn start(mut self) {
283 tracing::debug!(target: LOG_TARGET, "starting mdns event loop");
284
285 let mut socket_opt = None;
286
287 loop {
288 let socket = match socket_opt.take() {
289 Some(s) => s,
290 None => {
291 let _ = self.query_interval.tick().await;
292 match Self::setup_socket() {
293 Ok(s) => s,
294 Err(error) => {
295 tracing::debug!(
296 target: LOG_TARGET,
297 ?error,
298 "failed to setup mDNS socket, will try again"
299 );
300 continue;
301 }
302 }
303 }
304 };
305
306 tokio::select! {
307 _ = self.query_interval.tick() => {
308 tracing::trace!(target: LOG_TARGET, "query interval ticked");
309
310 if let Err(error) = self.on_outbound_request(&socket).await {
311 tracing::debug!(target: LOG_TARGET, ?error, "failed to send mdns query");
312 continue;
314 }
315 },
316
317 result = socket.recv_from(&mut self.receive_buffer) => match result {
318 Ok((nread, address)) => match Packet::parse(&self.receive_buffer[..nread]) {
319 Ok(packet) => match packet.has_flags(PacketFlag::RESPONSE) {
320 true => {
321 let to_forward = self.on_inbound_response(packet).into_iter().filter_map(|address| {
322 self.discovered.insert(address.clone()).then_some(address)
323 })
324 .collect::<Vec<_>>();
325
326 if !to_forward.is_empty() {
327 let _ = self.event_tx.send(MdnsEvent::Discovered(to_forward)).await;
328 }
329 }
330 false => if let Some(response) = self.on_inbound_request(packet) {
331 if let Err(error) = socket
332 .send_to(&response, (IPV4_MULTICAST_ADDRESS, IPV4_MULTICAST_PORT))
333 .await {
334 tracing::debug!(target: LOG_TARGET, ?error, "failed to send mdns response");
335 continue;
337 }
338 }
339 }
340 Err(error) => tracing::debug!(
341 target: LOG_TARGET,
342 ?address,
343 ?error,
344 ?nread,
345 "failed to parse mdns packet"
346 ),
347 }
348 Err(error) => {
349 tracing::debug!(target: LOG_TARGET, ?error, "failed to read from socket");
350 continue;
352 }
353 },
354 };
355
356 socket_opt = Some(socket);
357 }
358 }
359}
360
361#[cfg(test)]
362mod tests {
363 use super::*;
364 use crate::transport::manager::TransportManagerBuilder;
365 use futures::StreamExt;
366 use multiaddr::Protocol;
367
368 #[tokio::test]
369 async fn mdns_works() {
370 let _ = tracing_subscriber::fmt()
371 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
372 .try_init();
373
374 let (config1, mut stream1) = Config::new(Duration::from_secs(5));
375 let manager1 = TransportManagerBuilder::new().build();
376
377 let mdns1 = Mdns::new(
378 manager1.transport_manager_handle(),
379 config1,
380 vec![
381 "/ip6/::1/tcp/8888/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTaaaa"
382 .parse()
383 .unwrap(),
384 "/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTaaaa"
385 .parse()
386 .unwrap(),
387 ],
388 );
389
390 let (config2, mut stream2) = Config::new(Duration::from_secs(5));
391 let manager2 = TransportManagerBuilder::new().build();
392
393 let mdns2 = Mdns::new(
394 manager2.transport_manager_handle(),
395 config2,
396 vec![
397 "/ip6/::1/tcp/9999/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTbbbb"
398 .parse()
399 .unwrap(),
400 "/ip4/127.0.0.1/tcp/9999/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTbbbb"
401 .parse()
402 .unwrap(),
403 ],
404 );
405
406 tokio::spawn(mdns1.start());
407 tokio::spawn(mdns2.start());
408
409 let mut peer1_discovered = false;
410 let mut peer2_discovered = false;
411
412 while !peer1_discovered && !peer2_discovered {
413 tokio::select! {
414 event = stream1.next() => match event.unwrap() {
415 MdnsEvent::Discovered(addrs) => {
416 if addrs.len() == 2 {
417 let mut iter = addrs[0].iter();
418
419 if !std::matches!(iter.next(), Some(Protocol::Ip4(_) | Protocol::Ip6(_))) {
420 continue
421 }
422
423 match iter.next() {
424 Some(Protocol::Tcp(port)) => {
425 if port != 9999 {
426 continue
427 }
428 }
429 _ => continue,
430 }
431
432 peer1_discovered = true;
433 }
434 }
435 },
436 event = stream2.next() => match event.unwrap() {
437 MdnsEvent::Discovered(addrs) => {
438 if addrs.len() == 2 {
439 let mut iter = addrs[0].iter();
440
441 if !std::matches!(iter.next(), Some(Protocol::Ip4(_) | Protocol::Ip6(_))) {
442 continue
443 }
444
445 match iter.next() {
446 Some(Protocol::Tcp(port)) => {
447 if port != 8888 {
448 continue
449 }
450 }
451 _ => continue,
452 }
453
454 peer2_discovered = true;
455 }
456 }
457 }
458 }
459 }
460 }
461}