1use std::cmp::Ordering;
9use std::pin::Pin;
10use std::sync::{
11 Arc,
12 atomic::{AtomicUsize, Ordering as AtomicOrdering},
13};
14use std::task::{Context, Poll};
15use std::time::Duration;
16
17use futures_util::future::FutureExt;
18use futures_util::stream::{FuturesUnordered, Stream, StreamExt, once};
19use smallvec::SmallVec;
20use tracing::debug;
21
22use crate::config::{NameServerConfigGroup, ResolverConfig, ResolverOpts, ServerOrderingStrategy};
23use crate::name_server::connection_provider::{ConnectionProvider, GenericConnector};
24use crate::name_server::name_server::NameServer;
25use crate::proto::runtime::{RuntimeProvider, Time};
26use crate::proto::xfer::{DnsHandle, DnsRequest, DnsResponse, FirstAnswer};
27use crate::proto::{ProtoError, ProtoErrorKind};
28
29pub type GenericNameServerPool<P> = NameServerPool<GenericConnector<P>>;
33
34#[derive(Clone)]
36pub struct NameServerPool<P: ConnectionProvider + Send + 'static> {
37 datagram_conns: Arc<[NameServer<P>]>, stream_conns: Arc<[NameServer<P>]>, options: ResolverOpts,
41 datagram_index: Arc<AtomicUsize>,
42 stream_index: Arc<AtomicUsize>,
43}
44
45impl<P> NameServerPool<P>
46where
47 P: ConnectionProvider + 'static,
48{
49 pub(crate) fn from_config_with_provider(
50 config: &ResolverConfig,
51 options: ResolverOpts,
52 conn_provider: P,
53 ) -> Self {
54 let datagram_conns = config
55 .name_servers()
56 .iter()
57 .filter(|ns_config| ns_config.protocol.is_datagram())
58 .map(|ns_config| {
59 NameServer::new(ns_config.clone(), options.clone(), conn_provider.clone())
60 })
61 .collect();
62
63 let stream_conns = config
64 .name_servers()
65 .iter()
66 .filter(|ns_config| ns_config.protocol.is_stream())
67 .map(|ns_config| {
68 NameServer::new(ns_config.clone(), options.clone(), conn_provider.clone())
69 })
70 .collect();
71
72 Self {
73 datagram_conns,
74 stream_conns,
75 options,
76 datagram_index: Arc::from(AtomicUsize::new(0)),
77 stream_index: Arc::from(AtomicUsize::new(0)),
78 }
79 }
80
81 pub fn from_config(
83 name_servers: NameServerConfigGroup,
84 options: ResolverOpts,
85 conn_provider: P,
86 ) -> Self {
87 let map_config_to_ns =
88 |ns_config| NameServer::new(ns_config, options.clone(), conn_provider.clone());
89
90 let (datagram, stream): (Vec<_>, Vec<_>) = name_servers
91 .into_inner()
92 .into_iter()
93 .partition(|ns| ns.protocol.is_datagram());
94
95 let datagram_conns: Vec<_> = datagram.into_iter().map(map_config_to_ns).collect();
96 let stream_conns: Vec<_> = stream.into_iter().map(map_config_to_ns).collect();
97
98 Self {
99 datagram_conns: Arc::from(datagram_conns),
100 stream_conns: Arc::from(stream_conns),
101 options,
102 datagram_index: Arc::from(AtomicUsize::new(0)),
103 stream_index: Arc::from(AtomicUsize::new(0)),
104 }
105 }
106
107 #[doc(hidden)]
108 pub fn from_nameservers(
109 options: ResolverOpts,
110 datagram_conns: Vec<NameServer<P>>,
111 stream_conns: Vec<NameServer<P>>,
112 ) -> Self {
113 Self {
114 datagram_conns: Arc::from(datagram_conns),
115 stream_conns: Arc::from(stream_conns),
116 options,
117 datagram_index: Arc::from(AtomicUsize::new(0)),
118 stream_index: Arc::from(AtomicUsize::new(0)),
119 }
120 }
121
122 pub fn options(&self) -> &ResolverOpts {
124 &self.options
125 }
126
127 #[cfg(test)]
128 #[allow(dead_code)]
129 fn from_nameservers_test(
130 options: ResolverOpts,
131 datagram_conns: Arc<[NameServer<P>]>,
132 stream_conns: Arc<[NameServer<P>]>,
133 ) -> Self {
134 Self {
135 datagram_conns,
136 stream_conns,
137 options,
138 datagram_index: Arc::from(AtomicUsize::new(0)),
139 stream_index: Arc::from(AtomicUsize::new(0)),
140 }
141 }
142
143 async fn try_send(
144 opts: ResolverOpts,
145 conns: Arc<[NameServer<P>]>,
146 request: DnsRequest,
147 next_index: &Arc<AtomicUsize>,
148 ) -> Result<DnsResponse, ProtoError> {
149 let mut conns: Vec<NameServer<P>> = conns.to_vec();
150
151 match opts.server_ordering_strategy {
152 ServerOrderingStrategy::QueryStatistics => {
156 conns.sort_by(|a, b| a.stats.decayed_srtt().total_cmp(&b.stats.decayed_srtt()));
157 }
158 ServerOrderingStrategy::UserProvidedOrder => {}
159 ServerOrderingStrategy::RoundRobin => {
160 let num_concurrent_reqs = if opts.num_concurrent_reqs > 1 {
161 opts.num_concurrent_reqs
162 } else {
163 1
164 };
165 if num_concurrent_reqs < conns.len() {
166 let index = next_index.fetch_add(num_concurrent_reqs, AtomicOrdering::SeqCst)
167 % conns.len();
168 conns.rotate_left(index);
169 }
170 }
171 }
172 let request_loop = request.clone();
173
174 parallel_conn_loop(conns, request_loop, opts).await
175 }
176}
177
178impl<P> DnsHandle for NameServerPool<P>
179where
180 P: ConnectionProvider + 'static,
181{
182 type Response = Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>;
183
184 fn send<R: Into<DnsRequest>>(&self, request: R) -> Self::Response {
185 let opts = self.options.clone();
186 let request = request.into();
187 let datagram_conns = Arc::clone(&self.datagram_conns);
188 let stream_conns = Arc::clone(&self.stream_conns);
189 let datagram_index = Arc::clone(&self.datagram_index);
190 let stream_index = Arc::clone(&self.stream_index);
191 let tcp_message = request.clone();
194
195 let mdns = Local::NotMdns(request);
197
198 if mdns.is_local() {
200 return mdns.take_stream();
201 }
202
203 let request = mdns.take_request();
207 Box::pin(once(async move {
208 debug!("sending request: {:?}", request.queries());
209
210 let future = Self::try_send(opts.clone(), datagram_conns, request, &datagram_index);
212 let udp_res = match future.await {
213 Ok(response) if response.truncated() => {
214 debug!("truncated response received, retrying over TCP");
215 Err(ProtoError::from("received truncated response"))
216 }
217 Err(e)
218 if (opts.try_tcp_on_error && e.is_io())
219 || e.is_no_connections()
220 || matches!(&*e.kind, ProtoErrorKind::QueryCaseMismatch) =>
221 {
222 debug!("error from UDP, retrying over TCP: {}", e);
223 Err(e)
224 }
225 result => return result,
226 };
227
228 if stream_conns.is_empty() {
229 debug!("no TCP connections available");
230 return udp_res;
231 }
232
233 Self::try_send(opts, stream_conns, tcp_message, &stream_index).await
236 }))
237 }
238}
239
240async fn parallel_conn_loop<P>(
243 mut conns: Vec<NameServer<P>>,
244 request: DnsRequest,
245 opts: ResolverOpts,
246) -> Result<DnsResponse, ProtoError>
247where
248 P: ConnectionProvider + 'static,
249{
250 let mut err = ProtoError::from(ProtoErrorKind::NoConnections);
251
252 let mut backoff = Duration::from_millis(20);
263 let mut busy = SmallVec::<[NameServer<P>; 2]>::new();
264
265 loop {
266 let request_cont = request.clone();
267
268 let mut par_conns = SmallVec::<[NameServer<P>; 2]>::new();
270 let count = conns.len().min(opts.num_concurrent_reqs.max(1));
271
272 for conn in conns.drain(..count) {
274 par_conns.push(conn);
275 }
276
277 if par_conns.is_empty() {
278 if !busy.is_empty() && backoff < Duration::from_millis(300) {
279 <<P as ConnectionProvider>::RuntimeProvider as RuntimeProvider>::Timer::delay_for(
280 backoff,
281 )
282 .await;
283 conns.extend(busy.drain(..));
284 backoff *= 2;
285 continue;
286 }
287 return Err(err);
288 }
289
290 let mut requests = par_conns
291 .into_iter()
292 .map(move |conn| {
293 conn.send(request_cont.clone())
294 .first_answer()
295 .map(|result| result.map_err(|e| (conn, e)))
296 })
297 .collect::<FuturesUnordered<_>>();
298
299 while let Some(result) = requests.next().await {
300 let (conn, e) = match result {
301 Ok(sent) => return Ok(sent),
302 Err((conn, e)) => (conn, e),
303 };
304
305 match e.kind() {
306 ProtoErrorKind::NoRecordsFound {
307 trusted, soa, ns, ..
308 } if *trusted || soa.is_some() || ns.is_some() => {
309 return Err(e);
310 }
311 _ if e.is_busy() => {
312 busy.push(conn);
313 }
314 _ if matches!(err.kind(), ProtoErrorKind::NoConnections) => {
318 err = e;
319 }
320 _ if err.cmp_specificity(&e) == Ordering::Less => {
321 err = e;
322 }
323 _ => {}
324 }
325 }
326 }
327}
328
329#[allow(clippy::large_enum_variant)]
330pub(crate) enum Local {
331 #[allow(dead_code)]
332 ResolveStream(Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>>),
333 NotMdns(DnsRequest),
334}
335
336impl Local {
337 fn is_local(&self) -> bool {
338 matches!(*self, Self::ResolveStream(..))
339 }
340
341 fn take_stream(self) -> Pin<Box<dyn Stream<Item = Result<DnsResponse, ProtoError>> + Send>> {
347 match self {
348 Self::ResolveStream(future) => future,
349 _ => panic!("non Local queries have no future, see take_message()"),
350 }
351 }
352
353 fn take_request(self) -> DnsRequest {
359 match self {
360 Self::NotMdns(request) => request,
361 _ => panic!("Local queries must be polled, see take_future()"),
362 }
363 }
364}
365
366impl Stream for Local {
367 type Item = Result<DnsResponse, ProtoError>;
368
369 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
370 match self.get_mut() {
371 Self::ResolveStream(ns) => ns.as_mut().poll_next(cx),
372 Self::NotMdns(..) => panic!("Local queries that are not mDNS should not be polled"), }
375 }
376}
377
378#[cfg(test)]
379#[cfg(feature = "tokio")]
380mod tests {
381 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
382 use std::str::FromStr;
383
384 use test_support::subscribe;
385 use tokio::runtime::Runtime;
386
387 use super::*;
388 use crate::config::NameServerConfig;
389 use crate::name_server::GenericNameServer;
390 use crate::name_server::connection_provider::TokioConnectionProvider;
391 use crate::proto::op::Query;
392 use crate::proto::rr::{Name, RecordType};
393 use crate::proto::runtime::TokioRuntimeProvider;
394 use crate::proto::xfer::{DnsHandle, DnsRequestOptions, Protocol};
395
396 #[ignore]
397 #[test]
399 #[allow(clippy::uninlined_format_args)]
400 fn test_failed_then_success_pool() {
401 subscribe();
402
403 let config1 = NameServerConfig {
404 socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 252)), 253),
405 protocol: Protocol::Udp,
406 tls_dns_name: None,
407 http_endpoint: None,
408 trust_negative_responses: false,
409 bind_addr: None,
410 };
411
412 let config2 = NameServerConfig {
413 socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
414 protocol: Protocol::Udp,
415 tls_dns_name: None,
416 http_endpoint: None,
417 trust_negative_responses: false,
418 bind_addr: None,
419 };
420
421 let mut resolver_config = ResolverConfig::new();
422 resolver_config.add_name_server(config1);
423 resolver_config.add_name_server(config2);
424
425 let io_loop = Runtime::new().unwrap();
426 let pool = GenericNameServerPool::tokio_from_config(
427 &resolver_config,
428 ResolverOpts::default(),
429 TokioRuntimeProvider::new(),
430 );
431
432 let name = Name::parse("www.example.com.", None).unwrap();
433
434 for i in 0..2 {
436 assert!(
437 io_loop
438 .block_on(
439 pool.lookup(
440 Query::query(name.clone(), RecordType::A),
441 DnsRequestOptions::default()
442 )
443 .first_answer()
444 )
445 .is_err(),
446 "iter: {}",
447 i
448 );
449 }
450
451 for i in 0..10 {
452 assert!(
453 io_loop
454 .block_on(
455 pool.lookup(
456 Query::query(name.clone(), RecordType::A),
457 DnsRequestOptions::default()
458 )
459 .first_answer()
460 )
461 .is_ok(),
462 "iter: {}",
463 i
464 );
465 }
466 }
467
468 #[tokio::test]
469 async fn test_multi_use_conns() {
470 subscribe();
471
472 let conn_provider = TokioConnectionProvider::default();
473
474 let tcp = NameServerConfig {
475 socket_addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 53),
476 protocol: Protocol::Tcp,
477 tls_dns_name: None,
478 http_endpoint: None,
479 trust_negative_responses: false,
480 bind_addr: None,
481 };
482
483 let opts = ResolverOpts {
484 try_tcp_on_error: true,
485 ..ResolverOpts::default()
486 };
487 let ns_config = { tcp };
488 let name_server = GenericNameServer::new(ns_config, opts.clone(), conn_provider);
489 let name_servers: Arc<[_]> = Arc::from([name_server]);
490
491 let pool = GenericNameServerPool::from_nameservers_test(
492 opts,
493 Arc::from([]),
494 Arc::clone(&name_servers),
495 );
496
497 let name = Name::from_str("www.example.com.").unwrap();
498
499 let response = pool
501 .lookup(
502 Query::query(name.clone(), RecordType::A),
503 DnsRequestOptions::default(),
504 )
505 .first_answer()
506 .await
507 .expect("lookup failed");
508
509 assert!(!response.answers().is_empty());
510
511 assert!(
512 name_servers[0].is_connected(),
513 "if this is failing then the NameServers aren't being properly shared."
514 );
515
516 let response = pool
518 .lookup(
519 Query::query(name, RecordType::AAAA),
520 DnsRequestOptions::default(),
521 )
522 .first_answer()
523 .await
524 .expect("lookup failed");
525
526 assert!(!response.answers().is_empty());
527
528 assert!(
529 name_servers[0].is_connected(),
530 "if this is failing then the NameServers aren't being properly shared."
531 );
532 }
533
534 impl GenericNameServerPool<TokioRuntimeProvider> {
535 pub(crate) fn tokio_from_config(
536 config: &ResolverConfig,
537 options: ResolverOpts,
538 runtime: TokioRuntimeProvider,
539 ) -> Self {
540 Self::from_config_with_provider(config, options, GenericConnector::new(runtime))
541 }
542 }
543}