1use std::io;
9use std::marker::Unpin;
10use std::net::SocketAddr;
11#[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
12use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{Context, Poll};
16
17use futures_util::future::{Future, FutureExt};
18use futures_util::ready;
19use futures_util::stream::{Stream, StreamExt};
20#[cfg(feature = "tokio-runtime")]
21use tokio::net::TcpStream as TokioTcpStream;
22#[cfg(all(feature = "dns-over-native-tls", not(feature = "dns-over-rustls")))]
23use tokio_native_tls::TlsStream as TokioTlsStream;
24#[cfg(all(
25 feature = "dns-over-openssl",
26 not(feature = "dns-over-rustls"),
27 not(feature = "dns-over-native-tls")
28))]
29use tokio_openssl::SslStream as TokioTlsStream;
30#[cfg(feature = "dns-over-rustls")]
31use tokio_rustls::client::TlsStream as TokioTlsStream;
32
33use crate::config::{NameServerConfig, Protocol, ResolverOpts};
34#[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
35use hickory_proto::udp::QuicLocalAddr;
36#[cfg(feature = "dns-over-https")]
37use proto::h2::{HttpsClientConnect, HttpsClientStream};
38#[cfg(feature = "dns-over-h3")]
39use proto::h3::{H3ClientConnect, H3ClientStream};
40#[cfg(feature = "mdns")]
41use proto::multicast::{MdnsClientConnect, MdnsClientStream, MdnsQueryType};
42#[cfg(feature = "dns-over-quic")]
43use proto::quic::{QuicClientConnect, QuicClientStream};
44use proto::tcp::DnsTcpStream;
45use proto::udp::DnsUdpSocket;
46use proto::{
47 self,
48 error::ProtoError,
49 op::NoopMessageFinalizer,
50 tcp::TcpClientConnect,
51 tcp::TcpClientStream,
52 udp::UdpClientConnect,
53 udp::UdpClientStream,
54 xfer::{
55 DnsExchange, DnsExchangeConnect, DnsExchangeSend, DnsHandle, DnsMultiplexer,
56 DnsMultiplexerConnect, DnsRequest, DnsResponse,
57 },
58 Time,
59};
60#[cfg(feature = "tokio-runtime")]
61use proto::{iocompat::AsyncIoTokioAsStd, TokioTime};
62
63use crate::error::ResolveError;
64
65pub trait RuntimeProvider: Clone + Send + Sync + Unpin + 'static {
67 type Handle: Clone + Send + Spawn + Sync + Unpin;
69
70 type Timer: Time + Send + Unpin;
72
73 #[cfg(not(any(feature = "dns-over-quic", feature = "dns-over-h3")))]
74 type Udp: DnsUdpSocket + Send;
76 #[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
77 type Udp: DnsUdpSocket + QuicLocalAddr + Send;
79
80 type Tcp: DnsTcpStream;
82
83 fn create_handle(&self) -> Self::Handle;
85
86 fn connect_tcp(
88 &self,
89 server_addr: SocketAddr,
90 ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>>;
91
92 fn bind_udp(
95 &self,
96 local_addr: SocketAddr,
97 server_addr: SocketAddr,
98 ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>>;
99}
100
101pub trait ConnectionProvider: 'static + Clone + Send + Sync + Unpin {
104 type Conn: DnsHandle<Error = ResolveError> + Clone + Send + Sync + 'static;
106 type FutureConn: Future<Output = Result<Self::Conn, ResolveError>> + Send + 'static;
108 type RuntimeProvider: RuntimeProvider;
110
111 fn new_connection(&self, config: &NameServerConfig, options: &ResolverOpts)
113 -> Self::FutureConn;
114}
115
116pub trait Spawn {
118 fn spawn_bg<F>(&mut self, future: F)
120 where
121 F: Future<Output = Result<(), ProtoError>> + Send + 'static;
122}
123
124#[cfg(feature = "dns-over-tls")]
125type TlsClientStream<S> =
127 TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<proto::iocompat::AsyncIoStdAsTokio<S>>>>;
128
129#[allow(clippy::large_enum_variant, clippy::type_complexity)]
131pub(crate) enum ConnectionConnect<R: RuntimeProvider> {
132 Udp(DnsExchangeConnect<UdpClientConnect<R::Udp>, UdpClientStream<R::Udp>, R::Timer>),
133 Tcp(
134 DnsExchangeConnect<
135 DnsMultiplexerConnect<
136 TcpClientConnect<<R as RuntimeProvider>::Tcp>,
137 TcpClientStream<<R as RuntimeProvider>::Tcp>,
138 NoopMessageFinalizer,
139 >,
140 DnsMultiplexer<TcpClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
141 R::Timer,
142 >,
143 ),
144 #[cfg(all(feature = "dns-over-tls", feature = "tokio-runtime"))]
145 Tls(
146 DnsExchangeConnect<
147 DnsMultiplexerConnect<
148 Pin<
149 Box<
150 dyn Future<
151 Output = Result<
152 TlsClientStream<<R as RuntimeProvider>::Tcp>,
153 ProtoError,
154 >,
155 > + Send
156 + 'static,
157 >,
158 >,
159 TlsClientStream<<R as RuntimeProvider>::Tcp>,
160 NoopMessageFinalizer,
161 >,
162 DnsMultiplexer<TlsClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
163 TokioTime,
164 >,
165 ),
166 #[cfg(all(feature = "dns-over-https", feature = "tokio-runtime"))]
167 Https(DnsExchangeConnect<HttpsClientConnect<R::Tcp>, HttpsClientStream, TokioTime>),
168 #[cfg(all(feature = "dns-over-quic", feature = "tokio-runtime"))]
169 Quic(DnsExchangeConnect<QuicClientConnect, QuicClientStream, TokioTime>),
170 #[cfg(all(feature = "dns-over-h3", feature = "tokio-runtime"))]
171 H3(DnsExchangeConnect<H3ClientConnect, H3ClientStream, TokioTime>),
172 #[cfg(feature = "mdns")]
173 Mdns(
174 DnsExchangeConnect<
175 DnsMultiplexerConnect<MdnsClientConnect, MdnsClientStream, NoopMessageFinalizer>,
176 DnsMultiplexer<MdnsClientStream, NoopMessageFinalizer>,
177 TokioTime,
178 >,
179 ),
180}
181
182#[must_use = "futures do nothing unless polled"]
184pub struct ConnectionFuture<R: RuntimeProvider> {
185 pub(crate) connect: ConnectionConnect<R>,
186 pub(crate) spawner: R::Handle,
187}
188
189impl<R: RuntimeProvider> Future for ConnectionFuture<R> {
190 type Output = Result<GenericConnection, ResolveError>;
191
192 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
193 Poll::Ready(Ok(match &mut self.connect {
194 ConnectionConnect::Udp(ref mut conn) => {
195 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
196 self.spawner.spawn_bg(bg);
197 GenericConnection(conn)
198 }
199 ConnectionConnect::Tcp(ref mut conn) => {
200 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
201 self.spawner.spawn_bg(bg);
202 GenericConnection(conn)
203 }
204 #[cfg(feature = "dns-over-tls")]
205 ConnectionConnect::Tls(ref mut conn) => {
206 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
207 self.spawner.spawn_bg(bg);
208 GenericConnection(conn)
209 }
210 #[cfg(feature = "dns-over-https")]
211 ConnectionConnect::Https(ref mut conn) => {
212 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
213 self.spawner.spawn_bg(bg);
214 GenericConnection(conn)
215 }
216 #[cfg(feature = "dns-over-quic")]
217 ConnectionConnect::Quic(ref mut conn) => {
218 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
219 self.spawner.spawn_bg(bg);
220 GenericConnection(conn)
221 }
222 #[cfg(feature = "dns-over-h3")]
223 ConnectionConnect::H3(ref mut conn) => {
224 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
225 self.spawner.spawn_bg(bg);
226 GenericConnection(conn)
227 }
228 #[cfg(feature = "mdns")]
229 ConnectionConnect::Mdns(ref mut conn) => {
230 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
231 self.spawner.spawn_bg(bg);
232 GenericConnection(conn)
233 }
234 }))
235 }
236}
237
238#[derive(Clone)]
240pub struct GenericConnection(DnsExchange);
241
242impl DnsHandle for GenericConnection {
243 type Response = ConnectionResponse;
244 type Error = ResolveError;
245
246 fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
247 ConnectionResponse(self.0.send(request))
248 }
249}
250
251#[derive(Clone)]
253pub struct GenericConnector<P: RuntimeProvider> {
254 runtime_provider: P,
255}
256
257impl<P: RuntimeProvider> GenericConnector<P> {
258 pub fn new(runtime_provider: P) -> Self {
260 Self { runtime_provider }
261 }
262}
263
264impl<P: RuntimeProvider + Default> Default for GenericConnector<P> {
265 fn default() -> Self {
266 Self {
267 runtime_provider: P::default(),
268 }
269 }
270}
271
272impl<P: RuntimeProvider> ConnectionProvider for GenericConnector<P> {
273 type Conn = GenericConnection;
274 type FutureConn = ConnectionFuture<P>;
275 type RuntimeProvider = P;
276
277 fn new_connection(
278 &self,
279 config: &NameServerConfig,
280 options: &ResolverOpts,
281 ) -> Self::FutureConn {
282 let dns_connect = match config.protocol {
283 Protocol::Udp => {
284 let provider_handle = self.runtime_provider.clone();
285 let closure = move |local_addr: SocketAddr, server_addr: SocketAddr| {
286 provider_handle.bind_udp(local_addr, server_addr)
287 };
288 let stream = UdpClientStream::with_creator(
289 config.socket_addr,
290 None,
291 options.timeout,
292 Arc::new(closure),
293 );
294 let exchange = DnsExchange::connect(stream);
295 ConnectionConnect::Udp(exchange)
296 }
297 Protocol::Tcp => {
298 let socket_addr = config.socket_addr;
299 let timeout = options.timeout;
300 let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
301
302 let (stream, handle) =
303 TcpClientStream::with_future(tcp_future, socket_addr, timeout);
304 let dns_conn = DnsMultiplexer::with_timeout(
306 stream,
307 handle,
308 timeout,
309 NoopMessageFinalizer::new(),
310 );
311
312 let exchange = DnsExchange::connect(dns_conn);
313 ConnectionConnect::Tcp(exchange)
314 }
315 #[cfg(feature = "dns-over-tls")]
316 Protocol::Tls => {
317 let socket_addr = config.socket_addr;
318 let timeout = options.timeout;
319 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
320 let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
321
322 #[cfg(feature = "dns-over-rustls")]
323 let client_config = config.tls_config.clone();
324
325 #[cfg(feature = "dns-over-rustls")]
326 let (stream, handle) = {
327 crate::tls::new_tls_stream_with_future(
328 tcp_future,
329 socket_addr,
330 tls_dns_name,
331 client_config,
332 )
333 };
334 #[cfg(not(feature = "dns-over-rustls"))]
335 let (stream, handle) = {
336 crate::tls::new_tls_stream_with_future(tcp_future, socket_addr, tls_dns_name)
337 };
338
339 let dns_conn = DnsMultiplexer::with_timeout(
340 stream,
341 handle,
342 timeout,
343 NoopMessageFinalizer::new(),
344 );
345
346 let exchange = DnsExchange::connect(dns_conn);
347 ConnectionConnect::Tls(exchange)
348 }
349 #[cfg(feature = "dns-over-https")]
350 Protocol::Https => {
351 let socket_addr = config.socket_addr;
352 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
353 #[cfg(feature = "dns-over-rustls")]
354 let client_config = config.tls_config.clone();
355 let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
356
357 let exchange = crate::h2::new_https_stream_with_future(
358 tcp_future,
359 socket_addr,
360 tls_dns_name,
361 client_config,
362 );
363 ConnectionConnect::Https(exchange)
364 }
365 #[cfg(feature = "dns-over-quic")]
366 Protocol::Quic => {
367 let socket_addr = config.socket_addr;
368 let bind_addr = config.bind_addr.unwrap_or(match socket_addr {
369 SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
370 SocketAddr::V6(_) => {
371 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
372 }
373 });
374 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
375 #[cfg(feature = "dns-over-rustls")]
376 let client_config = config.tls_config.clone();
377 let udp_future = self.runtime_provider.bind_udp(bind_addr, socket_addr);
378
379 let exchange = crate::quic::new_quic_stream_with_future(
380 udp_future,
381 socket_addr,
382 tls_dns_name,
383 client_config,
384 );
385 ConnectionConnect::Quic(exchange)
386 }
387 #[cfg(feature = "dns-over-h3")]
388 Protocol::H3 => {
389 let socket_addr = config.socket_addr;
390 let bind_addr = config.bind_addr.unwrap_or(match socket_addr {
391 SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
392 SocketAddr::V6(_) => {
393 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
394 }
395 });
396 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
397 let client_config = config.tls_config.clone();
398 let udp_future = self.runtime_provider.bind_udp(bind_addr, socket_addr);
399
400 let exchange = crate::h3::new_h3_stream_with_future(
401 udp_future,
402 socket_addr,
403 tls_dns_name,
404 client_config,
405 );
406 ConnectionConnect::H3(exchange)
407 }
408 #[cfg(feature = "mdns")]
409 Protocol::Mdns => {
410 let socket_addr = config.socket_addr;
411 let timeout = options.timeout;
412
413 let (stream, handle) =
414 MdnsClientStream::new(socket_addr, MdnsQueryType::OneShot, None, None, None);
415 let dns_conn = DnsMultiplexer::with_timeout(
417 stream,
418 handle,
419 timeout,
420 NoopMessageFinalizer::new(),
421 );
422
423 let exchange = DnsExchange::connect(dns_conn);
424 ConnectionConnect::Mdns(exchange)
425 }
426 };
427
428 ConnectionFuture::<P> {
429 connect: dns_connect,
430 spawner: self.runtime_provider.create_handle(),
431 }
432 }
433}
434
435#[must_use = "steam do nothing unless polled"]
437pub struct ConnectionResponse(DnsExchangeSend);
438
439impl Stream for ConnectionResponse {
440 type Item = Result<DnsResponse, ResolveError>;
441
442 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
443 Poll::Ready(ready!(self.0.poll_next_unpin(cx)).map(|r| r.map_err(ResolveError::from)))
444 }
445}
446
447#[cfg(feature = "tokio-runtime")]
448#[cfg_attr(docsrs, doc(cfg(feature = "tokio-runtime")))]
449#[allow(unreachable_pub)]
450pub mod tokio_runtime {
451 use super::*;
452 use std::sync::{Arc, Mutex};
453 use tokio::net::UdpSocket as TokioUdpSocket;
454 use tokio::task::JoinSet;
455
456 #[derive(Clone, Default)]
458 pub struct TokioHandle {
459 join_set: Arc<Mutex<JoinSet<Result<(), ProtoError>>>>,
460 }
461
462 impl Spawn for TokioHandle {
463 fn spawn_bg<F>(&mut self, future: F)
464 where
465 F: Future<Output = Result<(), ProtoError>> + Send + 'static,
466 {
467 let mut join_set = self.join_set.lock().unwrap();
468 join_set.spawn(future);
469 reap_tasks(&mut join_set);
470 }
471 }
472
473 #[derive(Clone, Default)]
475 pub struct TokioRuntimeProvider(TokioHandle);
476
477 impl TokioRuntimeProvider {
478 pub fn new() -> Self {
480 Self::default()
481 }
482 }
483
484 impl RuntimeProvider for TokioRuntimeProvider {
485 type Handle = TokioHandle;
486 type Timer = TokioTime;
487 type Udp = TokioUdpSocket;
488 type Tcp = AsyncIoTokioAsStd<TokioTcpStream>;
489
490 fn create_handle(&self) -> Self::Handle {
491 self.0.clone()
492 }
493
494 fn connect_tcp(
495 &self,
496 server_addr: SocketAddr,
497 ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>> {
498 Box::pin(async move {
499 TokioTcpStream::connect(server_addr)
500 .await
501 .map(AsyncIoTokioAsStd)
502 })
503 }
504
505 fn bind_udp(
506 &self,
507 local_addr: SocketAddr,
508 _server_addr: SocketAddr,
509 ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>> {
510 Box::pin(tokio::net::UdpSocket::bind(local_addr))
511 }
512 }
513
514 fn reap_tasks(join_set: &mut JoinSet<Result<(), ProtoError>>) {
516 while FutureExt::now_or_never(join_set.join_next())
517 .flatten()
518 .is_some()
519 {}
520 }
521
522 pub type TokioConnectionProvider = GenericConnector<TokioRuntimeProvider>;
524}