1use std::io;
9use std::marker::Unpin;
10use std::net::SocketAddr;
11#[cfg(feature = "dns-over-quic")]
12use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{Context, Poll};
16
17use futures_util::future::{Future, FutureExt};
18use futures_util::ready;
19use futures_util::stream::{Stream, StreamExt};
20#[cfg(feature = "tokio-runtime")]
21use tokio::net::TcpStream as TokioTcpStream;
22#[cfg(all(feature = "dns-over-native-tls", not(feature = "dns-over-rustls")))]
23use tokio_native_tls::TlsStream as TokioTlsStream;
24#[cfg(all(
25 feature = "dns-over-openssl",
26 not(feature = "dns-over-rustls"),
27 not(feature = "dns-over-native-tls")
28))]
29use tokio_openssl::SslStream as TokioTlsStream;
30#[cfg(feature = "dns-over-rustls")]
31use tokio_rustls::client::TlsStream as TokioTlsStream;
32
33use crate::config::{NameServerConfig, Protocol, ResolverOpts};
34#[cfg(feature = "dns-over-https")]
35use proto::https::{HttpsClientConnect, HttpsClientStream};
36#[cfg(feature = "mdns")]
37use proto::multicast::{MdnsClientConnect, MdnsClientStream, MdnsQueryType};
38#[cfg(feature = "dns-over-quic")]
39use proto::quic::{QuicClientConnect, QuicClientStream};
40use proto::tcp::DnsTcpStream;
41use proto::udp::DnsUdpSocket;
42use proto::{
43 self,
44 error::ProtoError,
45 op::NoopMessageFinalizer,
46 tcp::TcpClientConnect,
47 tcp::TcpClientStream,
48 udp::UdpClientConnect,
49 udp::UdpClientStream,
50 xfer::{
51 DnsExchange, DnsExchangeConnect, DnsExchangeSend, DnsHandle, DnsMultiplexer,
52 DnsMultiplexerConnect, DnsRequest, DnsResponse,
53 },
54 Time,
55};
56#[cfg(feature = "tokio-runtime")]
57use proto::{iocompat::AsyncIoTokioAsStd, TokioTime};
58#[cfg(feature = "dns-over-quic")]
59use trust_dns_proto::udp::QuicLocalAddr;
60
61use crate::error::ResolveError;
62
63pub trait RuntimeProvider: Clone + Send + Sync + Unpin + 'static {
65 type Handle: Clone + Send + Spawn + Sync + Unpin;
67
68 type Timer: Time + Send + Unpin;
70
71 #[cfg(not(feature = "dns-over-quic"))]
72 type Udp: DnsUdpSocket + Send;
74 #[cfg(feature = "dns-over-quic")]
75 type Udp: DnsUdpSocket + QuicLocalAddr + Send;
77
78 type Tcp: DnsTcpStream;
80
81 fn create_handle(&self) -> Self::Handle;
83
84 fn connect_tcp(
86 &self,
87 server_addr: SocketAddr,
88 ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>>;
89
90 fn bind_udp(
93 &self,
94 local_addr: SocketAddr,
95 server_addr: SocketAddr,
96 ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>>;
97}
98
99pub trait ConnectionProvider: 'static + Clone + Send + Sync + Unpin {
102 type Conn: DnsHandle<Error = ResolveError> + Clone + Send + Sync + 'static;
104 type FutureConn: Future<Output = Result<Self::Conn, ResolveError>> + Send + 'static;
106 type RuntimeProvider: RuntimeProvider;
108
109 fn new_connection(&self, config: &NameServerConfig, options: &ResolverOpts)
111 -> Self::FutureConn;
112}
113
114pub trait Spawn {
116 fn spawn_bg<F>(&mut self, future: F)
118 where
119 F: Future<Output = Result<(), ProtoError>> + Send + 'static;
120}
121
122#[cfg(feature = "dns-over-tls")]
123type TlsClientStream<S> =
125 TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<proto::iocompat::AsyncIoStdAsTokio<S>>>>;
126
127#[allow(clippy::large_enum_variant, clippy::type_complexity)]
129pub(crate) enum ConnectionConnect<R: RuntimeProvider> {
130 Udp(DnsExchangeConnect<UdpClientConnect<R::Udp>, UdpClientStream<R::Udp>, R::Timer>),
131 Tcp(
132 DnsExchangeConnect<
133 DnsMultiplexerConnect<
134 TcpClientConnect<<R as RuntimeProvider>::Tcp>,
135 TcpClientStream<<R as RuntimeProvider>::Tcp>,
136 NoopMessageFinalizer,
137 >,
138 DnsMultiplexer<TcpClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
139 R::Timer,
140 >,
141 ),
142 #[cfg(all(feature = "dns-over-tls", feature = "tokio-runtime"))]
143 Tls(
144 DnsExchangeConnect<
145 DnsMultiplexerConnect<
146 Pin<
147 Box<
148 dyn Future<
149 Output = Result<
150 TlsClientStream<<R as RuntimeProvider>::Tcp>,
151 ProtoError,
152 >,
153 > + Send
154 + 'static,
155 >,
156 >,
157 TlsClientStream<<R as RuntimeProvider>::Tcp>,
158 NoopMessageFinalizer,
159 >,
160 DnsMultiplexer<TlsClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
161 TokioTime,
162 >,
163 ),
164 #[cfg(all(feature = "dns-over-https", feature = "tokio-runtime"))]
165 Https(DnsExchangeConnect<HttpsClientConnect<R::Tcp>, HttpsClientStream, TokioTime>),
166 #[cfg(all(feature = "dns-over-quic", feature = "tokio-runtime"))]
167 Quic(DnsExchangeConnect<QuicClientConnect, QuicClientStream, TokioTime>),
168 #[cfg(feature = "mdns")]
169 Mdns(
170 DnsExchangeConnect<
171 DnsMultiplexerConnect<MdnsClientConnect, MdnsClientStream, NoopMessageFinalizer>,
172 DnsMultiplexer<MdnsClientStream, NoopMessageFinalizer>,
173 TokioTime,
174 >,
175 ),
176}
177
178#[must_use = "futures do nothing unless polled"]
180pub struct ConnectionFuture<R: RuntimeProvider> {
181 pub(crate) connect: ConnectionConnect<R>,
182 pub(crate) spawner: R::Handle,
183}
184
185impl<R: RuntimeProvider> Future for ConnectionFuture<R> {
186 type Output = Result<GenericConnection, ResolveError>;
187
188 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
189 Poll::Ready(Ok(match &mut self.connect {
190 ConnectionConnect::Udp(ref mut conn) => {
191 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
192 self.spawner.spawn_bg(bg);
193 GenericConnection(conn)
194 }
195 ConnectionConnect::Tcp(ref mut conn) => {
196 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
197 self.spawner.spawn_bg(bg);
198 GenericConnection(conn)
199 }
200 #[cfg(feature = "dns-over-tls")]
201 ConnectionConnect::Tls(ref mut conn) => {
202 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
203 self.spawner.spawn_bg(bg);
204 GenericConnection(conn)
205 }
206 #[cfg(feature = "dns-over-https")]
207 ConnectionConnect::Https(ref mut conn) => {
208 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
209 self.spawner.spawn_bg(bg);
210 GenericConnection(conn)
211 }
212 #[cfg(feature = "dns-over-quic")]
213 ConnectionConnect::Quic(ref mut conn) => {
214 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
215 self.spawner.spawn_bg(bg);
216 GenericConnection(conn)
217 }
218 #[cfg(feature = "mdns")]
219 ConnectionConnect::Mdns(ref mut conn) => {
220 let (conn, bg) = ready!(conn.poll_unpin(cx))?;
221 self.spawner.spawn_bg(bg);
222 GenericConnection(conn)
223 }
224 }))
225 }
226}
227
228#[derive(Clone)]
230pub struct GenericConnection(DnsExchange);
231
232impl DnsHandle for GenericConnection {
233 type Response = ConnectionResponse;
234 type Error = ResolveError;
235
236 fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&mut self, request: R) -> Self::Response {
237 ConnectionResponse(self.0.send(request))
238 }
239}
240
241#[derive(Clone)]
243pub struct GenericConnector<P: RuntimeProvider> {
244 runtime_provider: P,
245}
246
247impl<P: RuntimeProvider> GenericConnector<P> {
248 pub fn new(runtime_provider: P) -> Self {
250 Self { runtime_provider }
251 }
252}
253
254impl<P: RuntimeProvider + Default> Default for GenericConnector<P> {
255 fn default() -> Self {
256 Self {
257 runtime_provider: P::default(),
258 }
259 }
260}
261
262impl<P: RuntimeProvider> ConnectionProvider for GenericConnector<P> {
263 type Conn = GenericConnection;
264 type FutureConn = ConnectionFuture<P>;
265 type RuntimeProvider = P;
266
267 fn new_connection(
268 &self,
269 config: &NameServerConfig,
270 options: &ResolverOpts,
271 ) -> Self::FutureConn {
272 let dns_connect = match config.protocol {
273 Protocol::Udp => {
274 let provider_handle = self.runtime_provider.clone();
275 let closure = move |local_addr: SocketAddr, server_addr: SocketAddr| {
276 provider_handle.bind_udp(local_addr, server_addr)
277 };
278 let stream = UdpClientStream::with_creator(
279 config.socket_addr,
280 None,
281 options.timeout,
282 Arc::new(closure),
283 );
284 let exchange = DnsExchange::connect(stream);
285 ConnectionConnect::Udp(exchange)
286 }
287 Protocol::Tcp => {
288 let socket_addr = config.socket_addr;
289 let timeout = options.timeout;
290 let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
291
292 let (stream, handle) =
293 TcpClientStream::with_future(tcp_future, socket_addr, timeout);
294 let dns_conn = DnsMultiplexer::with_timeout(
296 stream,
297 handle,
298 timeout,
299 NoopMessageFinalizer::new(),
300 );
301
302 let exchange = DnsExchange::connect(dns_conn);
303 ConnectionConnect::Tcp(exchange)
304 }
305 #[cfg(feature = "dns-over-tls")]
306 Protocol::Tls => {
307 let socket_addr = config.socket_addr;
308 let timeout = options.timeout;
309 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
310 let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
311
312 #[cfg(feature = "dns-over-rustls")]
313 let client_config = config.tls_config.clone();
314
315 #[cfg(feature = "dns-over-rustls")]
316 let (stream, handle) = {
317 crate::tls::new_tls_stream_with_future(
318 tcp_future,
319 socket_addr,
320 tls_dns_name,
321 client_config,
322 )
323 };
324 #[cfg(not(feature = "dns-over-rustls"))]
325 let (stream, handle) = {
326 crate::tls::new_tls_stream_with_future(tcp_future, socket_addr, tls_dns_name)
327 };
328
329 let dns_conn = DnsMultiplexer::with_timeout(
330 stream,
331 handle,
332 timeout,
333 NoopMessageFinalizer::new(),
334 );
335
336 let exchange = DnsExchange::connect(dns_conn);
337 ConnectionConnect::Tls(exchange)
338 }
339 #[cfg(feature = "dns-over-https")]
340 Protocol::Https => {
341 let socket_addr = config.socket_addr;
342 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
343 #[cfg(feature = "dns-over-rustls")]
344 let client_config = config.tls_config.clone();
345 let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
346
347 let exchange = crate::https::new_https_stream_with_future(
348 tcp_future,
349 socket_addr,
350 tls_dns_name,
351 client_config,
352 );
353 ConnectionConnect::Https(exchange)
354 }
355 #[cfg(feature = "dns-over-quic")]
356 Protocol::Quic => {
357 let socket_addr = config.socket_addr;
358 let bind_addr = config.bind_addr.unwrap_or(match socket_addr {
359 SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
360 SocketAddr::V6(_) => {
361 SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
362 }
363 });
364 let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
365 #[cfg(feature = "dns-over-rustls")]
366 let client_config = config.tls_config.clone();
367 let udp_future = self.runtime_provider.bind_udp(bind_addr, socket_addr);
368
369 let exchange = crate::quic::new_quic_stream_with_future(
370 udp_future,
371 socket_addr,
372 tls_dns_name,
373 client_config,
374 );
375 ConnectionConnect::Quic(exchange)
376 }
377 #[cfg(feature = "mdns")]
378 Protocol::Mdns => {
379 let socket_addr = config.socket_addr;
380 let timeout = options.timeout;
381
382 let (stream, handle) =
383 MdnsClientStream::new(socket_addr, MdnsQueryType::OneShot, None, None, None);
384 let dns_conn = DnsMultiplexer::with_timeout(
386 stream,
387 handle,
388 timeout,
389 NoopMessageFinalizer::new(),
390 );
391
392 let exchange = DnsExchange::connect(dns_conn);
393 ConnectionConnect::Mdns(exchange)
394 }
395 };
396
397 ConnectionFuture::<P> {
398 connect: dns_connect,
399 spawner: self.runtime_provider.create_handle(),
400 }
401 }
402}
403
404#[must_use = "steam do nothing unless polled"]
406pub struct ConnectionResponse(DnsExchangeSend);
407
408impl Stream for ConnectionResponse {
409 type Item = Result<DnsResponse, ResolveError>;
410
411 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
412 Poll::Ready(ready!(self.0.poll_next_unpin(cx)).map(|r| r.map_err(ResolveError::from)))
413 }
414}
415
416#[cfg(feature = "tokio-runtime")]
417#[cfg_attr(docsrs, doc(cfg(feature = "tokio-runtime")))]
418#[allow(unreachable_pub)]
419pub mod tokio_runtime {
420 use super::*;
421 use std::sync::{Arc, Mutex};
422 use tokio::net::UdpSocket as TokioUdpSocket;
423 use tokio::task::JoinSet;
424
425 #[derive(Clone, Default)]
427 pub struct TokioHandle {
428 join_set: Arc<Mutex<JoinSet<Result<(), ProtoError>>>>,
429 }
430
431 impl Spawn for TokioHandle {
432 fn spawn_bg<F>(&mut self, future: F)
433 where
434 F: Future<Output = Result<(), ProtoError>> + Send + 'static,
435 {
436 let mut join_set = self.join_set.lock().unwrap();
437 join_set.spawn(future);
438 reap_tasks(&mut join_set);
439 }
440 }
441
442 #[derive(Clone, Default)]
444 pub struct TokioRuntimeProvider(TokioHandle);
445
446 impl TokioRuntimeProvider {
447 pub fn new() -> Self {
449 Self::default()
450 }
451 }
452
453 impl RuntimeProvider for TokioRuntimeProvider {
454 type Handle = TokioHandle;
455 type Timer = TokioTime;
456 type Udp = TokioUdpSocket;
457 type Tcp = AsyncIoTokioAsStd<TokioTcpStream>;
458
459 fn create_handle(&self) -> Self::Handle {
460 self.0.clone()
461 }
462
463 fn connect_tcp(
464 &self,
465 server_addr: SocketAddr,
466 ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>> {
467 Box::pin(async move {
468 TokioTcpStream::connect(server_addr)
469 .await
470 .map(AsyncIoTokioAsStd)
471 })
472 }
473
474 fn bind_udp(
475 &self,
476 local_addr: SocketAddr,
477 _server_addr: SocketAddr,
478 ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>> {
479 Box::pin(tokio::net::UdpSocket::bind(local_addr))
480 }
481 }
482
483 fn reap_tasks(join_set: &mut JoinSet<Result<(), ProtoError>>) {
485 while FutureExt::now_or_never(join_set.join_next()).is_some() {}
486 }
487
488 pub type TokioConnectionProvider = GenericConnector<TokioRuntimeProvider>;
490}