hickory_resolver/name_server/
connection_provider.rs

1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::io;
9use std::marker::Unpin;
10use std::net::SocketAddr;
11#[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
12use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{Context, Poll};
16
17use futures_util::future::{Future, FutureExt};
18use futures_util::ready;
19use futures_util::stream::{Stream, StreamExt};
20#[cfg(feature = "tokio-runtime")]
21use tokio::net::TcpStream as TokioTcpStream;
22#[cfg(all(feature = "dns-over-native-tls", not(feature = "dns-over-rustls")))]
23use tokio_native_tls::TlsStream as TokioTlsStream;
24#[cfg(all(
25    feature = "dns-over-openssl",
26    not(feature = "dns-over-rustls"),
27    not(feature = "dns-over-native-tls")
28))]
29use tokio_openssl::SslStream as TokioTlsStream;
30#[cfg(feature = "dns-over-rustls")]
31use tokio_rustls::client::TlsStream as TokioTlsStream;
32
33use crate::config::{NameServerConfig, Protocol, ResolverOpts};
34#[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
35use hickory_proto::udp::QuicLocalAddr;
36#[cfg(feature = "dns-over-https")]
37use proto::h2::{HttpsClientConnect, HttpsClientStream};
38#[cfg(feature = "dns-over-h3")]
39use proto::h3::{H3ClientConnect, H3ClientStream};
40#[cfg(feature = "mdns")]
41use proto::multicast::{MdnsClientConnect, MdnsClientStream, MdnsQueryType};
42#[cfg(feature = "dns-over-quic")]
43use proto::quic::{QuicClientConnect, QuicClientStream};
44use proto::tcp::DnsTcpStream;
45use proto::udp::DnsUdpSocket;
46use proto::{
47    self,
48    error::ProtoError,
49    op::NoopMessageFinalizer,
50    tcp::TcpClientConnect,
51    tcp::TcpClientStream,
52    udp::UdpClientConnect,
53    udp::UdpClientStream,
54    xfer::{
55        DnsExchange, DnsExchangeConnect, DnsExchangeSend, DnsHandle, DnsMultiplexer,
56        DnsMultiplexerConnect, DnsRequest, DnsResponse,
57    },
58    Time,
59};
60#[cfg(feature = "tokio-runtime")]
61use proto::{iocompat::AsyncIoTokioAsStd, TokioTime};
62
63use crate::error::ResolveError;
64
65/// RuntimeProvider defines which async runtime that handles IO and timers.
66pub trait RuntimeProvider: Clone + Send + Sync + Unpin + 'static {
67    /// Handle to the executor;
68    type Handle: Clone + Send + Spawn + Sync + Unpin;
69
70    /// Timer
71    type Timer: Time + Send + Unpin;
72
73    #[cfg(not(any(feature = "dns-over-quic", feature = "dns-over-h3")))]
74    /// UdpSocket
75    type Udp: DnsUdpSocket + Send;
76    #[cfg(any(feature = "dns-over-quic", feature = "dns-over-h3"))]
77    /// UdpSocket, where `QuicLocalAddr` is for `quinn` crate.
78    type Udp: DnsUdpSocket + QuicLocalAddr + Send;
79
80    /// TcpStream
81    type Tcp: DnsTcpStream;
82
83    /// Create a runtime handle
84    fn create_handle(&self) -> Self::Handle;
85
86    /// Create a TCP connection with custom configuration.
87    fn connect_tcp(
88        &self,
89        server_addr: SocketAddr,
90    ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>>;
91
92    /// Create a UDP socket bound to `local_addr`. The returned value should **not** be connected to `server_addr`.
93    /// *Notice: the future should be ready once returned at best effort. Otherwise UDP DNS may need much more retries.*
94    fn bind_udp(
95        &self,
96        local_addr: SocketAddr,
97        server_addr: SocketAddr,
98    ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>>;
99}
100
101/// Create `DnsHandle` with the help of `RuntimeProvider`.
102/// This trait is designed for customization.
103pub trait ConnectionProvider: 'static + Clone + Send + Sync + Unpin {
104    /// The handle to the connect for sending DNS requests.
105    type Conn: DnsHandle<Error = ResolveError> + Clone + Send + Sync + 'static;
106    /// Ths future is responsible for spawning any background tasks as necessary.
107    type FutureConn: Future<Output = Result<Self::Conn, ResolveError>> + Send + 'static;
108    /// Provider that handles the underlying I/O and timing.
109    type RuntimeProvider: RuntimeProvider;
110
111    /// Create a new connection.
112    fn new_connection(&self, config: &NameServerConfig, options: &ResolverOpts)
113        -> Self::FutureConn;
114}
115
116/// A type defines the Handle which can spawn future.
117pub trait Spawn {
118    /// Spawn a future in the background
119    fn spawn_bg<F>(&mut self, future: F)
120    where
121        F: Future<Output = Result<(), ProtoError>> + Send + 'static;
122}
123
124#[cfg(feature = "dns-over-tls")]
125/// Predefined type for TLS client stream
126type TlsClientStream<S> =
127    TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<proto::iocompat::AsyncIoStdAsTokio<S>>>>;
128
129/// The variants of all supported connections for the Resolver
130#[allow(clippy::large_enum_variant, clippy::type_complexity)]
131pub(crate) enum ConnectionConnect<R: RuntimeProvider> {
132    Udp(DnsExchangeConnect<UdpClientConnect<R::Udp>, UdpClientStream<R::Udp>, R::Timer>),
133    Tcp(
134        DnsExchangeConnect<
135            DnsMultiplexerConnect<
136                TcpClientConnect<<R as RuntimeProvider>::Tcp>,
137                TcpClientStream<<R as RuntimeProvider>::Tcp>,
138                NoopMessageFinalizer,
139            >,
140            DnsMultiplexer<TcpClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
141            R::Timer,
142        >,
143    ),
144    #[cfg(all(feature = "dns-over-tls", feature = "tokio-runtime"))]
145    Tls(
146        DnsExchangeConnect<
147            DnsMultiplexerConnect<
148                Pin<
149                    Box<
150                        dyn Future<
151                                Output = Result<
152                                    TlsClientStream<<R as RuntimeProvider>::Tcp>,
153                                    ProtoError,
154                                >,
155                            > + Send
156                            + 'static,
157                    >,
158                >,
159                TlsClientStream<<R as RuntimeProvider>::Tcp>,
160                NoopMessageFinalizer,
161            >,
162            DnsMultiplexer<TlsClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
163            TokioTime,
164        >,
165    ),
166    #[cfg(all(feature = "dns-over-https", feature = "tokio-runtime"))]
167    Https(DnsExchangeConnect<HttpsClientConnect<R::Tcp>, HttpsClientStream, TokioTime>),
168    #[cfg(all(feature = "dns-over-quic", feature = "tokio-runtime"))]
169    Quic(DnsExchangeConnect<QuicClientConnect, QuicClientStream, TokioTime>),
170    #[cfg(all(feature = "dns-over-h3", feature = "tokio-runtime"))]
171    H3(DnsExchangeConnect<H3ClientConnect, H3ClientStream, TokioTime>),
172    #[cfg(feature = "mdns")]
173    Mdns(
174        DnsExchangeConnect<
175            DnsMultiplexerConnect<MdnsClientConnect, MdnsClientStream, NoopMessageFinalizer>,
176            DnsMultiplexer<MdnsClientStream, NoopMessageFinalizer>,
177            TokioTime,
178        >,
179    ),
180}
181
182/// Resolves to a new Connection
183#[must_use = "futures do nothing unless polled"]
184pub struct ConnectionFuture<R: RuntimeProvider> {
185    pub(crate) connect: ConnectionConnect<R>,
186    pub(crate) spawner: R::Handle,
187}
188
189impl<R: RuntimeProvider> Future for ConnectionFuture<R> {
190    type Output = Result<GenericConnection, ResolveError>;
191
192    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
193        Poll::Ready(Ok(match &mut self.connect {
194            ConnectionConnect::Udp(ref mut conn) => {
195                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
196                self.spawner.spawn_bg(bg);
197                GenericConnection(conn)
198            }
199            ConnectionConnect::Tcp(ref mut conn) => {
200                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
201                self.spawner.spawn_bg(bg);
202                GenericConnection(conn)
203            }
204            #[cfg(feature = "dns-over-tls")]
205            ConnectionConnect::Tls(ref mut conn) => {
206                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
207                self.spawner.spawn_bg(bg);
208                GenericConnection(conn)
209            }
210            #[cfg(feature = "dns-over-https")]
211            ConnectionConnect::Https(ref mut conn) => {
212                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
213                self.spawner.spawn_bg(bg);
214                GenericConnection(conn)
215            }
216            #[cfg(feature = "dns-over-quic")]
217            ConnectionConnect::Quic(ref mut conn) => {
218                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
219                self.spawner.spawn_bg(bg);
220                GenericConnection(conn)
221            }
222            #[cfg(feature = "dns-over-h3")]
223            ConnectionConnect::H3(ref mut conn) => {
224                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
225                self.spawner.spawn_bg(bg);
226                GenericConnection(conn)
227            }
228            #[cfg(feature = "mdns")]
229            ConnectionConnect::Mdns(ref mut conn) => {
230                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
231                self.spawner.spawn_bg(bg);
232                GenericConnection(conn)
233            }
234        }))
235    }
236}
237
238/// A connected DNS handle
239#[derive(Clone)]
240pub struct GenericConnection(DnsExchange);
241
242impl DnsHandle for GenericConnection {
243    type Response = ConnectionResponse;
244    type Error = ResolveError;
245
246    fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&self, request: R) -> Self::Response {
247        ConnectionResponse(self.0.send(request))
248    }
249}
250
251/// Default connector for `GenericConnection`
252#[derive(Clone)]
253pub struct GenericConnector<P: RuntimeProvider> {
254    runtime_provider: P,
255}
256
257impl<P: RuntimeProvider> GenericConnector<P> {
258    /// Create a new instance.
259    pub fn new(runtime_provider: P) -> Self {
260        Self { runtime_provider }
261    }
262}
263
264impl<P: RuntimeProvider + Default> Default for GenericConnector<P> {
265    fn default() -> Self {
266        Self {
267            runtime_provider: P::default(),
268        }
269    }
270}
271
272impl<P: RuntimeProvider> ConnectionProvider for GenericConnector<P> {
273    type Conn = GenericConnection;
274    type FutureConn = ConnectionFuture<P>;
275    type RuntimeProvider = P;
276
277    fn new_connection(
278        &self,
279        config: &NameServerConfig,
280        options: &ResolverOpts,
281    ) -> Self::FutureConn {
282        let dns_connect = match config.protocol {
283            Protocol::Udp => {
284                let provider_handle = self.runtime_provider.clone();
285                let closure = move |local_addr: SocketAddr, server_addr: SocketAddr| {
286                    provider_handle.bind_udp(local_addr, server_addr)
287                };
288                let stream = UdpClientStream::with_creator(
289                    config.socket_addr,
290                    None,
291                    options.timeout,
292                    Arc::new(closure),
293                );
294                let exchange = DnsExchange::connect(stream);
295                ConnectionConnect::Udp(exchange)
296            }
297            Protocol::Tcp => {
298                let socket_addr = config.socket_addr;
299                let timeout = options.timeout;
300                let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
301
302                let (stream, handle) =
303                    TcpClientStream::with_future(tcp_future, socket_addr, timeout);
304                // TODO: need config for Signer...
305                let dns_conn = DnsMultiplexer::with_timeout(
306                    stream,
307                    handle,
308                    timeout,
309                    NoopMessageFinalizer::new(),
310                );
311
312                let exchange = DnsExchange::connect(dns_conn);
313                ConnectionConnect::Tcp(exchange)
314            }
315            #[cfg(feature = "dns-over-tls")]
316            Protocol::Tls => {
317                let socket_addr = config.socket_addr;
318                let timeout = options.timeout;
319                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
320                let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
321
322                #[cfg(feature = "dns-over-rustls")]
323                let client_config = config.tls_config.clone();
324
325                #[cfg(feature = "dns-over-rustls")]
326                let (stream, handle) = {
327                    crate::tls::new_tls_stream_with_future(
328                        tcp_future,
329                        socket_addr,
330                        tls_dns_name,
331                        client_config,
332                    )
333                };
334                #[cfg(not(feature = "dns-over-rustls"))]
335                let (stream, handle) = {
336                    crate::tls::new_tls_stream_with_future(tcp_future, socket_addr, tls_dns_name)
337                };
338
339                let dns_conn = DnsMultiplexer::with_timeout(
340                    stream,
341                    handle,
342                    timeout,
343                    NoopMessageFinalizer::new(),
344                );
345
346                let exchange = DnsExchange::connect(dns_conn);
347                ConnectionConnect::Tls(exchange)
348            }
349            #[cfg(feature = "dns-over-https")]
350            Protocol::Https => {
351                let socket_addr = config.socket_addr;
352                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
353                #[cfg(feature = "dns-over-rustls")]
354                let client_config = config.tls_config.clone();
355                let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
356
357                let exchange = crate::h2::new_https_stream_with_future(
358                    tcp_future,
359                    socket_addr,
360                    tls_dns_name,
361                    client_config,
362                );
363                ConnectionConnect::Https(exchange)
364            }
365            #[cfg(feature = "dns-over-quic")]
366            Protocol::Quic => {
367                let socket_addr = config.socket_addr;
368                let bind_addr = config.bind_addr.unwrap_or(match socket_addr {
369                    SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
370                    SocketAddr::V6(_) => {
371                        SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
372                    }
373                });
374                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
375                #[cfg(feature = "dns-over-rustls")]
376                let client_config = config.tls_config.clone();
377                let udp_future = self.runtime_provider.bind_udp(bind_addr, socket_addr);
378
379                let exchange = crate::quic::new_quic_stream_with_future(
380                    udp_future,
381                    socket_addr,
382                    tls_dns_name,
383                    client_config,
384                );
385                ConnectionConnect::Quic(exchange)
386            }
387            #[cfg(feature = "dns-over-h3")]
388            Protocol::H3 => {
389                let socket_addr = config.socket_addr;
390                let bind_addr = config.bind_addr.unwrap_or(match socket_addr {
391                    SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
392                    SocketAddr::V6(_) => {
393                        SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
394                    }
395                });
396                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
397                let client_config = config.tls_config.clone();
398                let udp_future = self.runtime_provider.bind_udp(bind_addr, socket_addr);
399
400                let exchange = crate::h3::new_h3_stream_with_future(
401                    udp_future,
402                    socket_addr,
403                    tls_dns_name,
404                    client_config,
405                );
406                ConnectionConnect::H3(exchange)
407            }
408            #[cfg(feature = "mdns")]
409            Protocol::Mdns => {
410                let socket_addr = config.socket_addr;
411                let timeout = options.timeout;
412
413                let (stream, handle) =
414                    MdnsClientStream::new(socket_addr, MdnsQueryType::OneShot, None, None, None);
415                // TODO: need config for Signer...
416                let dns_conn = DnsMultiplexer::with_timeout(
417                    stream,
418                    handle,
419                    timeout,
420                    NoopMessageFinalizer::new(),
421                );
422
423                let exchange = DnsExchange::connect(dns_conn);
424                ConnectionConnect::Mdns(exchange)
425            }
426        };
427
428        ConnectionFuture::<P> {
429            connect: dns_connect,
430            spawner: self.runtime_provider.create_handle(),
431        }
432    }
433}
434
435/// A stream of response to a DNS request.
436#[must_use = "steam do nothing unless polled"]
437pub struct ConnectionResponse(DnsExchangeSend);
438
439impl Stream for ConnectionResponse {
440    type Item = Result<DnsResponse, ResolveError>;
441
442    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
443        Poll::Ready(ready!(self.0.poll_next_unpin(cx)).map(|r| r.map_err(ResolveError::from)))
444    }
445}
446
447#[cfg(feature = "tokio-runtime")]
448#[cfg_attr(docsrs, doc(cfg(feature = "tokio-runtime")))]
449#[allow(unreachable_pub)]
450pub mod tokio_runtime {
451    use super::*;
452    use std::sync::{Arc, Mutex};
453    use tokio::net::UdpSocket as TokioUdpSocket;
454    use tokio::task::JoinSet;
455
456    /// A handle to the Tokio runtime
457    #[derive(Clone, Default)]
458    pub struct TokioHandle {
459        join_set: Arc<Mutex<JoinSet<Result<(), ProtoError>>>>,
460    }
461
462    impl Spawn for TokioHandle {
463        fn spawn_bg<F>(&mut self, future: F)
464        where
465            F: Future<Output = Result<(), ProtoError>> + Send + 'static,
466        {
467            let mut join_set = self.join_set.lock().unwrap();
468            join_set.spawn(future);
469            reap_tasks(&mut join_set);
470        }
471    }
472
473    /// The Tokio Runtime for async execution
474    #[derive(Clone, Default)]
475    pub struct TokioRuntimeProvider(TokioHandle);
476
477    impl TokioRuntimeProvider {
478        /// Create a Tokio runtime
479        pub fn new() -> Self {
480            Self::default()
481        }
482    }
483
484    impl RuntimeProvider for TokioRuntimeProvider {
485        type Handle = TokioHandle;
486        type Timer = TokioTime;
487        type Udp = TokioUdpSocket;
488        type Tcp = AsyncIoTokioAsStd<TokioTcpStream>;
489
490        fn create_handle(&self) -> Self::Handle {
491            self.0.clone()
492        }
493
494        fn connect_tcp(
495            &self,
496            server_addr: SocketAddr,
497        ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>> {
498            Box::pin(async move {
499                TokioTcpStream::connect(server_addr)
500                    .await
501                    .map(AsyncIoTokioAsStd)
502            })
503        }
504
505        fn bind_udp(
506            &self,
507            local_addr: SocketAddr,
508            _server_addr: SocketAddr,
509        ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>> {
510            Box::pin(tokio::net::UdpSocket::bind(local_addr))
511        }
512    }
513
514    /// Reap finished tasks from a `JoinSet`, without awaiting or blocking.
515    fn reap_tasks(join_set: &mut JoinSet<Result<(), ProtoError>>) {
516        while FutureExt::now_or_never(join_set.join_next())
517            .flatten()
518            .is_some()
519        {}
520    }
521
522    /// Default ConnectionProvider with `GenericConnection`.
523    pub type TokioConnectionProvider = GenericConnector<TokioRuntimeProvider>;
524}