hickory_resolver/
caching_client.rs

1// Copyright 2015-2023 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Caching related functionality for the Resolver.
9
10use std::{borrow::Cow, future::Future, pin::Pin, sync::Arc, time::Instant};
11
12use futures_util::future::TryFutureExt;
13use once_cell::sync::Lazy;
14
15use crate::{
16    dns_lru::{self, DnsLru, TtlConfig},
17    error::ResolveError,
18    lookup::Lookup,
19    proto::{
20        op::{Query, ResponseCode},
21        rr::{
22            DNSClass, Name, RData, Record, RecordType,
23            domain::usage::{
24                DEFAULT, IN_ADDR_ARPA_127, INVALID, IP6_ARPA_1, LOCAL,
25                LOCALHOST as LOCALHOST_usage, ONION, ResolverUsage,
26            },
27            rdata::{A, AAAA, CNAME, PTR, SOA},
28            resource::RecordRef,
29        },
30        xfer::{DnsHandle, DnsRequestOptions, DnsResponse, FirstAnswer},
31        {ForwardNSData, ProtoError, ProtoErrorKind},
32    },
33};
34
35static LOCALHOST: Lazy<RData> =
36    Lazy::new(|| RData::PTR(PTR(Name::from_ascii("localhost.").unwrap())));
37static LOCALHOST_V4: Lazy<RData> = Lazy::new(|| RData::A(A::new(127, 0, 0, 1)));
38static LOCALHOST_V6: Lazy<RData> = Lazy::new(|| RData::AAAA(AAAA::new(0, 0, 0, 0, 0, 0, 0, 1)));
39
40/// Counts the depth of CNAME query resolutions.
41#[derive(Default, Clone, Copy)]
42struct DepthTracker {
43    query_depth: u8,
44}
45
46impl DepthTracker {
47    fn nest(self) -> Self {
48        Self {
49            query_depth: self.query_depth + 1,
50        }
51    }
52
53    fn is_exhausted(self) -> bool {
54        self.query_depth + 1 >= Self::MAX_QUERY_DEPTH
55    }
56
57    const MAX_QUERY_DEPTH: u8 = 8; // arbitrarily chosen number...
58}
59
60// TODO: need to consider this storage type as it compares to Authority in server...
61//       should it just be an variation on Authority?
62#[derive(Clone, Debug)]
63#[doc(hidden)]
64pub struct CachingClient<C>
65where
66    C: DnsHandle,
67{
68    lru: DnsLru,
69    client: C,
70    preserve_intermediates: bool,
71}
72
73impl<C> CachingClient<C>
74where
75    C: DnsHandle + Send + 'static,
76{
77    #[doc(hidden)]
78    pub fn new(max_size: usize, client: C, preserve_intermediates: bool) -> Self {
79        Self::with_cache(
80            DnsLru::new(max_size, TtlConfig::default()),
81            client,
82            preserve_intermediates,
83        )
84    }
85
86    pub(crate) fn with_cache(lru: DnsLru, client: C, preserve_intermediates: bool) -> Self {
87        Self {
88            lru,
89            client,
90            preserve_intermediates,
91        }
92    }
93
94    /// Perform a lookup against this caching client, looking first in the cache for a result
95    pub fn lookup(
96        &self,
97        query: Query,
98        options: DnsRequestOptions,
99    ) -> Pin<Box<dyn Future<Output = Result<Lookup, ResolveError>> + Send>> {
100        Box::pin(
101            Self::inner_lookup(
102                query,
103                options,
104                self.clone(),
105                vec![],
106                DepthTracker::default(),
107            )
108            .map_err(ResolveError::from),
109        )
110    }
111
112    async fn inner_lookup(
113        query: Query,
114        options: DnsRequestOptions,
115        mut client: Self,
116        preserved_records: Vec<(Record, u32)>,
117        depth: DepthTracker,
118    ) -> Result<Lookup, ProtoError> {
119        // see https://tools.ietf.org/html/rfc6761
120        //
121        // ```text
122        // Name resolution APIs and libraries SHOULD recognize localhost
123        // names as special and SHOULD always return the IP loopback address
124        // for address queries and negative responses for all other query
125        // types.  Name resolution APIs SHOULD NOT send queries for
126        // localhost names to their configured caching DNS server(s).
127        // ```
128        // special use rules only apply to the IN Class
129        if query.query_class() == DNSClass::IN {
130            let usage = match query.name() {
131                n if LOCALHOST_usage.zone_of(n) => &*LOCALHOST_usage,
132                n if IN_ADDR_ARPA_127.zone_of(n) => &*LOCALHOST_usage,
133                n if IP6_ARPA_1.zone_of(n) => &*LOCALHOST_usage,
134                n if INVALID.zone_of(n) => &*INVALID,
135                n if LOCAL.zone_of(n) => &*LOCAL,
136                n if ONION.zone_of(n) => &*ONION,
137                _ => &*DEFAULT,
138            };
139
140            match usage.resolver() {
141                ResolverUsage::Loopback => match query.query_type() {
142                    // TODO: look in hosts for these ips/names first...
143                    RecordType::A => return Ok(Lookup::from_rdata(query, LOCALHOST_V4.clone())),
144                    RecordType::AAAA => return Ok(Lookup::from_rdata(query, LOCALHOST_V6.clone())),
145                    RecordType::PTR => return Ok(Lookup::from_rdata(query, LOCALHOST.clone())),
146                    _ => {
147                        return Err(ProtoError::nx_error(
148                            Box::new(query),
149                            None,
150                            None,
151                            None,
152                            ResponseCode::NoError,
153                            false,
154                            None,
155                        ));
156                    } // Are there any other types we can use?
157                },
158                // TODO: this requires additional config, as Kubernetes and other systems misuse the .local. zone.
159                // when mdns is not enabled we will return errors on LinkLocal ("*.local.") names
160                ResolverUsage::LinkLocal => (),
161                ResolverUsage::NxDomain => {
162                    return Err(ProtoError::nx_error(
163                        Box::new(query),
164                        None,
165                        None,
166                        None,
167                        ResponseCode::NXDomain,
168                        false,
169                        None,
170                    ));
171                }
172                ResolverUsage::Normal => (),
173            }
174        }
175
176        let is_dnssec = client.client.is_verifying_dnssec();
177
178        // first transition any polling that is needed (mutable refs...)
179        if let Some(cached_lookup) = client.lookup_from_cache(&query) {
180            return cached_lookup;
181        };
182
183        let response_message = client
184            .client
185            .lookup(query.clone(), options)
186            .first_answer()
187            .await;
188
189        // TODO: technically this might be duplicating work, as name_server already performs this evaluation.
190        //  we may want to create a new type, if evaluated... but this is most generic to support any impl in LookupState...
191        let response_message = if let Ok(response) = response_message {
192            ProtoError::from_response(response, false)
193        } else {
194            response_message
195        };
196
197        // TODO: take all records and cache them?
198        //  if it's DNSSEC they must be signed, otherwise?
199        let records: Result<Records, ProtoError> = match response_message {
200            // this is the only cacheable form
201            Err(e) => {
202                match e.kind() {
203                    ProtoErrorKind::NoRecordsFound {
204                        query,
205                        soa,
206                        negative_ttl,
207                        response_code,
208                        trusted,
209                        ns,
210                        ..
211                    } => {
212                        Err(Self::handle_nxdomain(
213                            is_dnssec,
214                            false, /*tbd*/
215                            query.as_ref().clone(),
216                            soa.as_ref().map(Box::as_ref).cloned(),
217                            ns.clone(),
218                            *negative_ttl,
219                            *response_code,
220                            *trusted,
221                        ))
222                    }
223                    _ => return Err(e),
224                }
225            }
226            Ok(response_message) => {
227                // allow the handle_noerror function to deal with any error codes
228                let records = Self::handle_noerror(
229                    &mut client,
230                    options,
231                    is_dnssec,
232                    &query,
233                    response_message,
234                    preserved_records,
235                    depth,
236                )?;
237
238                Ok(records)
239            }
240        };
241
242        // after the request, evaluate if we have additional queries to perform
243        match records {
244            Ok(Records::CnameChain {
245                next: future,
246                min_ttl: ttl,
247            }) => match future.await {
248                Ok(lookup) => client.cname(lookup, query, ttl),
249                Err(e) => client.cache(query, Err(e)),
250            },
251            Ok(Records::Exists(rdata)) => client.cache(query, Ok(rdata)),
252            Err(e) => client.cache(query, Err(e)),
253        }
254    }
255
256    /// Check if this query is already cached
257    fn lookup_from_cache(&self, query: &Query) -> Option<Result<Lookup, ProtoError>> {
258        self.lru.get(query, Instant::now())
259    }
260
261    /// See https://tools.ietf.org/html/rfc2308
262    ///
263    /// For now we will regard NXDomain to strictly mean the query failed
264    ///  and a record for the name, regardless of CNAME presence, what have you
265    ///  ultimately does not exist.
266    ///
267    /// This also handles empty responses in the same way. When performing DNSSEC enabled queries, we should
268    ///  never enter here, and should never cache unless verified requests.
269    ///
270    /// TODO: should this should be expanded to do a forward lookup? Today, this will fail even if there are
271    ///   forwarding options.
272    ///
273    /// # Arguments
274    ///
275    /// * `message` - message to extract SOA, etc, from for caching failed requests
276    /// * `valid_nsec` - species that in DNSSEC mode, this request is safe to cache
277    /// * `negative_ttl` - this should be the SOA minimum for negative ttl
278    #[allow(clippy::too_many_arguments)]
279    fn handle_nxdomain(
280        is_dnssec: bool,
281        valid_nsec: bool,
282        query: Query,
283        soa: Option<Record<SOA>>,
284        ns: Option<Arc<[ForwardNSData]>>,
285        negative_ttl: Option<u32>,
286        response_code: ResponseCode,
287        trusted: bool,
288    ) -> ProtoError {
289        if valid_nsec || !is_dnssec {
290            // only trust if there were validated NSEC records
291            ProtoErrorKind::NoRecordsFound {
292                query: Box::new(query),
293                soa: soa.map(Box::new),
294                ns,
295                negative_ttl,
296                response_code,
297                trusted: true,
298                authorities: None,
299            }
300            .into()
301        } else {
302            // not cacheable, no ttl...
303            ProtoErrorKind::NoRecordsFound {
304                query: Box::new(query),
305                soa: soa.map(Box::new),
306                ns,
307                negative_ttl: None,
308                response_code,
309                trusted,
310                authorities: None,
311            }
312            .into()
313        }
314    }
315
316    /// Handle the case where there is no error returned
317    fn handle_noerror(
318        client: &mut Self,
319        options: DnsRequestOptions,
320        is_dnssec: bool,
321        query: &Query,
322        response: DnsResponse,
323        mut preserved_records: Vec<(Record, u32)>,
324        depth: DepthTracker,
325    ) -> Result<Records, ProtoError> {
326        // initial ttl is what CNAMES for min usage
327        const INITIAL_TTL: u32 = dns_lru::MAX_TTL;
328
329        // need to capture these before the subsequent and destructive record processing
330        let soa = response.soa().as_ref().map(RecordRef::to_owned);
331        let negative_ttl = response.negative_ttl();
332        let response_code = response.response_code();
333
334        // seek out CNAMES, this is only performed if the query is not a CNAME, ANY, or SRV
335        // FIXME: for SRV this evaluation is inadequate. CNAME is a single chain to a single record
336        //   for SRV, there could be many different targets. The search_name needs to be enhanced to
337        //   be a list of names found for SRV records.
338        let (search_name, cname_ttl, was_cname, preserved_records) = {
339            // this will only search for CNAMEs if the request was not meant to be for one of the triggers for recursion
340            let (search_name, cname_ttl, was_cname) =
341                if query.query_type().is_any() || query.query_type().is_cname() {
342                    (Cow::Borrowed(query.name()), INITIAL_TTL, false)
343                } else {
344                    // Folds any cnames from the answers section, into the final cname in the answers section
345                    //   this works by folding the last CNAME found into the final folded result.
346                    //   it assumes that the CNAMEs are in chained order in the DnsResponse Message...
347                    // For SRV, the name added for the search becomes the target name.
348                    //
349                    // TODO: should this include the additionals?
350                    response.answers().iter().fold(
351                        (Cow::Borrowed(query.name()), INITIAL_TTL, false),
352                        |(search_name, cname_ttl, was_cname), r| {
353                            match r.data() {
354                                RData::CNAME(CNAME(cname)) => {
355                                    // take the minimum TTL of the cname_ttl and the next record in the chain
356                                    let ttl = cname_ttl.min(r.ttl());
357                                    debug_assert_eq!(r.record_type(), RecordType::CNAME);
358                                    if search_name.as_ref() == r.name() {
359                                        return (Cow::Owned(cname.clone()), ttl, true);
360                                    }
361                                }
362                                RData::SRV(srv) => {
363                                    // take the minimum TTL of the cname_ttl and the next record in the chain
364                                    let ttl = cname_ttl.min(r.ttl());
365                                    debug_assert_eq!(r.record_type(), RecordType::SRV);
366
367                                    // the search name becomes the srv.target
368                                    return (Cow::Owned(srv.target().clone()), ttl, true);
369                                }
370                                _ => (),
371                            }
372
373                            (search_name, cname_ttl, was_cname)
374                        },
375                    )
376                };
377
378            // take all answers. // TODO: following CNAMES?
379            let mut response = response.into_message();
380            let answers = response.take_answers();
381            let additionals = response.take_additionals();
382            let name_servers = response.take_name_servers();
383
384            // set of names that still require resolution
385            // TODO: this needs to be enhanced for SRV
386            let mut found_name = false;
387
388            // After following all the CNAMES to the last one, try and lookup the final name
389            let records = answers
390                .into_iter()
391                // Chained records will generally exist in the additionals section
392                .chain(additionals)
393                .chain(name_servers)
394                .filter_map(|r| {
395                    // because this resolved potentially recursively, we want the min TTL from the chain
396                    let ttl = cname_ttl.min(r.ttl());
397                    // TODO: disable name validation with ResolverOpts? glibc feature...
398                    // restrict to the RData type requested
399                    if query.query_class() == r.dns_class() {
400                        // standard evaluation, it's an any type or it's the requested type and the search_name matches
401                        #[allow(clippy::suspicious_operation_groupings)]
402                        if (query.query_type().is_any() || query.query_type() == r.record_type())
403                            && (search_name.as_ref() == r.name() || query.name() == r.name())
404                        {
405                            found_name = true;
406                            return Some((r, ttl));
407                        }
408                        // CNAME evaluation, the record is from the CNAME lookup chain.
409                        if client.preserve_intermediates && r.record_type() == RecordType::CNAME {
410                            return Some((r, ttl));
411                        }
412                        // srv evaluation, it's an srv lookup and the srv_search_name/target matches this name
413                        //    and it's an IP
414                        if query.query_type().is_srv()
415                            && r.record_type().is_ip_addr()
416                            && search_name.as_ref() == r.name()
417                        {
418                            found_name = true;
419                            Some((r, ttl))
420                        } else if query.query_type().is_ns() && r.record_type().is_ip_addr() {
421                            Some((r, ttl))
422                        } else {
423                            None
424                        }
425                    } else {
426                        None
427                    }
428                })
429                .collect::<Vec<_>>();
430
431            // adding the newly collected records to the preserved records
432            preserved_records.extend(records);
433            if !preserved_records.is_empty() && found_name {
434                return Ok(Records::Exists(preserved_records));
435            }
436
437            (
438                search_name.into_owned(),
439                cname_ttl,
440                was_cname,
441                preserved_records,
442            )
443        };
444
445        // TODO: for SRV records we *could* do an implicit lookup, but, this requires knowing the type of IP desired
446        //    for now, we'll make the API require the user to perform a follow up to the lookups.
447        // It was a CNAME, but not included in the request...
448        if was_cname && !depth.is_exhausted() {
449            let next_query = Query::query(search_name, query.query_type());
450            Ok(Records::CnameChain {
451                next: Box::pin(Self::inner_lookup(
452                    next_query,
453                    options,
454                    client.clone(),
455                    preserved_records,
456                    depth.nest(),
457                )),
458                min_ttl: cname_ttl,
459            })
460        } else {
461            // TODO: review See https://tools.ietf.org/html/rfc2308 for NoData section
462            // Note on DNSSEC, in secure_client_handle, if verify_nsec fails then the request fails.
463            //   this will mean that no unverified negative caches will make it to this point and be stored
464            Err(Self::handle_nxdomain(
465                is_dnssec,
466                true,
467                query.clone(),
468                soa,
469                None,
470                negative_ttl,
471                response_code,
472                false,
473            ))
474        }
475    }
476
477    #[allow(clippy::unnecessary_wraps)]
478    fn cname(&self, lookup: Lookup, query: Query, cname_ttl: u32) -> Result<Lookup, ProtoError> {
479        // this duplicates the cache entry under the original query
480        Ok(self.lru.duplicate(query, lookup, cname_ttl, Instant::now()))
481    }
482
483    fn cache(
484        &self,
485        query: Query,
486        records: Result<Vec<(Record, u32)>, ProtoError>,
487    ) -> Result<Lookup, ProtoError> {
488        // this will put this object into an inconsistent state, but no one should call poll again...
489        match records {
490            Ok(rdata) => Ok(self.lru.insert(query, rdata, Instant::now())),
491            Err(err) => Err(self.lru.negative(query, err, Instant::now())),
492        }
493    }
494
495    /// Flushes/Removes all entries from the cache
496    pub fn clear_cache(&self) {
497        self.lru.clear();
498    }
499}
500
501enum Records {
502    /// The records exists, a vec of rdata with ttl
503    Exists(Vec<(Record, u32)>),
504    /// Future lookup for recursive cname records
505    CnameChain {
506        next: Pin<Box<dyn Future<Output = Result<Lookup, ProtoError>> + Send>>,
507        min_ttl: u32,
508    },
509}
510
511// see also the lookup_tests.rs in integration-tests crate
512#[cfg(test)]
513mod tests {
514    use std::net::*;
515    use std::str::FromStr;
516    use std::time::*;
517
518    use crate::proto::op::{Message, Query};
519    use crate::proto::rr::rdata::{NS, SRV};
520    use crate::proto::rr::{Name, Record};
521    use futures_executor::block_on;
522    use test_support::subscribe;
523
524    use super::*;
525    use crate::lookup_ip::tests::*;
526
527    #[test]
528    fn test_empty_cache() {
529        subscribe();
530        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
531        let client = mock(vec![empty()]);
532        let client = CachingClient::with_cache(cache, client, false);
533
534        if let ProtoErrorKind::NoRecordsFound {
535            query,
536            negative_ttl,
537            ..
538        } = block_on(CachingClient::inner_lookup(
539            Query::new(),
540            DnsRequestOptions::default(),
541            client,
542            vec![],
543            DepthTracker::default(),
544        ))
545        .unwrap_err()
546        .kind()
547        {
548            assert_eq!(*query, Box::new(Query::new()));
549            assert_eq!(*negative_ttl, None);
550        } else {
551            panic!("wrong error received")
552        }
553    }
554
555    #[test]
556    fn test_from_cache() {
557        subscribe();
558        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
559        let query = Query::new();
560        cache.insert(
561            query.clone(),
562            vec![(
563                Record::from_rdata(
564                    query.name().clone(),
565                    u32::MAX,
566                    RData::A(A::new(127, 0, 0, 1)),
567                ),
568                u32::MAX,
569            )],
570            Instant::now(),
571        );
572
573        let client = mock(vec![empty()]);
574        let client = CachingClient::with_cache(cache, client, false);
575
576        let ips = block_on(CachingClient::inner_lookup(
577            Query::new(),
578            DnsRequestOptions::default(),
579            client,
580            vec![],
581            DepthTracker::default(),
582        ))
583        .unwrap();
584
585        assert_eq!(
586            ips.iter().cloned().collect::<Vec<_>>(),
587            vec![RData::A(A::new(127, 0, 0, 1))]
588        );
589    }
590
591    #[test]
592    fn test_no_cache_insert() {
593        subscribe();
594        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
595        // first should come from client...
596        let client = mock(vec![v4_message()]);
597        let client = CachingClient::with_cache(cache.clone(), client, false);
598
599        let ips = block_on(CachingClient::inner_lookup(
600            Query::query(Name::root(), RecordType::A),
601            DnsRequestOptions::default(),
602            client,
603            vec![],
604            DepthTracker::default(),
605        ))
606        .unwrap();
607
608        assert_eq!(
609            ips.iter().cloned().collect::<Vec<_>>(),
610            vec![RData::A(A::new(127, 0, 0, 1))]
611        );
612
613        // next should come from cache...
614        let client = mock(vec![empty()]);
615        let client = CachingClient::with_cache(cache, client, false);
616
617        let ips = block_on(CachingClient::inner_lookup(
618            Query::query(Name::root(), RecordType::A),
619            DnsRequestOptions::default(),
620            client,
621            vec![],
622            DepthTracker::default(),
623        ))
624        .unwrap();
625
626        assert_eq!(
627            ips.iter().cloned().collect::<Vec<_>>(),
628            vec![RData::A(A::new(127, 0, 0, 1))]
629        );
630    }
631
632    #[allow(clippy::unnecessary_wraps)]
633    pub(crate) fn cname_message() -> Result<DnsResponse, ProtoError> {
634        let mut message = Message::new();
635        message.add_query(Query::query(
636            Name::from_str("www.example.com.").unwrap(),
637            RecordType::A,
638        ));
639        message.insert_answers(vec![Record::from_rdata(
640            Name::from_str("www.example.com.").unwrap(),
641            86400,
642            RData::CNAME(CNAME(Name::from_str("actual.example.com.").unwrap())),
643        )]);
644        Ok(DnsResponse::from_message(message).unwrap())
645    }
646
647    #[allow(clippy::unnecessary_wraps)]
648    pub(crate) fn srv_message() -> Result<DnsResponse, ProtoError> {
649        let mut message = Message::new();
650        message.add_query(Query::query(
651            Name::from_str("_443._tcp.www.example.com.").unwrap(),
652            RecordType::SRV,
653        ));
654        message.insert_answers(vec![Record::from_rdata(
655            Name::from_str("_443._tcp.www.example.com.").unwrap(),
656            86400,
657            RData::SRV(SRV::new(
658                1,
659                2,
660                443,
661                Name::from_str("www.example.com.").unwrap(),
662            )),
663        )]);
664        Ok(DnsResponse::from_message(message).unwrap())
665    }
666
667    #[allow(clippy::unnecessary_wraps)]
668    pub(crate) fn ns_message() -> Result<DnsResponse, ResolveError> {
669        let mut message = Message::new();
670        message.add_query(Query::query(
671            Name::from_str("www.example.com.").unwrap(),
672            RecordType::NS,
673        ));
674        message.insert_answers(vec![Record::from_rdata(
675            Name::from_str("www.example.com.").unwrap(),
676            86400,
677            RData::NS(NS(Name::from_str("www.example.com.").unwrap())),
678        )]);
679        Ok(DnsResponse::from_message(message).unwrap())
680    }
681
682    fn no_recursion_on_query_test(query_type: RecordType) {
683        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
684
685        // the cname should succeed, we shouldn't query again after that, which would cause an error...
686        let client = mock(vec![error(), cname_message()]);
687        let client = CachingClient::with_cache(cache, client, false);
688
689        let ips = block_on(CachingClient::inner_lookup(
690            Query::query(Name::from_str("www.example.com.").unwrap(), query_type),
691            DnsRequestOptions::default(),
692            client,
693            vec![],
694            DepthTracker::default(),
695        ))
696        .expect("lookup failed");
697
698        assert_eq!(
699            ips.iter().cloned().collect::<Vec<_>>(),
700            vec![RData::CNAME(CNAME(
701                Name::from_str("actual.example.com.").unwrap()
702            ))]
703        );
704    }
705
706    #[test]
707    fn test_no_recursion_on_cname_query() {
708        subscribe();
709        no_recursion_on_query_test(RecordType::CNAME);
710    }
711
712    #[test]
713    fn test_no_recursion_on_all_query() {
714        subscribe();
715        no_recursion_on_query_test(RecordType::ANY);
716    }
717
718    #[test]
719    fn test_non_recursive_srv_query() {
720        subscribe();
721
722        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
723
724        // the cname should succeed, we shouldn't query again after that, which would cause an error...
725        let client = mock(vec![error(), srv_message()]);
726        let client = CachingClient::with_cache(cache, client, false);
727
728        let ips = block_on(CachingClient::inner_lookup(
729            Query::query(
730                Name::from_str("_443._tcp.www.example.com.").unwrap(),
731                RecordType::SRV,
732            ),
733            DnsRequestOptions::default(),
734            client,
735            vec![],
736            DepthTracker::default(),
737        ))
738        .expect("lookup failed");
739
740        assert_eq!(
741            ips.iter().cloned().collect::<Vec<_>>(),
742            vec![RData::SRV(SRV::new(
743                1,
744                2,
745                443,
746                Name::from_str("www.example.com.").unwrap(),
747            ))]
748        );
749    }
750
751    #[test]
752    fn test_single_srv_query_response() {
753        subscribe();
754
755        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
756
757        let mut message = srv_message().unwrap().into_message();
758        message.add_answer(Record::from_rdata(
759            Name::from_str("www.example.com.").unwrap(),
760            86400,
761            RData::CNAME(CNAME(Name::from_str("actual.example.com.").unwrap())),
762        ));
763        message.insert_additionals(vec![
764            Record::from_rdata(
765                Name::from_str("actual.example.com.").unwrap(),
766                86400,
767                RData::A(A::new(127, 0, 0, 1)),
768            ),
769            Record::from_rdata(
770                Name::from_str("actual.example.com.").unwrap(),
771                86400,
772                RData::AAAA(AAAA::new(0, 0, 0, 0, 0, 0, 0, 1)),
773            ),
774        ]);
775
776        let client = mock(vec![
777            error(),
778            Ok(DnsResponse::from_message(message).unwrap()),
779        ]);
780        let client = CachingClient::with_cache(cache, client, false);
781
782        let ips = block_on(CachingClient::inner_lookup(
783            Query::query(
784                Name::from_str("_443._tcp.www.example.com.").unwrap(),
785                RecordType::SRV,
786            ),
787            DnsRequestOptions::default(),
788            client,
789            vec![],
790            DepthTracker::default(),
791        ))
792        .expect("lookup failed");
793
794        assert_eq!(
795            ips.iter().cloned().collect::<Vec<_>>(),
796            vec![
797                RData::SRV(SRV::new(
798                    1,
799                    2,
800                    443,
801                    Name::from_str("www.example.com.").unwrap(),
802                )),
803                RData::A(A::new(127, 0, 0, 1)),
804                RData::AAAA(AAAA::new(0, 0, 0, 0, 0, 0, 0, 1)),
805            ]
806        );
807    }
808
809    // TODO: if we ever enable recursive lookups for SRV, here are the tests...
810    // #[test]
811    // fn test_recursive_srv_query() {
812    //     let cache = Arc::new(Mutex::new(DnsLru::new(1)));
813
814    //     let mut message = Message::new();
815    //     message.add_answer(Record::from_rdata(
816    //         Name::from_str("www.example.com.").unwrap(),
817    //         86400,
818    //         RecordType::CNAME,
819    //         RData::CNAME(Name::from_str("actual.example.com.").unwrap()),
820    //     ));
821    //     message.insert_additionals(vec![
822    //         Record::from_rdata(
823    //             Name::from_str("actual.example.com.").unwrap(),
824    //             86400,
825    //             RecordType::A,
826    //             RData::A(Ipv4Addr::LOCALHOST),
827    //         ),
828    //     ]);
829
830    //     let mut client = mock(vec![error(), Ok(DnsResponse::from_message(message).unwrap()), srv_message()]);
831
832    //     let ips = QueryState::lookup(
833    //         Query::query(
834    //             Name::from_str("_443._tcp.www.example.com.").unwrap(),
835    //             RecordType::SRV,
836    //         ),
837    //         Default::default(),
838    //         &mut client,
839    //         cache.clone(),
840    //     ).wait()
841    //         .expect("lookup failed");
842
843    //     assert_eq!(
844    //         ips.iter().cloned().collect::<Vec<_>>(),
845    //         vec![
846    //             RData::SRV(SRV::new(
847    //                 1,
848    //                 2,
849    //                 443,
850    //                 Name::from_str("www.example.com.").unwrap(),
851    //             )),
852    //             RData::A(Ipv4Addr::LOCALHOST),
853    //             //RData::AAAA(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
854    //         ]
855    //     );
856    // }
857
858    #[test]
859    fn test_single_ns_query_response() {
860        subscribe();
861
862        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
863
864        let mut message = ns_message().unwrap().into_message();
865        message.add_answer(Record::from_rdata(
866            Name::from_str("www.example.com.").unwrap(),
867            86400,
868            RData::CNAME(CNAME(Name::from_str("actual.example.com.").unwrap())),
869        ));
870        message.insert_additionals(vec![
871            Record::from_rdata(
872                Name::from_str("actual.example.com.").unwrap(),
873                86400,
874                RData::A(A::new(127, 0, 0, 1)),
875            ),
876            Record::from_rdata(
877                Name::from_str("actual.example.com.").unwrap(),
878                86400,
879                RData::AAAA(AAAA::new(0, 0, 0, 0, 0, 0, 0, 1)),
880            ),
881        ]);
882
883        let client = mock(vec![
884            error(),
885            Ok(DnsResponse::from_message(message).unwrap()),
886        ]);
887        let client = CachingClient::with_cache(cache, client, false);
888
889        let ips = block_on(CachingClient::inner_lookup(
890            Query::query(Name::from_str("www.example.com.").unwrap(), RecordType::NS),
891            DnsRequestOptions::default(),
892            client,
893            vec![],
894            DepthTracker::default(),
895        ))
896        .expect("lookup failed");
897
898        assert_eq!(
899            ips.iter().cloned().collect::<Vec<_>>(),
900            vec![
901                RData::NS(NS(Name::from_str("www.example.com.").unwrap())),
902                RData::A(A::new(127, 0, 0, 1)),
903                RData::AAAA(AAAA::new(0, 0, 0, 0, 0, 0, 0, 1)),
904            ]
905        );
906    }
907
908    fn cname_ttl_test(first: u32, second: u32) {
909        let lru = DnsLru::new(1, dns_lru::TtlConfig::default());
910        // expecting no queries to be performed
911        let mut client = CachingClient::with_cache(lru, mock(vec![error()]), false);
912
913        let mut message = Message::new();
914        message.insert_answers(vec![Record::from_rdata(
915            Name::from_str("ttl.example.com.").unwrap(),
916            first,
917            RData::CNAME(CNAME(Name::from_str("actual.example.com.").unwrap())),
918        )]);
919        message.insert_additionals(vec![Record::from_rdata(
920            Name::from_str("actual.example.com.").unwrap(),
921            second,
922            RData::A(A::new(127, 0, 0, 1)),
923        )]);
924
925        let records = CachingClient::handle_noerror(
926            &mut client,
927            DnsRequestOptions::default(),
928            false,
929            &Query::query(Name::from_str("ttl.example.com.").unwrap(), RecordType::A),
930            DnsResponse::from_message(message).unwrap(),
931            vec![],
932            DepthTracker::default(),
933        );
934
935        if let Ok(records) = records {
936            if let Records::Exists(records) = records {
937                for (record, ttl) in records.iter() {
938                    if record.record_type() == RecordType::CNAME {
939                        continue;
940                    }
941                    assert_eq!(ttl, &1);
942                }
943            } else {
944                panic!("records don't exist");
945            }
946        } else {
947            panic!("error getting records");
948        }
949    }
950
951    #[test]
952    fn test_cname_ttl() {
953        subscribe();
954        cname_ttl_test(1, 2);
955        cname_ttl_test(2, 1);
956    }
957
958    #[test]
959    fn test_early_return_localhost() {
960        subscribe();
961        let cache = DnsLru::new(0, dns_lru::TtlConfig::default());
962        let client = mock(vec![empty()]);
963        let client = CachingClient::with_cache(cache, client, false);
964
965        {
966            let query = Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::A);
967            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
968                .expect("should have returned localhost");
969            assert_eq!(lookup.query(), &query);
970            assert_eq!(
971                lookup.iter().cloned().collect::<Vec<_>>(),
972                vec![LOCALHOST_V4.clone()]
973            );
974        }
975
976        {
977            let query = Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::AAAA);
978            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
979                .expect("should have returned localhost");
980            assert_eq!(lookup.query(), &query);
981            assert_eq!(
982                lookup.iter().cloned().collect::<Vec<_>>(),
983                vec![LOCALHOST_V6.clone()]
984            );
985        }
986
987        {
988            let query = Query::query(Name::from(Ipv4Addr::LOCALHOST), RecordType::PTR);
989            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
990                .expect("should have returned localhost");
991            assert_eq!(lookup.query(), &query);
992            assert_eq!(
993                lookup.iter().cloned().collect::<Vec<_>>(),
994                vec![LOCALHOST.clone()]
995            );
996        }
997
998        {
999            let query = Query::query(
1000                Name::from(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
1001                RecordType::PTR,
1002            );
1003            let lookup = block_on(client.lookup(query.clone(), DnsRequestOptions::default()))
1004                .expect("should have returned localhost");
1005            assert_eq!(lookup.query(), &query);
1006            assert_eq!(
1007                lookup.iter().cloned().collect::<Vec<_>>(),
1008                vec![LOCALHOST.clone()]
1009            );
1010        }
1011
1012        assert!(
1013            block_on(client.lookup(
1014                Query::query(Name::from_ascii("localhost.").unwrap(), RecordType::MX),
1015                DnsRequestOptions::default()
1016            ))
1017            .is_err()
1018        );
1019
1020        assert!(
1021            block_on(client.lookup(
1022                Query::query(Name::from(Ipv4Addr::LOCALHOST), RecordType::MX),
1023                DnsRequestOptions::default()
1024            ))
1025            .is_err()
1026        );
1027
1028        assert!(
1029            block_on(client.lookup(
1030                Query::query(
1031                    Name::from(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
1032                    RecordType::MX
1033                ),
1034                DnsRequestOptions::default()
1035            ))
1036            .is_err()
1037        );
1038    }
1039
1040    #[test]
1041    fn test_early_return_invalid() {
1042        subscribe();
1043        let cache = DnsLru::new(0, dns_lru::TtlConfig::default());
1044        let client = mock(vec![empty()]);
1045        let client = CachingClient::with_cache(cache, client, false);
1046
1047        assert!(
1048            block_on(client.lookup(
1049                Query::query(
1050                    Name::from_ascii("horrible.invalid.").unwrap(),
1051                    RecordType::A,
1052                ),
1053                DnsRequestOptions::default()
1054            ))
1055            .is_err()
1056        );
1057    }
1058
1059    #[test]
1060    fn test_no_error_on_dot_local_no_mdns() {
1061        subscribe();
1062
1063        let cache = DnsLru::new(1, dns_lru::TtlConfig::default());
1064
1065        let mut message = srv_message().unwrap().into_message();
1066        message.add_query(Query::query(
1067            Name::from_ascii("www.example.local.").unwrap(),
1068            RecordType::A,
1069        ));
1070        message.add_answer(Record::from_rdata(
1071            Name::from_str("www.example.local.").unwrap(),
1072            86400,
1073            RData::A(A::new(127, 0, 0, 1)),
1074        ));
1075
1076        let client = mock(vec![
1077            error(),
1078            Ok(DnsResponse::from_message(message).unwrap()),
1079        ]);
1080        let client = CachingClient::with_cache(cache, client, false);
1081
1082        assert!(
1083            block_on(client.lookup(
1084                Query::query(
1085                    Name::from_ascii("www.example.local.").unwrap(),
1086                    RecordType::A,
1087                ),
1088                DnsRequestOptions::default()
1089            ))
1090            .is_ok()
1091        );
1092    }
1093}