1use std::future::Future;
13use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17use std::time::Instant;
18
19use futures_util::{FutureExt, future, future::Either};
20use tracing::debug;
21
22use crate::proto::op::Query;
23use crate::proto::rr::{Name, RData, Record, RecordType};
24use crate::proto::xfer::{DnsHandle, DnsRequestOptions};
25
26use crate::caching_client::CachingClient;
27use crate::config::LookupIpStrategy;
28use crate::dns_lru::MAX_TTL;
29use crate::error::*;
30use crate::hosts::Hosts;
31use crate::lookup::{Lookup, LookupIntoIter, LookupIter};
32
33#[derive(Debug, Clone)]
37pub struct LookupIp(Lookup);
38
39impl LookupIp {
40 pub fn iter(&self) -> LookupIpIter<'_> {
44 LookupIpIter(self.0.iter())
45 }
46
47 pub fn query(&self) -> &Query {
49 self.0.query()
50 }
51
52 pub fn valid_until(&self) -> Instant {
54 self.0.valid_until()
55 }
56
57 pub fn as_lookup(&self) -> &Lookup {
61 &self.0
62 }
63}
64
65impl From<Lookup> for LookupIp {
66 fn from(lookup: Lookup) -> Self {
67 Self(lookup)
68 }
69}
70
71impl From<LookupIp> for Lookup {
72 fn from(lookup: LookupIp) -> Self {
73 lookup.0
74 }
75}
76
77pub struct LookupIpIter<'i>(pub(crate) LookupIter<'i>);
79
80impl Iterator for LookupIpIter<'_> {
81 type Item = IpAddr;
82
83 fn next(&mut self) -> Option<Self::Item> {
84 let iter: &mut _ = &mut self.0;
85 iter.find_map(|rdata| match rdata {
86 RData::A(ip) => Some(IpAddr::from(Ipv4Addr::from(*ip))),
87 RData::AAAA(ip) => Some(IpAddr::from(Ipv6Addr::from(*ip))),
88 _ => None,
89 })
90 }
91}
92
93impl IntoIterator for LookupIp {
94 type Item = IpAddr;
95 type IntoIter = LookupIpIntoIter;
96
97 fn into_iter(self) -> Self::IntoIter {
99 LookupIpIntoIter(self.0.into_iter())
100 }
101}
102
103pub struct LookupIpIntoIter(LookupIntoIter);
105
106impl Iterator for LookupIpIntoIter {
107 type Item = IpAddr;
108
109 fn next(&mut self) -> Option<Self::Item> {
110 let iter: &mut _ = &mut self.0;
111 iter.find_map(|rdata| match rdata {
112 RData::A(ip) => Some(IpAddr::from(Ipv4Addr::from(ip))),
113 RData::AAAA(ip) => Some(IpAddr::from(Ipv6Addr::from(ip))),
114 _ => None,
115 })
116 }
117}
118
119pub struct LookupIpFuture<C: DnsHandle + 'static> {
123 client_cache: CachingClient<C>,
124 names: Vec<Name>,
125 strategy: LookupIpStrategy,
126 options: DnsRequestOptions,
127 query: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>>,
128 hosts: Arc<Hosts>,
129 finally_ip_addr: Option<RData>,
130}
131
132impl<C: DnsHandle + 'static> LookupIpFuture<C> {
133 pub fn lookup(
141 names: Vec<Name>,
142 strategy: LookupIpStrategy,
143 client_cache: CachingClient<C>,
144 options: DnsRequestOptions,
145 hosts: Arc<Hosts>,
146 finally_ip_addr: Option<RData>,
147 ) -> Self {
148 let empty =
149 ResolveError::from(ResolveErrorKind::Message("can not lookup IPs for no names"));
150 Self {
151 names,
152 strategy,
153 client_cache,
154 query: future::err(empty).boxed(),
157 options,
158 hosts,
159 finally_ip_addr,
160 }
161 }
162}
163
164impl<C: DnsHandle + 'static> Future for LookupIpFuture<C> {
165 type Output = Result<LookupIp, ResolveError>;
166
167 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
168 loop {
169 let query = self.query.as_mut().poll(cx);
171
172 let should_retry = match &query {
174 Poll::Pending => return Poll::Pending,
176 Poll::Ready(Ok(lookup)) => lookup.is_empty(),
180 Poll::Ready(Err(_)) => true,
182 };
183
184 if !should_retry {
185 return query.map(|f| f.map(LookupIp::from));
189 }
190
191 if let Some(name) = self.names.pop() {
192 self.query = LookupContext {
195 client: self.client_cache.clone(),
196 options: self.options,
197 hosts: self.hosts.clone(),
198 }
199 .strategic_lookup(name, self.strategy)
200 .boxed();
201 continue;
204 } else if let Some(ip_addr) = self.finally_ip_addr.take() {
205 let record = Record::from_rdata(Name::new(), MAX_TTL, ip_addr);
208 let lookup = Lookup::new_with_max_ttl(Query::new(), Arc::from([record]));
209 return Poll::Ready(Ok(lookup.into()));
210 }
211
212 return query.map(|f| f.map(LookupIp::from));
217 }
218 }
219}
220
221#[derive(Clone)]
222struct LookupContext<C: DnsHandle> {
223 client: CachingClient<C>,
224 options: DnsRequestOptions,
225 hosts: Arc<Hosts>,
226}
227
228impl<C: DnsHandle> LookupContext<C> {
229 async fn strategic_lookup(
231 self,
232 name: Name,
233 strategy: LookupIpStrategy,
234 ) -> Result<Lookup, ResolveError> {
235 match strategy {
236 LookupIpStrategy::Ipv4Only => self.ipv4_only(name).await,
237 LookupIpStrategy::Ipv6Only => self.ipv6_only(name).await,
238 LookupIpStrategy::Ipv4AndIpv6 => self.ipv4_and_ipv6(name).await,
239 LookupIpStrategy::Ipv6thenIpv4 => self.ipv6_then_ipv4(name).await,
240 LookupIpStrategy::Ipv4thenIpv6 => self.ipv4_then_ipv6(name).await,
241 }
242 }
243
244 async fn ipv4_only(&self, name: Name) -> Result<Lookup, ResolveError> {
246 self.hosts_lookup(Query::query(name, RecordType::A)).await
247 }
248
249 async fn ipv6_only(&self, name: Name) -> Result<Lookup, ResolveError> {
251 self.hosts_lookup(Query::query(name, RecordType::AAAA))
252 .await
253 }
254
255 async fn ipv4_and_ipv6(&self, name: Name) -> Result<Lookup, ResolveError> {
258 let sel_res = future::select(
259 self.hosts_lookup(Query::query(name.clone(), RecordType::A))
260 .boxed(),
261 self.hosts_lookup(Query::query(name, RecordType::AAAA))
262 .boxed(),
263 )
264 .await;
265
266 let (ips, remaining_query) = match sel_res {
267 Either::Left(ips_and_remaining) => ips_and_remaining,
268 Either::Right(ips_and_remaining) => ips_and_remaining,
269 };
270
271 let next_ips = remaining_query.await;
272
273 match (ips, next_ips) {
274 (Ok(ips), Ok(next_ips)) => {
275 let ips = ips.append(next_ips);
277 Ok(ips)
278 }
279 (Ok(ips), Err(e)) | (Err(e), Ok(ips)) => {
280 debug!(
281 "one of ipv4 or ipv6 lookup failed in ipv4_and_ipv6 strategy: {}",
282 e
283 );
284 Ok(ips)
285 }
286 (Err(e1), Err(e2)) => {
287 debug!(
288 "both of ipv4 or ipv6 lookup failed in ipv4_and_ipv6 strategy e1: {}, e2: {}",
289 e1, e2
290 );
291 Err(e1)
292 }
293 }
294 }
295
296 async fn ipv6_then_ipv4(&self, name: Name) -> Result<Lookup, ResolveError> {
298 self.rt_then_swap(name, RecordType::AAAA, RecordType::A)
299 .await
300 }
301
302 async fn ipv4_then_ipv6(&self, name: Name) -> Result<Lookup, ResolveError> {
304 self.rt_then_swap(name, RecordType::A, RecordType::AAAA)
305 .await
306 }
307
308 async fn rt_then_swap(
310 &self,
311 name: Name,
312 first_type: RecordType,
313 second_type: RecordType,
314 ) -> Result<Lookup, ResolveError> {
315 let res = self
316 .hosts_lookup(Query::query(name.clone(), first_type))
317 .await;
318
319 match res {
320 Ok(ips) => {
321 if ips.is_empty() {
322 self.hosts_lookup(Query::query(name.clone(), second_type))
324 .await
325 } else {
326 Ok(ips)
327 }
328 }
329 Err(_) => {
330 self.hosts_lookup(Query::query(name.clone(), second_type))
331 .await
332 }
333 }
334 }
335
336 async fn hosts_lookup(&self, query: Query) -> Result<Lookup, ResolveError> {
338 match self.hosts.lookup_static_host(&query) {
339 Some(lookup) => Ok(lookup),
340 None => self.client.lookup(query, self.options).await,
341 }
342 }
343}
344
345#[cfg(test)]
346pub(crate) mod tests {
347 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
348 use std::sync::{Arc, Mutex};
349
350 use futures_executor::block_on;
351 use futures_util::future;
352 use futures_util::stream::{Stream, once};
353 use test_support::subscribe;
354
355 use crate::proto::ProtoError;
356 use crate::proto::op::Message;
357 use crate::proto::rr::{Name, RData, Record};
358 use crate::proto::xfer::{DnsHandle, DnsRequest, DnsResponse};
359
360 use super::*;
361
362 #[derive(Clone)]
363 pub(crate) struct MockDnsHandle {
364 messages: Arc<Mutex<Vec<Result<DnsResponse, ProtoError>>>>,
365 }
366
367 impl DnsHandle for MockDnsHandle {
368 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send + Unpin>>;
369
370 fn send<R: Into<DnsRequest>>(&self, _: R) -> Self::Response {
371 Box::pin(once(future::ready(
372 self.messages.lock().unwrap().pop().unwrap_or_else(empty),
373 )))
374 }
375 }
376
377 pub(crate) fn v4_message() -> Result<DnsResponse, ProtoError> {
378 let mut message = Message::new();
379 message.add_query(Query::query(Name::root(), RecordType::A));
380 message.insert_answers(vec![Record::from_rdata(
381 Name::root(),
382 86400,
383 RData::A(Ipv4Addr::LOCALHOST.into()),
384 )]);
385
386 let resp = DnsResponse::from_message(message).unwrap();
387 assert!(resp.contains_answer());
388 Ok(resp)
389 }
390
391 pub(crate) fn v6_message() -> Result<DnsResponse, ProtoError> {
392 let mut message = Message::new();
393 message.add_query(Query::query(Name::root(), RecordType::AAAA));
394 message.insert_answers(vec![Record::from_rdata(
395 Name::root(),
396 86400,
397 RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()),
398 )]);
399
400 let resp = DnsResponse::from_message(message).unwrap();
401 assert!(resp.contains_answer());
402 Ok(resp)
403 }
404
405 pub(crate) fn empty() -> Result<DnsResponse, ProtoError> {
406 Ok(DnsResponse::from_message(Message::new()).unwrap())
407 }
408
409 pub(crate) fn error() -> Result<DnsResponse, ProtoError> {
410 Err(ProtoError::from("forced test failure"))
411 }
412
413 pub(crate) fn mock(messages: Vec<Result<DnsResponse, ProtoError>>) -> MockDnsHandle {
414 MockDnsHandle {
415 messages: Arc::new(Mutex::new(messages)),
416 }
417 }
418
419 #[test]
420 fn test_ipv4_only_strategy() {
421 subscribe();
422
423 let cx = LookupContext {
424 client: CachingClient::new(0, mock(vec![v4_message()]), false),
425 options: DnsRequestOptions::default(),
426 hosts: Arc::new(Hosts::default()),
427 };
428
429 assert_eq!(
430 block_on(cx.ipv4_only(Name::root()))
431 .unwrap()
432 .iter()
433 .map(|r| r.ip_addr().unwrap())
434 .collect::<Vec<IpAddr>>(),
435 vec![Ipv4Addr::LOCALHOST]
436 );
437 }
438
439 #[test]
440 fn test_ipv6_only_strategy() {
441 subscribe();
442
443 let cx = LookupContext {
444 client: CachingClient::new(0, mock(vec![v6_message()]), false),
445 options: DnsRequestOptions::default(),
446 hosts: Arc::new(Hosts::default()),
447 };
448
449 assert_eq!(
450 block_on(cx.ipv6_only(Name::root()))
451 .unwrap()
452 .iter()
453 .map(|r| r.ip_addr().unwrap())
454 .collect::<Vec<IpAddr>>(),
455 vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
456 );
457 }
458
459 #[test]
460 fn test_ipv4_and_ipv6_strategy() {
461 subscribe();
462
463 let mut cx = LookupContext {
464 client: CachingClient::new(0, mock(vec![v6_message(), v4_message()]), false),
465 options: DnsRequestOptions::default(),
466 hosts: Arc::new(Hosts::default()),
467 };
468
469 assert_eq!(
472 block_on(cx.ipv4_and_ipv6(Name::root()))
473 .unwrap()
474 .iter()
475 .map(|r| r.ip_addr().unwrap())
476 .collect::<Vec<IpAddr>>(),
477 vec![
478 IpAddr::V4(Ipv4Addr::LOCALHOST),
479 IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
480 ]
481 );
482
483 cx.client = CachingClient::new(0, mock(vec![empty(), v4_message()]), false);
485 assert_eq!(
486 block_on(cx.ipv4_and_ipv6(Name::root()))
487 .unwrap()
488 .iter()
489 .map(|r| r.ip_addr().unwrap())
490 .collect::<Vec<IpAddr>>(),
491 vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]
492 );
493
494 cx.client = CachingClient::new(0, mock(vec![error(), v4_message()]), false);
496 assert_eq!(
497 block_on(cx.ipv4_and_ipv6(Name::root()))
498 .unwrap()
499 .iter()
500 .map(|r| r.ip_addr().unwrap())
501 .collect::<Vec<IpAddr>>(),
502 vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]
503 );
504
505 cx.client = CachingClient::new(0, mock(vec![v6_message(), empty()]), false);
507 assert_eq!(
508 block_on(cx.ipv4_and_ipv6(Name::root()))
509 .unwrap()
510 .iter()
511 .map(|r| r.ip_addr().unwrap())
512 .collect::<Vec<IpAddr>>(),
513 vec![IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))]
514 );
515
516 cx.client = CachingClient::new(0, mock(vec![v6_message(), error()]), false);
518 assert_eq!(
519 block_on(cx.ipv4_and_ipv6(Name::root()))
520 .unwrap()
521 .iter()
522 .map(|r| r.ip_addr().unwrap())
523 .collect::<Vec<IpAddr>>(),
524 vec![IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))]
525 );
526 }
527
528 #[test]
529 fn test_ipv6_then_ipv4_strategy() {
530 subscribe();
531
532 let mut cx = LookupContext {
533 client: CachingClient::new(0, mock(vec![v6_message()]), false),
534 options: DnsRequestOptions::default(),
535 hosts: Arc::new(Hosts::default()),
536 };
537
538 assert_eq!(
540 block_on(cx.ipv6_then_ipv4(Name::root()))
541 .unwrap()
542 .iter()
543 .map(|r| r.ip_addr().unwrap())
544 .collect::<Vec<IpAddr>>(),
545 vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
546 );
547
548 cx.client = CachingClient::new(0, mock(vec![v4_message(), empty()]), false);
550 assert_eq!(
551 block_on(cx.ipv6_then_ipv4(Name::root()))
552 .unwrap()
553 .iter()
554 .map(|r| r.ip_addr().unwrap())
555 .collect::<Vec<IpAddr>>(),
556 vec![Ipv4Addr::LOCALHOST]
557 );
558
559 cx.client = CachingClient::new(0, mock(vec![v4_message(), error()]), false);
561 assert_eq!(
562 block_on(cx.ipv6_then_ipv4(Name::root()))
563 .unwrap()
564 .iter()
565 .map(|r| r.ip_addr().unwrap())
566 .collect::<Vec<IpAddr>>(),
567 vec![Ipv4Addr::LOCALHOST]
568 );
569 }
570
571 #[test]
572 fn test_ipv4_then_ipv6_strategy() {
573 subscribe();
574
575 let mut cx = LookupContext {
576 client: CachingClient::new(0, mock(vec![v4_message()]), false),
577 options: DnsRequestOptions::default(),
578 hosts: Arc::new(Hosts::default()),
579 };
580
581 assert_eq!(
583 block_on(cx.ipv4_then_ipv6(Name::root()))
584 .unwrap()
585 .iter()
586 .map(|r| r.ip_addr().unwrap())
587 .collect::<Vec<IpAddr>>(),
588 vec![Ipv4Addr::LOCALHOST]
589 );
590
591 cx.client = CachingClient::new(0, mock(vec![v6_message(), empty()]), false);
593 assert_eq!(
594 block_on(cx.ipv4_then_ipv6(Name::root()))
595 .unwrap()
596 .iter()
597 .map(|r| r.ip_addr().unwrap())
598 .collect::<Vec<IpAddr>>(),
599 vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
600 );
601
602 cx.client = CachingClient::new(0, mock(vec![v6_message(), error()]), false);
604 assert_eq!(
605 block_on(cx.ipv4_then_ipv6(Name::root()))
606 .unwrap()
607 .iter()
608 .map(|r| r.ip_addr().unwrap())
609 .collect::<Vec<IpAddr>>(),
610 vec![Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)]
611 );
612 }
613}