hickory_resolver/name_server/
name_server_pool.rs

1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::cmp::Ordering;
9use std::pin::Pin;
10use std::sync::{
11    Arc,
12    atomic::{AtomicUsize, Ordering as AtomicOrdering},
13};
14use std::task::{Context, Poll};
15use std::time::Duration;
16
17use futures_util::future::FutureExt;
18use futures_util::stream::{FuturesUnordered, Stream, StreamExt, once};
19use smallvec::SmallVec;
20use tracing::debug;
21
22use crate::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts, ServerOrderingStrategy};
23use crate::name_server::connection_provider::{ConnectionProvider, GenericConnector};
24use crate::name_server::name_server::NameServer;
25use crate::proto::runtime::{RuntimeProvider, Time};
26use crate::proto::xfer::{DnsHandle, DnsRequest, DnsResponse, FirstAnswer};
27use crate::proto::{ProtoError, ProtoErrorKind};
28
29/// A pool of NameServers
30///
31/// This is not expected to be used directly, see [crate::Resolver].
32pub type GenericNameServerPool<P> = NameServerPool<GenericConnector<P>>;
33
34/// Abstract interface for mocking purpose
35#[derive(Clone)]
36pub struct NameServerPool<P: ConnectionProvider + Send + 'static> {
37    // TODO: switch to FuturesMutex (Mutex will have some undesirable locking)
38    datagram_conns: Arc<[NameServer<P>]>, /* All NameServers must be the same type */
39    stream_conns: Arc<[NameServer<P>]>,   /* All NameServers must be the same type */
40    options: ResolverOpts,
41    datagram_index: Arc<AtomicUsize>,
42    stream_index: Arc<AtomicUsize>,
43}
44
45impl<P> NameServerPool<P>
46where
47    P: ConnectionProvider + 'static,
48{
49    pub(crate) fn from_config_with_provider(
50        config: &ResolverConfig,
51        options: ResolverOpts,
52        conn_provider: P,
53    ) -> Self {
54        let datagram_conns = config
55            .name_servers()
56            .iter()
57            .filter(|ns_config| ns_config.protocol.is_datagram())
58            .map(|ns_config| {
59                NameServer::new(ns_config.clone(), options.clone(), conn_provider.clone())
60            })
61            .collect();
62
63        let stream_conns = config
64            .name_servers()
65            .iter()
66            .filter(|ns_config| ns_config.protocol.is_stream())
67            .map(|ns_config| {
68                NameServer::new(ns_config.clone(), options.clone(), conn_provider.clone())
69            })
70            .collect();
71
72        Self {
73            datagram_conns,
74            stream_conns,
75            options,
76            datagram_index: Arc::from(AtomicUsize::new(0)),
77            stream_index: Arc::from(AtomicUsize::new(0)),
78        }
79    }
80
81    /// Construct a NameServerPool from a set of name server configs
82    pub fn from_config(
83        name_servers: NameServerConfigGroup,
84        options: ResolverOpts,
85        conn_provider: P,
86    ) -> Self {
87        let map_config_to_ns =
88            |ns_config| NameServer::new(ns_config, options.clone(), conn_provider.clone());
89
90        let (datagram, stream): (Vec<_>, Vec<_>) = name_servers
91            .into_inner()
92            .into_iter()
93            .partition(|ns| ns.protocol.is_datagram());
94
95        let datagram_conns: Vec<_> = datagram.into_iter().map(map_config_to_ns).collect();
96        let stream_conns: Vec<_> = stream.into_iter().map(map_config_to_ns).collect();
97
98        Self {
99            datagram_conns: Arc::from(datagram_conns),
100            stream_conns: Arc::from(stream_conns),
101            options,
102            datagram_index: Arc::from(AtomicUsize::new(0)),
103            stream_index: Arc::from(AtomicUsize::new(0)),
104        }
105    }
106
107    #[doc(hidden)]
108    pub fn from_nameservers(
109        options: ResolverOpts,
110        datagram_conns: Vec<NameServer<P>>,
111        stream_conns: Vec<NameServer<P>>,
112    ) -> Self {
113        Self {
114            datagram_conns: Arc::from(datagram_conns),
115            stream_conns: Arc::from(stream_conns),
116            options,
117            datagram_index: Arc::from(AtomicUsize::new(0)),
118            stream_index: Arc::from(AtomicUsize::new(0)),
119        }
120    }
121
122    /// Returns the pool's options.
123    pub fn options(&self) -> &ResolverOpts {
124        &self.options
125    }
126
127    #[cfg(test)]
128    #[allow(dead_code)]
129    fn from_nameservers_test(
130        options: ResolverOpts,
131        datagram_conns: Arc<[NameServer<P>]>,
132        stream_conns: Arc<[NameServer<P>]>,
133    ) -> Self {
134        Self {
135            datagram_conns,
136            stream_conns,
137            options,
138            datagram_index: Arc::from(AtomicUsize::new(0)),
139            stream_index: Arc::from(AtomicUsize::new(0)),
140        }
141    }
142
143    async fn try_send(
144        opts: ResolverOpts,
145        conns: Arc<[NameServer<P>]>,
146        request: DnsRequest,
147        next_index: &Arc<AtomicUsize>,
148    ) -> Result<DnsResponse, ProtoError> {
149        let mut conns: Vec<NameServer<P>> = conns.to_vec();
150
151        match opts.server_ordering_strategy {
152            // select the highest priority connection
153            //   reorder the connections based on current view...
154            //   this reorders the inner set
155            ServerOrderingStrategy::QueryStatistics => {
156                conns.sort_by(|a, b| a.stats.decayed_srtt().total_cmp(&b.stats.decayed_srtt()));
157            }
158            ServerOrderingStrategy::UserProvidedOrder => {}
159            ServerOrderingStrategy::RoundRobin => {
160                let num_concurrent_reqs = if opts.num_concurrent_reqs > 1 {
161                    opts.num_concurrent_reqs
162                } else {
163                    1
164                };
165                if num_concurrent_reqs < conns.len() {
166                    let index = next_index.fetch_add(num_concurrent_reqs, AtomicOrdering::SeqCst)
167                        % conns.len();
168                    conns.rotate_left(index);
169                }
170            }
171        }
172        let request_loop = request.clone();
173
174        parallel_conn_loop(conns, request_loop, opts).await
175    }
176}
177
178impl<P> DnsHandle for NameServerPool<P>
179where
180    P: ConnectionProvider + 'static,
181{
182    type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>;
183
184    fn send<R: Into<DnsRequest>>(&self, request: R) -> Self::Response {
185        let opts = self.options.clone();
186        let request = request.into();
187        let datagram_conns = Arc::clone(&self.datagram_conns);
188        let stream_conns = Arc::clone(&self.stream_conns);
189        let datagram_index = Arc::clone(&self.datagram_index);
190        let stream_index = Arc::clone(&self.stream_index);
191        // TODO: remove this clone, return the Message in the error?
192        // TODO: remove this clone, return the Message in the error?
193        let tcp_message = request.clone();
194
195        // TODO: limited to only when mDNS is enabled, but this should probably always be enforced?
196        let mdns = Local::NotMdns(request);
197
198        // local queries are queried through mDNS
199        if mdns.is_local() {
200            return mdns.take_stream();
201        }
202
203        // TODO: should we allow mDNS to be used for standard lookups as well?
204
205        // it wasn't a local query, continue with standard lookup path
206        let request = mdns.take_request();
207        Box::pin(once(async move {
208            debug!("sending request: {:?}", request.queries());
209
210            // First try the UDP connections
211            let future = Self::try_send(opts.clone(), datagram_conns, request, &datagram_index);
212            let udp_res = match future.await {
213                Ok(response) if response.truncated() => {
214                    debug!("truncated response received, retrying over TCP");
215                    Err(ProtoError::from("received truncated response"))
216                }
217                Err(e)
218                    if (opts.try_tcp_on_error && e.is_io())
219                        || e.is_no_connections()
220                        || matches!(&*e.kind, ProtoErrorKind::QueryCaseMismatch) =>
221                {
222                    debug!("error from UDP, retrying over TCP: {}", e);
223                    Err(e)
224                }
225                result => return result,
226            };
227
228            if stream_conns.is_empty() {
229                debug!("no TCP connections available");
230                return udp_res;
231            }
232
233            // Try query over TCP, as response to query over UDP was either truncated or was an
234            // error.
235            Self::try_send(opts, stream_conns, tcp_message, &stream_index).await
236        }))
237    }
238}
239
240// TODO: we should be able to have a self-referential future here with Pin and not require cloned conns
241/// An async function that will loop over all the conns with a max parallel request count of ops.num_concurrent_req
242async fn parallel_conn_loop<P>(
243    mut conns: Vec<NameServer<P>>,
244    request: DnsRequest,
245    opts: ResolverOpts,
246) -> Result<DnsResponse, ProtoError>
247where
248    P: ConnectionProvider + 'static,
249{
250    let mut err = ProtoError::from(ProtoErrorKind::NoConnections);
251
252    // If the name server we're trying is giving us backpressure by returning ProtoErrorKind::Busy,
253    // we will first try the other name servers (as for other error types). However, if the other
254    // servers are also busy, we're going to wait for a little while and then retry each server that
255    // returned Busy in the previous round. If the server is still Busy, this continues, while
256    // the backoff increases exponentially (by a factor of 2), until it hits 300ms, in which case we
257    // give up. The request might still be retried by the caller (likely the DnsRetryHandle).
258    //
259    // TODO: more principled handling of timeouts. Currently, timeouts appear to be handled mostly
260    // close to the connection, which means the top level resolution might take substantially longer
261    // to fire than the timeout configured in `ResolverOpts`.
262    let mut backoff = Duration::from_millis(20);
263    let mut busy = SmallVec::<[NameServer<P>; 2]>::new();
264
265    loop {
266        let request_cont = request.clone();
267
268        // construct the parallel requests, 2 is the default
269        let mut par_conns = SmallVec::<[NameServer<P>; 2]>::new();
270        let count = conns.len().min(opts.num_concurrent_reqs.max(1));
271
272        // Shuffe DNS NameServers to avoid overloads to the first configured ones
273        for conn in conns.drain(..count) {
274            par_conns.push(conn);
275        }
276
277        if par_conns.is_empty() {
278            if !busy.is_empty() && backoff < Duration::from_millis(300) {
279                <<P as ConnectionProvider>::RuntimeProvider as RuntimeProvider>::Timer::delay_for(
280                    backoff,
281                )
282                .await;
283                conns.extend(busy.drain(..));
284                backoff *= 2;
285                continue;
286            }
287            return Err(err);
288        }
289
290        let mut requests = par_conns
291            .into_iter()
292            .map(move |conn| {
293                conn.send(request_cont.clone())
294                    .first_answer()
295                    .map(|result| result.map_err(|e| (conn, e)))
296            })
297            .collect::<FuturesUnordered<_>>();
298
299        while let Some(result) = requests.next().await {
300            let (conn, e) = match result {
301                Ok(sent) => return Ok(sent),
302                Err((conn, e)) => (conn, e),
303            };
304
305            match e.kind() {
306                ProtoErrorKind::NoRecordsFound {
307                    trusted, soa, ns, ..
308                } if *trusted || soa.is_some() || ns.is_some() => {
309                    return Err(e);
310                }
311                _ if e.is_busy() => {
312                    busy.push(conn);
313                }
314                // If our current error is the default err we start with, replace it with the
315                // new error under consideration. It was produced trying to make a connection
316                // and is more specific than the default.
317                _ if matches!(err.kind(), ProtoErrorKind::NoConnections) => {
318                    err = e;
319                }
320                _ if err.cmp_specificity(&e) == Ordering::Less => {
321                    err = e;
322                }
323                _ => {}
324            }
325        }
326    }
327}
328
329#[allow(clippy::large_enum_variant)]
330pub(crate) enum Local {
331    #[allow(dead_code)]
332    ResolveStream(Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>),
333    NotMdns(DnsRequest),
334}
335
336impl Local {
337    fn is_local(&self) -> bool {
338        matches!(*self, Self::ResolveStream(..))
339    }
340
341    /// Takes the stream
342    ///
343    /// # Panics
344    ///
345    /// Panics if this is in fact a Local::NotMdns
346    fn take_stream(self) -> Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>> {
347        match self {
348            Self::ResolveStream(future) => future,
349            _ => panic!("non Local queries have no future, see take_message()"),
350        }
351    }
352
353    /// Takes the message
354    ///
355    /// # Panics
356    ///
357    /// Panics if this is in fact a Local::ResolveStream
358    fn take_request(self) -> DnsRequest {
359        match self {
360            Self::NotMdns(request) => request,
361            _ => panic!("Local queries must be polled, see take_future()"),
362        }
363    }
364}
365
366impl Stream for Local {
367    type Item = Result<DnsResponse, ProtoError>;
368
369    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
370        match self.get_mut() {
371            Self::ResolveStream(ns) => ns.as_mut().poll_next(cx),
372            // TODO: making this a panic for now
373            Self::NotMdns(..) => panic!("Local queries that are not mDNS should not be polled"), //Local::NotMdns(message) => return Err(ResolveErrorKind::Message("not mDNS")),
374        }
375    }
376}
377
378#[cfg(test)]
379#[cfg(feature = "tokio")]
380mod tests {
381    use std::net::{IpAddr, Ipv4Addr, SocketAddr};
382    use std::str::FromStr;
383
384    use test_support::subscribe;
385    use tokio::runtime::Runtime;
386
387    use super::*;
388    use crate::config::NameServerConfig;
389    use crate::name_server::GenericNameServer;
390    use crate::name_server::connection_provider::TokioConnectionProvider;
391    use crate::proto::op::Query;
392    use crate::proto::rr::{Name, RecordType};
393    use crate::proto::runtime::TokioRuntimeProvider;
394    use crate::proto::xfer::{DnsHandle, DnsRequestOptions, Protocol};
395
396    #[ignore]
397    // because of there is a real connection that needs a reasonable timeout
398    #[test]
399    #[allow(clippy::uninlined_format_args)]
400    fn test_failed_then_success_pool() {
401        subscribe();
402
403        let config1 = NameServerConfig {
404            socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)), 253),
405            protocol: Protocol::Udp,
406            tls_dns_name: None,
407            http_endpoint: None,
408            trust_negative_responses: false,
409            bind_addr: None,
410        };
411
412        let config2 = NameServerConfig {
413            socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
414            protocol: Protocol::Udp,
415            tls_dns_name: None,
416            http_endpoint: None,
417            trust_negative_responses: false,
418            bind_addr: None,
419        };
420
421        let mut resolver_config = ResolverConfig::new();
422        resolver_config.add_name_server(config1);
423        resolver_config.add_name_server(config2);
424
425        let io_loop = Runtime::new().unwrap();
426        let pool = GenericNameServerPool::tokio_from_config(
427            &resolver_config,
428            ResolverOpts::default(),
429            TokioRuntimeProvider::new(),
430        );
431
432        let name = Name::parse("www.example.com.", None).unwrap();
433
434        // TODO: it's not clear why there are two failures before the success
435        for i in 0..2 {
436            assert!(
437                io_loop
438                    .block_on(
439                        pool.lookup(
440                            Query::query(name.clone(), RecordType::A),
441                            DnsRequestOptions::default()
442                        )
443                        .first_answer()
444                    )
445                    .is_err(),
446                "iter: {}",
447                i
448            );
449        }
450
451        for i in 0..10 {
452            assert!(
453                io_loop
454                    .block_on(
455                        pool.lookup(
456                            Query::query(name.clone(), RecordType::A),
457                            DnsRequestOptions::default()
458                        )
459                        .first_answer()
460                    )
461                    .is_ok(),
462                "iter: {}",
463                i
464            );
465        }
466    }
467
468    #[tokio::test]
469    async fn test_multi_use_conns() {
470        subscribe();
471
472        let conn_provider = TokioConnectionProvider::default();
473
474        let tcp = NameServerConfig {
475            socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
476            protocol: Protocol::Tcp,
477            tls_dns_name: None,
478            http_endpoint: None,
479            trust_negative_responses: false,
480            bind_addr: None,
481        };
482
483        let opts = ResolverOpts {
484            try_tcp_on_error: true,
485            ..ResolverOpts::default()
486        };
487        let ns_config = { tcp };
488        let name_server = GenericNameServer::new(ns_config, opts.clone(), conn_provider);
489        let name_servers: Arc<[_]> = Arc::from([name_server]);
490
491        let pool = GenericNameServerPool::from_nameservers_test(
492            opts,
493            Arc::from([]),
494            Arc::clone(&name_servers),
495        );
496
497        let name = Name::from_str("www.example.com.").unwrap();
498
499        // first lookup
500        let response = pool
501            .lookup(
502                Query::query(name.clone(), RecordType::A),
503                DnsRequestOptions::default(),
504            )
505            .first_answer()
506            .await
507            .expect("lookup failed");
508
509        assert!(!response.answers().is_empty());
510
511        assert!(
512            name_servers[0].is_connected(),
513            "if this is failing then the NameServers aren't being properly shared."
514        );
515
516        // first lookup
517        let response = pool
518            .lookup(
519                Query::query(name, RecordType::AAAA),
520                DnsRequestOptions::default(),
521            )
522            .first_answer()
523            .await
524            .expect("lookup failed");
525
526        assert!(!response.answers().is_empty());
527
528        assert!(
529            name_servers[0].is_connected(),
530            "if this is failing then the NameServers aren't being properly shared."
531        );
532    }
533
534    impl GenericNameServerPool<TokioRuntimeProvider> {
535        pub(crate) fn tokio_from_config(
536            config: &ResolverConfig,
537            options: ResolverOpts,
538            runtime: TokioRuntimeProvider,
539        ) -> Self {
540            Self::from_config_with_provider(config, options, GenericConnector::new(runtime))
541        }
542    }
543}