1use crate::{error::Error, transport::manager::TransportManagerHandle, DEFAULT_CHANNEL_SIZE};
25
26use futures::Stream;
27use multiaddr::Multiaddr;
28use rand::{distributions::Alphanumeric, Rng};
29use simple_dns::{
30 rdata::{RData, PTR, TXT},
31 Name, Packet, PacketFlag, Question, ResourceRecord, CLASS, QCLASS, QTYPE, TYPE,
32};
33use socket2::{Domain, Protocol, Socket, Type};
34use tokio::{
35 net::UdpSocket,
36 sync::mpsc::{channel, Sender},
37};
38use tokio_stream::wrappers::ReceiverStream;
39
40use std::{
41 collections::HashSet,
42 net,
43 net::{IpAddr, Ipv4Addr, SocketAddr},
44 sync::Arc,
45 time::Duration,
46};
47
48const LOG_TARGET: &str = "litep2p::mdns";
50
51const IPV4_MULTICAST_ADDRESS: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
53
54const IPV4_MULTICAST_PORT: u16 = 5353;
56
57const SERVICE_NAME: &str = "_p2p._udp.local";
59
60pub enum MdnsEvent {
63 Discovered(Vec<Multiaddr>),
65}
66
67pub struct Config {
70 query_interval: Duration,
72
73 tx: Sender<MdnsEvent>,
75}
76
77impl Config {
78 pub fn new(
82 query_interval: Duration,
83 ) -> (Self, Box<dyn Stream<Item = MdnsEvent> + Send + Unpin>) {
84 let (tx, rx) = channel(DEFAULT_CHANNEL_SIZE);
85 (
86 Self { query_interval, tx },
87 Box::new(ReceiverStream::new(rx)),
88 )
89 }
90}
91
92pub(crate) struct Mdns {
94 socket: UdpSocket,
96
97 query_interval: Duration,
99
100 event_tx: Sender<MdnsEvent>,
102
103 _transport_handle: TransportManagerHandle,
105
106 username: String,
108
109 next_query_id: u16,
111
112 receive_buffer: Vec<u8>,
114
115 listen_addresses: Vec<Arc<str>>,
117
118 discovered: HashSet<Multiaddr>,
120}
121
122impl Mdns {
123 pub(crate) fn new(
125 _transport_handle: TransportManagerHandle,
126 config: Config,
127 listen_addresses: Vec<Multiaddr>,
128 ) -> crate::Result<Self> {
129 let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
130 socket.set_reuse_address(true)?;
131 #[cfg(unix)]
132 socket.set_reuse_port(true)?;
133 socket.bind(
134 &SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), IPV4_MULTICAST_PORT).into(),
135 )?;
136 socket.set_multicast_loop_v4(true)?;
137 socket.set_multicast_ttl_v4(255)?;
138 socket.join_multicast_v4(&IPV4_MULTICAST_ADDRESS, &Ipv4Addr::UNSPECIFIED)?;
139 socket.set_nonblocking(true)?;
140
141 Ok(Self {
142 _transport_handle,
143 event_tx: config.tx,
144 next_query_id: 1337u16,
145 discovered: HashSet::new(),
146 query_interval: config.query_interval,
147 receive_buffer: vec![0u8; 4096],
148 username: rand::thread_rng()
149 .sample_iter(&Alphanumeric)
150 .take(32)
151 .map(char::from)
152 .collect(),
153 socket: UdpSocket::from_std(net::UdpSocket::from(socket))?,
154 listen_addresses: listen_addresses
155 .into_iter()
156 .map(|address| format!("dnsaddr={address}").into())
157 .collect(),
158 })
159 }
160
161 fn next_query_id(&mut self) -> u16 {
163 let query_id = self.next_query_id;
164 self.next_query_id += 1;
165
166 query_id
167 }
168
169 async fn on_outbound_request(&mut self) -> crate::Result<()> {
171 tracing::debug!(target: LOG_TARGET, "send outbound query");
172
173 let mut packet = Packet::new_query(self.next_query_id());
174
175 packet.questions.push(Question {
176 qname: Name::new_unchecked(SERVICE_NAME),
177 qtype: QTYPE::TYPE(TYPE::PTR),
178 qclass: QCLASS::CLASS(CLASS::IN),
179 unicast_response: false,
180 });
181
182 self.socket
183 .send_to(
184 &packet.build_bytes_vec().expect("valid packet"),
185 (IPV4_MULTICAST_ADDRESS, IPV4_MULTICAST_PORT),
186 )
187 .await
188 .map(|_| ())
189 .map_err(From::from)
190 }
191
192 fn on_inbound_request(&self, packet: Packet) -> Option<Vec<u8>> {
194 tracing::debug!(target: LOG_TARGET, ?packet, "handle inbound request");
195
196 let mut packet = Packet::new_reply(packet.id());
197 let srv_name = Name::new_unchecked(SERVICE_NAME);
198
199 packet.answers.push(ResourceRecord::new(
200 srv_name.clone(),
201 CLASS::IN,
202 360,
203 RData::PTR(PTR(Name::new_unchecked(&self.username))),
204 ));
205
206 for address in &self.listen_addresses {
207 let mut record = TXT::new();
208 record.add_string(address).expect("valid string");
209
210 packet.additional_records.push(ResourceRecord {
211 name: Name::new_unchecked(&self.username),
212 class: CLASS::IN,
213 ttl: 360,
214 rdata: RData::TXT(record),
215 cache_flush: false,
216 });
217 }
218
219 Some(packet.build_bytes_vec().expect("valid packet"))
220 }
221
222 fn on_inbound_response(&self, packet: Packet) -> Vec<Multiaddr> {
224 tracing::debug!(target: LOG_TARGET, "handle inbound response");
225
226 let names = packet
227 .answers
228 .iter()
229 .filter_map(|answer| {
230 if answer.name != Name::new_unchecked(SERVICE_NAME) {
231 return None;
232 }
233
234 match answer.rdata {
235 RData::PTR(PTR(ref name)) if name != &Name::new_unchecked(&self.username) =>
236 Some(name),
237 _ => None,
238 }
239 })
240 .collect::<Vec<&Name>>();
241
242 let name = match names.len() {
243 0 => return Vec::new(),
244 _ => {
245 tracing::debug!(
246 target: LOG_TARGET,
247 ?names,
248 "response name"
249 );
250
251 names[0]
252 }
253 };
254
255 packet
256 .additional_records
257 .iter()
258 .flat_map(|record| {
259 if &record.name != name {
260 return vec![];
261 }
262
263 match &record.rdata {
265 RData::TXT(text) => text
266 .attributes()
267 .iter()
268 .filter_map(|(_, address)| {
269 address.as_ref().and_then(|inner| inner.parse().ok())
270 })
271 .collect(),
272 _ => vec![],
273 }
274 })
275 .collect()
276 }
277
278 pub(crate) async fn start(mut self) -> crate::Result<()> {
280 tracing::debug!(target: LOG_TARGET, "starting mdns event loop");
281
282 self.on_outbound_request().await?;
286
287 loop {
288 tokio::select! {
289 _ = tokio::time::sleep(self.query_interval) => {
290 tracing::trace!(target: LOG_TARGET, "timeout expired");
291
292 if let Err(error) = self.on_outbound_request().await {
293 tracing::error!(target: LOG_TARGET, ?error, "failed to send mdns query");
294 return Err(error);
295 }
296 }
297 result = self.socket.recv_from(&mut self.receive_buffer) => match result {
298 Ok((nread, address)) => match Packet::parse(&self.receive_buffer[..nread]) {
299 Ok(packet) => match packet.has_flags(PacketFlag::RESPONSE) {
300 true => {
301 let to_forward = self.on_inbound_response(packet).into_iter().filter_map(|address| {
302 self.discovered.insert(address.clone()).then_some(address)
303 })
304 .collect::<Vec<_>>();
305
306 if !to_forward.is_empty() {
307 let _ = self.event_tx.send(MdnsEvent::Discovered(to_forward)).await;
308 }
309 }
310 false => if let Some(response) = self.on_inbound_request(packet) {
311 self.socket
312 .send_to(&response, (IPV4_MULTICAST_ADDRESS, IPV4_MULTICAST_PORT))
313 .await?;
314 }
315 }
316 Err(error) => tracing::debug!(
317 target: LOG_TARGET,
318 ?address,
319 ?error,
320 ?nread,
321 "failed to parse mdns packet"
322 ),
323 }
324 Err(error) => {
325 tracing::error!(target: LOG_TARGET, ?error, "failed to read from socket");
326 return Err(Error::from(error));
327 }
328 },
329 }
330 }
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use crate::{
338 crypto::ed25519::Keypair,
339 transport::manager::{limits::ConnectionLimitsConfig, TransportManager},
340 BandwidthSink,
341 };
342 use futures::StreamExt;
343 use multiaddr::Protocol;
344
345 #[tokio::test]
346 async fn mdns_works() {
347 let _ = tracing_subscriber::fmt()
348 .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
349 .try_init();
350
351 let (config1, mut stream1) = Config::new(Duration::from_secs(5));
352 let (_manager1, handle1) = TransportManager::new(
353 Keypair::generate(),
354 HashSet::new(),
355 BandwidthSink::new(),
356 8usize,
357 ConnectionLimitsConfig::default(),
358 );
359
360 let mdns1 = Mdns::new(
361 handle1,
362 config1,
363 vec![
364 "/ip6/::1/tcp/8888/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTaaaa"
365 .parse()
366 .unwrap(),
367 "/ip4/127.0.0.1/tcp/8888/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTaaaa"
368 .parse()
369 .unwrap(),
370 ],
371 )
372 .unwrap();
373
374 let (config2, mut stream2) = Config::new(Duration::from_secs(5));
375 let (_manager1, handle2) = TransportManager::new(
376 Keypair::generate(),
377 HashSet::new(),
378 BandwidthSink::new(),
379 8usize,
380 ConnectionLimitsConfig::default(),
381 );
382
383 let mdns2 = Mdns::new(
384 handle2,
385 config2,
386 vec![
387 "/ip6/::1/tcp/9999/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTbbbb"
388 .parse()
389 .unwrap(),
390 "/ip4/127.0.0.1/tcp/9999/p2p/12D3KooWNP463TyS3vUpmekjjZ2dg7xy1WHNMM7MqfsMevMTbbbb"
391 .parse()
392 .unwrap(),
393 ],
394 )
395 .unwrap();
396
397 tokio::spawn(mdns1.start());
398 tokio::spawn(mdns2.start());
399
400 let mut peer1_discovered = false;
401 let mut peer2_discovered = false;
402
403 while !peer1_discovered && !peer2_discovered {
404 tokio::select! {
405 event = stream1.next() => match event.unwrap() {
406 MdnsEvent::Discovered(addrs) => {
407 if addrs.len() == 2 {
408 let mut iter = addrs[0].iter();
409
410 if !std::matches!(iter.next(), Some(Protocol::Ip4(_) | Protocol::Ip6(_))) {
411 continue
412 }
413
414 match iter.next() {
415 Some(Protocol::Tcp(port)) => {
416 if port != 9999 {
417 continue
418 }
419 }
420 _ => continue,
421 }
422
423 peer1_discovered = true;
424 }
425 }
426 },
427 event = stream2.next() => match event.unwrap() {
428 MdnsEvent::Discovered(addrs) => {
429 if addrs.len() == 2 {
430 let mut iter = addrs[0].iter();
431
432 if !std::matches!(iter.next(), Some(Protocol::Ip4(_) | Protocol::Ip6(_))) {
433 continue
434 }
435
436 match iter.next() {
437 Some(Protocol::Tcp(port)) => {
438 if port != 8888 {
439 continue
440 }
441 }
442 _ => continue,
443 }
444
445 peer2_discovered = true;
446 }
447 }
448 }
449 }
450 }
451 }
452}