hickory_resolver/
caching_client.rs

1// Copyright 2015-2023 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Caching related functionality for the Resolver.
9
10use std::{
11    borrow::Cow,
12    error::Error,
13    pin::Pin,
14    sync::{
15        atomic::{AtomicU8, Ordering},
16        Arc,
17    },
18    time::Instant,
19};
20
21use futures_util::future::Future;
22use once_cell::sync::Lazy;
23
24use crate::{
25    dns_lru::{self, DnsLru, TtlConfig},
26    error::{ResolveError, ResolveErrorKind},
27    lookup::Lookup,
28    proto::{
29        error::ProtoError,
30        op::{Query, ResponseCode},
31        rr::{
32            domain::usage::{
33                ResolverUsage, DEFAULT, INVALID, IN_ADDR_ARPA_127, IP6_ARPA_1, LOCAL,
34                LOCALHOST as LOCALHOST_usage, ONION,
35            },
36            rdata::{A, AAAA, CNAME, PTR, SOA},
37            resource::RecordRef,
38            DNSClass, Name, RData, Record, RecordType,
39        },
40        xfer::{DnsHandle, DnsRequestOptions, DnsResponse, FirstAnswer},
41    },
42};
43
44const MAX_QUERY_DEPTH: u8 = 8; // arbitrarily chosen number...
45
46static LOCALHOST: Lazy<RData> =
47    Lazy::new(|| RData::PTR(PTR(Name::from_ascii("localhost.").unwrap())));
48static LOCALHOST_V4: Lazy<RData> = Lazy::new(|| RData::A(A::new(127, 0, 0, 1)));
49static LOCALHOST_V6: Lazy<RData> = Lazy::new(|| RData::AAAA(AAAA::new(0, 0, 0, 0, 0, 0, 0, 1)));
50
51struct DepthTracker {
52    query_depth: Arc<AtomicU8>,
53}
54
55impl DepthTracker {
56    fn track(query_depth: Arc<AtomicU8>) -> Self {
57        query_depth.fetch_add(1, Ordering::Release);
58        Self { query_depth }
59    }
60}
61
62impl Drop for DepthTracker {
63    fn drop(&mut self) {
64        self.query_depth.fetch_sub(1, Ordering::Release);
65    }
66}
67
68// TODO: need to consider this storage type as it compares to Authority in server...
69//       should it just be an variation on Authority?
70#[derive(Clone, Debug)]
71#[doc(hidden)]
72pub struct CachingClient<C, E>
73where
74    C: DnsHandle<Error = E>,
75    E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
76{
77    lru: DnsLru,
78    client: C,
79    query_depth: Arc<AtomicU8>,
80    preserve_intermediates: bool,
81}
82
83impl<C, E> CachingClient<C, E>
84where
85    C: DnsHandle<Error = E> + Send + 'static,
86    E: Into<ResolveError> + From<ProtoError> + Error + Clone + Send + Unpin + 'static,
87{
88    #[doc(hidden)]
89    pub fn new(max_size: usize, client: C, preserve_intermediates: bool) -> Self {
90        Self::with_cache(
91            DnsLru::new(max_size, TtlConfig::default()),
92            client,
93            preserve_intermediates,
94        )
95    }
96
97    pub(crate) fn with_cache(lru: DnsLru, client: C, preserve_intermediates: bool) -> Self {
98        let query_depth = Arc::new(AtomicU8::new(0));
99        Self {
100            lru,
101            client,
102            query_depth,
103            preserve_intermediates,
104        }
105    }
106
107    /// Perform a lookup against this caching client, looking first in the cache for a result
108    pub fn lookup(
109        &mut self,
110        query: Query,
111        options: DnsRequestOptions,
112    ) -> Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>> {
113        Box::pin(Self::inner_lookup(query, options, self.clone(), vec![]))
114    }
115
116    async fn inner_lookup(
117        query: Query,
118        options: DnsRequestOptions,
119        mut client: Self,
120        preserved_records: Vec<(Record, u32)>,
121    ) -> Result<Lookup, ResolveError> {
122        // see https://tools.ietf.org/html/rfc6761
123        //
124        // ```text
125        // Name resolution APIs and libraries SHOULD recognize localhost
126        // names as special and SHOULD always return the IP loopback address
127        // for address queries and negative responses for all other query
128        // types.  Name resolution APIs SHOULD NOT send queries for
129        // localhost names to their configured caching DNS server(s).
130        // ```
131        // special use rules only apply to the IN Class
132        if query.query_class() == DNSClass::IN {
133            let usage = match query.name() {
134                n if LOCALHOST_usage.zone_of(n) => &*LOCALHOST_usage,
135                n if IN_ADDR_ARPA_127.zone_of(n) => &*LOCALHOST_usage,
136                n if IP6_ARPA_1.zone_of(n) => &*LOCALHOST_usage,
137                n if INVALID.zone_of(n) => &*INVALID,
138                n if LOCAL.zone_of(n) => &*LOCAL,
139                n if ONION.zone_of(n) => &*ONION,
140                _ => &*DEFAULT,
141            };
142
143            match usage.resolver() {
144                ResolverUsage::Loopback => match query.query_type() {
145                    // TODO: look in hosts for these ips/names first...
146                    RecordType::A => return Ok(Lookup::from_rdata(query, LOCALHOST_V4.clone())),
147                    RecordType::AAAA => return Ok(Lookup::from_rdata(query, LOCALHOST_V6.clone())),
148                    RecordType::PTR => return Ok(Lookup::from_rdata(query, LOCALHOST.clone())),
149                    _ => {
150                        return Err(ResolveError::nx_error(
151                            query,
152                            None,
153                            None,
154                            ResponseCode::NoError,
155                            false,
156                        ))
157                    } // Are there any other types we can use?
158                },
159                // when mdns is enabled we will follow a standard query path
160                #[cfg(feature = "mdns")]
161                ResolverUsage::LinkLocal => (),
162                // TODO: this requires additional config, as Kubernetes and other systems misuse the .local. zone.
163                // when mdns is not enabled we will return errors on LinkLocal ("*.local.") names
164                #[cfg(not(feature = "mdns"))]
165                ResolverUsage::LinkLocal => (),
166                ResolverUsage::NxDomain => {
167                    return Err(ResolveError::nx_error(
168                        query,
169                        None,
170                        None,
171                        ResponseCode::NXDomain,
172                        false,
173                    ))
174                }
175                ResolverUsage::Normal => (),
176            }
177        }
178
179        let _tracker = DepthTracker::track(client.query_depth.clone());
180        let is_dnssec = client.client.is_verifying_dnssec();
181
182        // first transition any polling that is needed (mutable refs...)
183        if let Some(cached_lookup) = client.lookup_from_cache(&query) {
184            return cached_lookup;
185        };
186
187        let response_message = client
188            .client
189            .lookup(query.clone(), options)
190            .first_answer()
191            .await
192            .map_err(E::into);
193
194        // TODO: technically this might be duplicating work, as name_server already performs this evaluation.
195        //  we may want to create a new type, if evaluated... but this is most generic to support any impl in LookupState...
196        let response_message = if let Ok(response) = response_message {
197            ResolveError::from_response(response, false)
198        } else {
199            response_message
200        };
201
202        // TODO: take all records and cache them?
203        //  if it's DNSSEC they must be signed, otherwise?
204        let records: Result<Records, ResolveError> = match response_message {
205            // this is the only cacheable form
206            Err(ResolveError {
207                kind:
208                    ResolveErrorKind::NoRecordsFound {
209                        query,
210                        soa,
211                        negative_ttl,
212                        response_code,
213                        trusted,
214                    },
215                ..
216            }) => {
217                Err(Self::handle_nxdomain(
218                    is_dnssec,
219                    false, /*tbd*/
220                    *query,
221                    soa.map(|v| *v),
222                    negative_ttl,
223                    response_code,
224                    trusted,
225                ))
226            }
227            Err(e) => return Err(e),
228            Ok(response_message) => {
229                // allow the handle_noerror function to deal with any error codes
230                let records = Self::handle_noerror(
231                    &mut client,
232                    options,
233                    is_dnssec,
234                    &query,
235                    response_message,
236                    preserved_records,
237                )?;
238
239                Ok(records)
240            }
241        };
242
243        // after the request, evaluate if we have additional queries to perform
244        match records {
245            Ok(Records::CnameChain {
246                next: future,
247                min_ttl: ttl,
248            }) => match future.await {
249                Ok(lookup) => client.cname(lookup, query, ttl),
250                Err(e) => client.cache(query, Err(e)),
251            },
252            Ok(Records::Exists(rdata)) => client.cache(query, Ok(rdata)),
253            Err(e) => client.cache(query, Err(e)),
254        }
255    }
256
257    /// Check if this query is already cached
258    fn lookup_from_cache(&self, query: &Query) -> Option<Result<Lookup, ResolveError>> {
259        self.lru.get(query, Instant::now())
260    }
261
262    /// See https://tools.ietf.org/html/rfc2308
263    ///
264    /// For now we will regard NXDomain to strictly mean the query failed
265    ///  and a record for the name, regardless of CNAME presence, what have you
266    ///  ultimately does not exist.
267    ///
268    /// This also handles empty responses in the same way. When performing DNSSEC enabled queries, we should
269    ///  never enter here, and should never cache unless verified requests.
270    ///
271    /// TODO: should this should be expanded to do a forward lookup? Today, this will fail even if there are
272    ///   forwarding options.
273    ///
274    /// # Arguments
275    ///
276    /// * `message` - message to extract SOA, etc, from for caching failed requests
277    /// * `valid_nsec` - species that in DNSSEC mode, this request is safe to cache
278    /// * `negative_ttl` - this should be the SOA minimum for negative ttl
279    fn handle_nxdomain(
280        is_dnssec: bool,
281        valid_nsec: bool,
282        query: Query,
283        soa: Option<Record<SOA>>,
284        negative_ttl: Option<u32>,
285        response_code: ResponseCode,
286        trusted: bool,
287    ) -> ResolveError {
288        if valid_nsec || !is_dnssec {
289            // only trust if there were validated NSEC records
290            ResolveErrorKind::NoRecordsFound {
291                query: Box::new(query),
292                soa: soa.map(Box::new),
293                negative_ttl,
294                response_code,
295                trusted: true,
296            }
297            .into()
298        } else {
299            // not cacheable, no ttl...
300            ResolveErrorKind::NoRecordsFound {
301                query: Box::new(query),
302                soa: soa.map(Box::new),
303                negative_ttl: None,
304                response_code,
305                trusted,
306            }
307            .into()
308        }
309    }
310
311    /// Handle the case where there is no error returned
312    fn handle_noerror(
313        client: &mut Self,
314        options: DnsRequestOptions,
315        is_dnssec: bool,
316        query: &Query,
317        response: DnsResponse,
318        mut preserved_records: Vec<(Record, u32)>,
319    ) -> Result<Records, ResolveError> {
320        // initial ttl is what CNAMES for min usage
321        const INITIAL_TTL: u32 = dns_lru::MAX_TTL;
322
323        // need to capture these before the subsequent and destructive record processing
324        let soa = response.soa().as_ref().map(RecordRef::to_owned);
325        let negative_ttl = response.negative_ttl();
326        let response_code = response.response_code();
327
328        // seek out CNAMES, this is only performed if the query is not a CNAME, ANY, or SRV
329        // FIXME: for SRV this evaluation is inadequate. CNAME is a single chain to a single record
330        //   for SRV, there could be many different targets. The search_name needs to be enhanced to
331        //   be a list of names found for SRV records.
332        let (search_name, cname_ttl, was_cname, preserved_records) = {
333            // this will only search for CNAMEs if the request was not meant to be for one of the triggers for recursion
334            let (search_name, cname_ttl, was_cname) =
335                if query.query_type().is_any() || query.query_type().is_cname() {
336                    (Cow::Borrowed(query.name()), INITIAL_TTL, false)
337                } else {
338                    // Folds any cnames from the answers section, into the final cname in the answers section
339                    //   this works by folding the last CNAME found into the final folded result.
340                    //   it assumes that the CNAMEs are in chained order in the DnsResponse Message...
341                    // For SRV, the name added for the search becomes the target name.
342                    //
343                    // TODO: should this include the additionals?
344                    response.answers().iter().fold(
345                        (Cow::Borrowed(query.name()), INITIAL_TTL, false),
346                        |(search_name, cname_ttl, was_cname), r| {
347                            match r.data() {
348                                Some(RData::CNAME(CNAME(ref cname))) => {
349                                    // take the minimum TTL of the cname_ttl and the next record in the chain
350                                    let ttl = cname_ttl.min(r.ttl());
351                                    debug_assert_eq!(r.record_type(), RecordType::CNAME);
352                                    if search_name.as_ref() == r.name() {
353                                        return (Cow::Owned(cname.clone()), ttl, true);
354                                    }
355                                }
356                                Some(RData::SRV(ref srv)) => {
357                                    // take the minimum TTL of the cname_ttl and the next record in the chain
358                                    let ttl = cname_ttl.min(r.ttl());
359                                    debug_assert_eq!(r.record_type(), RecordType::SRV);
360
361                                    // the search name becomes the srv.target
362                                    return (Cow::Owned(srv.target().clone()), ttl, true);
363                                }
364                                _ => (),
365                            }
366
367                            (search_name, cname_ttl, was_cname)
368                        },
369                    )
370                };
371
372            // take all answers. // TODO: following CNAMES?
373            let mut response = response.into_message();
374            let answers = response.take_answers();
375            let additionals = response.take_additionals();
376            let name_servers = response.take_name_servers();
377
378            // set of names that still require resolution
379            // TODO: this needs to be enhanced for SRV
380            let mut found_name = false;
381
382            // After following all the CNAMES to the last one, try and lookup the final name
383            let records = answers
384                .into_iter()
385                // Chained records will generally exist in the additionals section
386                .chain(additionals)
387                .chain(name_servers)
388                .filter_map(|r| {
389                    // because this resolved potentially recursively, we want the min TTL from the chain
390                    let ttl = cname_ttl.min(r.ttl());
391                    // TODO: disable name validation with ResolverOpts? glibc feature...
392                    // restrict to the RData type requested
393                    if query.query_class() == r.dns_class() {
394                        // standard evaluation, it's an any type or it's the requested type and the search_name matches
395                        #[allow(clippy::suspicious_operation_groupings)]
396                        if (query.query_type().is_any() || query.query_type() == r.record_type())
397                            && (search_name.as_ref() == r.name() || query.name() == r.name())
398                        {
399                            found_name = true;
400                            return Some((r, ttl));
401                        }
402                        // CNAME evaluation, the record is from the CNAME lookup chain.
403                        if client.preserve_intermediates && r.record_type() == RecordType::CNAME {
404                            return Some((r, ttl));
405                        }
406                        // srv evaluation, it's an srv lookup and the srv_search_name/target matches this name
407                        //    and it's an IP
408                        if query.query_type().is_srv()
409                            && r.record_type().is_ip_addr()
410                            && search_name.as_ref() == r.name()
411                        {
412                            found_name = true;
413                            Some((r, ttl))
414                        } else if query.query_type().is_ns() && r.record_type().is_ip_addr() {
415                            Some((r, ttl))
416                        } else {
417                            None
418                        }
419                    } else {
420                        None
421                    }
422                })
423                .collect::<Vec<_>>();
424
425            // adding the newly collected records to the preserved records
426            preserved_records.extend(records);
427            if !preserved_records.is_empty() && found_name {
428                return Ok(Records::Exists(preserved_records));
429            }
430
431            (
432                search_name.into_owned(),
433                cname_ttl,
434                was_cname,
435                preserved_records,
436            )
437        };
438
439        // TODO: for SRV records we *could* do an implicit lookup, but, this requires knowing the type of IP desired
440        //    for now, we'll make the API require the user to perform a follow up to the lookups.
441        // It was a CNAME, but not included in the request...
442        if was_cname && client.query_depth.load(Ordering::Acquire) < MAX_QUERY_DEPTH {
443            let next_query = Query::query(search_name, query.query_type());
444            Ok(Records::CnameChain {
445                next: Box::pin(Self::inner_lookup(
446                    next_query,
447                    options,
448                    client.clone(),
449                    preserved_records,
450                )),
451                min_ttl: cname_ttl,
452            })
453        } else {
454            // TODO: review See https://tools.ietf.org/html/rfc2308 for NoData section
455            // Note on DNSSEC, in secure_client_handle, if verify_nsec fails then the request fails.
456            //   this will mean that no unverified negative caches will make it to this point and be stored
457            Err(Self::handle_nxdomain(
458                is_dnssec,
459                true,
460                query.clone(),
461                soa,
462                negative_ttl,
463                response_code,
464                false,
465            ))
466        }
467    }
468
469    #[allow(clippy::unnecessary_wraps)]
470    fn cname(&self, lookup: Lookup, query: Query, cname_ttl: u32) -> Result<Lookup, ResolveError> {
471        // this duplicates the cache entry under the original query
472        Ok(self.lru.duplicate(query, lookup, cname_ttl, Instant::now()))
473    }
474
475    fn cache(
476        &self,
477        query: Query,
478        records: Result<Vec<(Record, u32)>, ResolveError>,
479    ) -> Result<Lookup, ResolveError> {
480        // this will put this object into an inconsistent state, but no one should call poll again...
481        match records {
482            Ok(rdata) => Ok(self.lru.insert(query, rdata, Instant::now())),
483            Err(err) => Err(self.lru.negative(query, err, Instant::now())),
484        }
485    }
486
487    /// Flushes/Removes all entries from the cache
488    pub fn clear_cache(&self) {
489        self.lru.clear();
490    }
491}
492
493enum Records {
494    /// The records exists, a vec of rdata with ttl
495    Exists(Vec<(Record, u32)>),
496    /// Future lookup for recursive cname records
497    CnameChain {
498        next: Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>>,
499        min_ttl: u32,
500    },
501}
502
503// see also the lookup_tests.rs in integration-tests crate
504#[cfg(test)]
505mod tests {
506    use std::net::*;
507    use std::str::FromStr;
508    use std::time::*;
509
510    use futures_executor::block_on;
511    use proto::op::{Message, Query};
512    use proto::rr::rdata::{NS, SRV};
513    use proto::rr::{Name, Record};
514
515    use super::*;
516    use crate::lookup_ip::tests::*;
517
518    #[test]
519    fn test_empty_cache() {
520        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
521        let client = mock(vec![empty()]);
522        let client = CachingClient::with_cache(cache, client, false);
523
524        if let ResolveErrorKind::NoRecordsFound {
525            query,
526            negative_ttl,
527            ..
528        } = block_on(CachingClient::inner_lookup(
529            Query::new(),
530            DnsRequestOptions::default(),
531            client,
532            vec![],
533        ))
534        .unwrap_err()
535        .kind()
536        {
537            assert_eq!(**query, Query::new());
538            assert_eq!(*negative_ttl, None);
539        } else {
540            panic!("wrong error received")
541        }
542    }
543
544    #[test]
545    fn test_from_cache() {
546        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
547        let query = Query::new();
548        cache.insert(
549            query.clone(),
550            vec![(
551                Record::from_rdata(
552                    query.name().clone(),
553                    u32::max_value(),
554                    RData::A(A::new(127, 0, 0, 1)),
555                ),
556                u32::max_value(),
557            )],
558            Instant::now(),
559        );
560
561        let client = mock(vec![empty()]);
562        let client = CachingClient::with_cache(cache, client, false);
563
564        let ips = block_on(CachingClient::inner_lookup(
565            Query::new(),
566            DnsRequestOptions::default(),
567            client,
568            vec![],
569        ))
570        .unwrap();
571
572        assert_eq!(
573            ips.iter().cloned().collect::<Vec<_>>(),
574            vec![RData::A(A::new(127, 0, 0, 1))]
575        );
576    }
577
578    #[test]
579    fn test_no_cache_insert() {
580        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
581        // first should come from client...
582        let client = mock(vec![v4_message()]);
583        let client = CachingClient::with_cache(cache.clone(), client, false);
584
585        let ips = block_on(CachingClient::inner_lookup(
586            Query::new(),
587            DnsRequestOptions::default(),
588            client,
589            vec![],
590        ))
591        .unwrap();
592
593        assert_eq!(
594            ips.iter().cloned().collect::<Vec<_>>(),
595            vec![RData::A(A::new(127, 0, 0, 1))]
596        );
597
598        // next should come from cache...
599        let client = mock(vec![empty()]);
600        let client = CachingClient::with_cache(cache, client, false);
601
602        let ips = block_on(CachingClient::inner_lookup(
603            Query::new(),
604            DnsRequestOptions::default(),
605            client,
606            vec![],
607        ))
608        .unwrap();
609
610        assert_eq!(
611            ips.iter().cloned().collect::<Vec<_>>(),
612            vec![RData::A(A::new(127, 0, 0, 1))]
613        );
614    }
615
616    #[allow(clippy::unnecessary_wraps)]
617    pub(crate) fn cname_message() -> Result<DnsResponse, ResolveError> {
618        let mut message = Message::new();
619        message.add_query(Query::query(
620            Name::from_str("www.example.com.").unwrap(),
621            RecordType::A,
622        ));
623        message.insert_answers(vec![Record::from_rdata(
624            Name::from_str("www.example.com.").unwrap(),
625            86400,
626            RData::CNAME(CNAME(Name::from_str("actual.example.com.").unwrap())),
627        )]);
628        Ok(DnsResponse::from_message(message).unwrap())
629    }
630
631    #[allow(clippy::unnecessary_wraps)]
632    pub(crate) fn srv_message() -> Result<DnsResponse, ResolveError> {
633        let mut message = Message::new();
634        message.add_query(Query::query(
635            Name::from_str("_443._tcp.www.example.com.").unwrap(),
636            RecordType::SRV,
637        ));
638        message.insert_answers(vec![Record::from_rdata(
639            Name::from_str("_443._tcp.www.example.com.").unwrap(),
640            86400,
641            RData::SRV(SRV::new(
642                1,
643                2,
644                443,
645                Name::from_str("www.example.com.").unwrap(),
646            )),
647        )]);
648        Ok(DnsResponse::from_message(message).unwrap())
649    }
650
651    #[allow(clippy::unnecessary_wraps)]
652    pub(crate) fn ns_message() -> Result<DnsResponse, ResolveError> {
653        let mut message = Message::new();
654        message.add_query(Query::query(
655            Name::from_str("www.example.com.").unwrap(),
656            RecordType::NS,
657        ));
658        message.insert_answers(vec![Record::from_rdata(
659            Name::from_str("www.example.com.").unwrap(),
660            86400,
661            RData::NS(NS(Name::from_str("www.example.com.").unwrap())),
662        )]);
663        Ok(DnsResponse::from_message(message).unwrap())
664    }
665
666    fn no_recursion_on_query_test(query_type: RecordType) {
667        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
668
669        // the cname should succeed, we shouldn't query again after that, which would cause an error...
670        let client = mock(vec![error(), cname_message()]);
671        let client = CachingClient::with_cache(cache, client, false);
672
673        let ips = block_on(CachingClient::inner_lookup(
674            Query::query(Name::from_str("www.example.com.").unwrap(), query_type),
675            DnsRequestOptions::default(),
676            client,
677            vec![],
678        ))
679        .expect("lookup failed");
680
681        assert_eq!(
682            ips.iter().cloned().collect::<Vec<_>>(),
683            vec![RData::CNAME(CNAME(
684                Name::from_str("actual.example.com.").unwrap()
685            ))]
686        );
687    }
688
689    #[test]
690    fn test_no_recursion_on_cname_query() {
691        no_recursion_on_query_test(RecordType::CNAME);
692    }
693
694    #[test]
695    fn test_no_recursion_on_all_query() {
696        no_recursion_on_query_test(RecordType::ANY);
697    }
698
699    #[test]
700    fn test_non_recursive_srv_query() {
701        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
702
703        // the cname should succeed, we shouldn't query again after that, which would cause an error...
704        let client = mock(vec![error(), srv_message()]);
705        let client = CachingClient::with_cache(cache, client, false);
706
707        let ips = block_on(CachingClient::inner_lookup(
708            Query::query(
709                Name::from_str("_443._tcp.www.example.com.").unwrap(),
710                RecordType::SRV,
711            ),
712            DnsRequestOptions::default(),
713            client,
714            vec![],
715        ))
716        .expect("lookup failed");
717
718        assert_eq!(
719            ips.iter().cloned().collect::<Vec<_>>(),
720            vec![RData::SRV(SRV::new(
721                1,
722                2,
723                443,
724                Name::from_str("www.example.com.").unwrap(),
725            ))]
726        );
727    }
728
729    #[test]
730    fn test_single_srv_query_response() {
731        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
732
733        let mut message = srv_message().unwrap().into_message();
734        message.add_answer(Record::from_rdata(
735            Name::from_str("www.example.com.").unwrap(),
736            86400,
737            RData::CNAME(CNAME(Name::from_str("actual.example.com.").unwrap())),
738        ));
739        message.insert_additionals(vec![
740            Record::from_rdata(
741                Name::from_str("actual.example.com.").unwrap(),
742                86400,
743                RData::A(A::new(127, 0, 0, 1)),
744            ),
745            Record::from_rdata(
746                Name::from_str("actual.example.com.").unwrap(),
747                86400,
748                RData::AAAA(AAAA::new(0, 0, 0, 0, 0, 0, 0, 1)),
749            ),
750        ]);
751
752        let client = mock(vec![
753            error(),
754            Ok(DnsResponse::from_message(message).unwrap()),
755        ]);
756        let client = CachingClient::with_cache(cache, client, false);
757
758        let ips = block_on(CachingClient::inner_lookup(
759            Query::query(
760                Name::from_str("_443._tcp.www.example.com.").unwrap(),
761                RecordType::SRV,
762            ),
763            DnsRequestOptions::default(),
764            client,
765            vec![],
766        ))
767        .expect("lookup failed");
768
769        assert_eq!(
770            ips.iter().cloned().collect::<Vec<_>>(),
771            vec![
772                RData::SRV(SRV::new(
773                    1,
774                    2,
775                    443,
776                    Name::from_str("www.example.com.").unwrap(),
777                )),
778                RData::A(A::new(127, 0, 0, 1)),
779                RData::AAAA(AAAA::new(0, 0, 0, 0, 0, 0, 0, 1)),
780            ]
781        );
782    }
783
784    // TODO: if we ever enable recursive lookups for SRV, here are the tests...
785    // #[test]
786    // fn test_recursive_srv_query() {
787    //     let cache = Arc::new(Mutex::new(DnsLru::new(1)));
788
789    //     let mut message = Message::new();
790    //     message.add_answer(Record::from_rdata(
791    //         Name::from_str("www.example.com.").unwrap(),
792    //         86400,
793    //         RecordType::CNAME,
794    //         RData::CNAME(Name::from_str("actual.example.com.").unwrap()),
795    //     ));
796    //     message.insert_additionals(vec![
797    //         Record::from_rdata(
798    //             Name::from_str("actual.example.com.").unwrap(),
799    //             86400,
800    //             RecordType::A,
801    //             RData::A(Ipv4Addr::new(127, 0, 0, 1)),
802    //         ),
803    //     ]);
804
805    //     let mut client = mock(vec![error(), Ok(DnsResponse::from_message(message).unwrap()), srv_message()]);
806
807    //     let ips = QueryState::lookup(
808    //         Query::query(
809    //             Name::from_str("_443._tcp.www.example.com.").unwrap(),
810    //             RecordType::SRV,
811    //         ),
812    //         Default::default(),
813    //         &mut client,
814    //         cache.clone(),
815    //     ).wait()
816    //         .expect("lookup failed");
817
818    //     assert_eq!(
819    //         ips.iter().cloned().collect::<Vec<_>>(),
820    //         vec![
821    //             RData::SRV(SRV::new(
822    //                 1,
823    //                 2,
824    //                 443,
825    //                 Name::from_str("www.example.com.").unwrap(),
826    //             )),
827    //             RData::A(Ipv4Addr::new(127, 0, 0, 1)),
828    //             //RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
829    //         ]
830    //     );
831    // }
832
833    #[test]
834    fn test_single_ns_query_response() {
835        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
836
837        let mut message = ns_message().unwrap().into_message();
838        message.add_answer(Record::from_rdata(
839            Name::from_str("www.example.com.").unwrap(),
840            86400,
841            RData::CNAME(CNAME(Name::from_str("actual.example.com.").unwrap())),
842        ));
843        message.insert_additionals(vec![
844            Record::from_rdata(
845                Name::from_str("actual.example.com.").unwrap(),
846                86400,
847                RData::A(A::new(127, 0, 0, 1)),
848            ),
849            Record::from_rdata(
850                Name::from_str("actual.example.com.").unwrap(),
851                86400,
852                RData::AAAA(AAAA::new(0, 0, 0, 0, 0, 0, 0, 1)),
853            ),
854        ]);
855
856        let client = mock(vec![
857            error(),
858            Ok(DnsResponse::from_message(message).unwrap()),
859        ]);
860        let client = CachingClient::with_cache(cache, client, false);
861
862        let ips = block_on(CachingClient::inner_lookup(
863            Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::NS),
864            DnsRequestOptions::default(),
865            client,
866            vec![],
867        ))
868        .expect("lookup failed");
869
870        assert_eq!(
871            ips.iter().cloned().collect::<Vec<_>>(),
872            vec![
873                RData::NS(NS(Name::from_str("www.example.com.").unwrap())),
874                RData::A(A::new(127, 0, 0, 1)),
875                RData::AAAA(AAAA::new(0, 0, 0, 0, 0, 0, 0, 1)),
876            ]
877        );
878    }
879
880    fn cname_ttl_test(first: u32, second: u32) {
881        let lru = DnsLru::new(1, dns_lru::TtlConfig::default());
882        // expecting no queries to be performed
883        let mut client = CachingClient::with_cache(lru, mock(vec![error()]), false);
884
885        let mut message = Message::new();
886        message.insert_answers(vec![Record::from_rdata(
887            Name::from_str("ttl.example.com.").unwrap(),
888            first,
889            RData::CNAME(CNAME(Name::from_str("actual.example.com.").unwrap())),
890        )]);
891        message.insert_additionals(vec![Record::from_rdata(
892            Name::from_str("actual.example.com.").unwrap(),
893            second,
894            RData::A(A::new(127, 0, 0, 1)),
895        )]);
896
897        let records = CachingClient::handle_noerror(
898            &mut client,
899            DnsRequestOptions::default(),
900            false,
901            &Query::query(Name::from_str("ttl.example.com.").unwrap(), RecordType::A),
902            DnsResponse::from_message(message).unwrap(),
903            vec![],
904        );
905
906        if let Ok(records) = records {
907            if let Records::Exists(records) = records {
908                for (record, ttl) in records.iter() {
909                    if record.record_type() == RecordType::CNAME {
910                        continue;
911                    }
912                    assert_eq!(ttl, &1);
913                }
914            } else {
915                panic!("records don't exist");
916            }
917        } else {
918            panic!("error getting records");
919        }
920    }
921
922    #[test]
923    fn test_cname_ttl() {
924        cname_ttl_test(1, 2);
925        cname_ttl_test(2, 1);
926    }
927
928    #[test]
929    fn test_early_return_localhost() {
930        let cache = DnsLru::new(0, dns_lru::TtlConfig::default());
931        let client = mock(vec![empty()]);
932        let mut client = CachingClient::with_cache(cache, client, false);
933
934        {
935            let query = Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::A);
936            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
937                .expect("should have returned localhost");
938            assert_eq!(lookup.query(), &query);
939            assert_eq!(
940                lookup.iter().cloned().collect::<Vec<_>>(),
941                vec![LOCALHOST_V4.clone()]
942            );
943        }
944
945        {
946            let query = Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::AAAA);
947            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
948                .expect("should have returned localhost");
949            assert_eq!(lookup.query(), &query);
950            assert_eq!(
951                lookup.iter().cloned().collect::<Vec<_>>(),
952                vec![LOCALHOST_V6.clone()]
953            );
954        }
955
956        {
957            let query = Query::query(Name::from(Ipv4Addr::new(127, 0, 0, 1)), RecordType::PTR);
958            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
959                .expect("should have returned localhost");
960            assert_eq!(lookup.query(), &query);
961            assert_eq!(
962                lookup.iter().cloned().collect::<Vec<_>>(),
963                vec![LOCALHOST.clone()]
964            );
965        }
966
967        {
968            let query = Query::query(
969                Name::from(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
970                RecordType::PTR,
971            );
972            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
973                .expect("should have returned localhost");
974            assert_eq!(lookup.query(), &query);
975            assert_eq!(
976                lookup.iter().cloned().collect::<Vec<_>>(),
977                vec![LOCALHOST.clone()]
978            );
979        }
980
981        assert!(block_on(client.lookup(
982            Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::MX),
983            DnsRequestOptions::default()
984        ))
985        .is_err());
986
987        assert!(block_on(client.lookup(
988            Query::query(Name::from(Ipv4Addr::new(127, 0, 0, 1)), RecordType::MX),
989            DnsRequestOptions::default()
990        ))
991        .is_err());
992
993        assert!(block_on(client.lookup(
994            Query::query(
995                Name::from(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
996                RecordType::MX
997            ),
998            DnsRequestOptions::default()
999        ))
1000        .is_err());
1001    }
1002
1003    #[test]
1004    fn test_early_return_invalid() {
1005        let cache = DnsLru::new(0, dns_lru::TtlConfig::default());
1006        let client = mock(vec![empty()]);
1007        let mut client = CachingClient::with_cache(cache, client, false);
1008
1009        assert!(block_on(client.lookup(
1010            Query::query(
1011                Name::from_ascii("horrible.invalid.").unwrap(),
1012                RecordType::A,
1013            ),
1014            DnsRequestOptions::default()
1015        ))
1016        .is_err());
1017    }
1018
1019    #[test]
1020    fn test_no_error_on_dot_local_no_mdns() {
1021        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
1022
1023        let mut message = srv_message().unwrap().into_message();
1024        message.add_query(Query::query(
1025            Name::from_ascii("www.example.local.").unwrap(),
1026            RecordType::A,
1027        ));
1028        message.add_answer(Record::from_rdata(
1029            Name::from_str("www.example.local.").unwrap(),
1030            86400,
1031            RData::A(A::new(127, 0, 0, 1)),
1032        ));
1033
1034        let client = mock(vec![
1035            error(),
1036            Ok(DnsResponse::from_message(message).unwrap()),
1037        ]);
1038        let mut client = CachingClient::with_cache(cache, client, false);
1039
1040        assert!(block_on(client.lookup(
1041            Query::query(
1042                Name::from_ascii("www.example.local.").unwrap(),
1043                RecordType::A,
1044            ),
1045            DnsRequestOptions::default()
1046        ))
1047        .is_ok());
1048    }
1049}