trust_dns_resolver/name_server/
connection_provider.rs

1// Copyright 2015-2019 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::io;
9use std::marker::Unpin;
10use std::net::SocketAddr;
11#[cfg(feature = "dns-over-quic")]
12use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{Context, Poll};
16
17use futures_util::future::{Future, FutureExt};
18use futures_util::ready;
19use futures_util::stream::{Stream, StreamExt};
20#[cfg(feature = "tokio-runtime")]
21use tokio::net::TcpStream as TokioTcpStream;
22#[cfg(all(feature = "dns-over-native-tls", not(feature = "dns-over-rustls")))]
23use tokio_native_tls::TlsStream as TokioTlsStream;
24#[cfg(all(
25    feature = "dns-over-openssl",
26    not(feature = "dns-over-rustls"),
27    not(feature = "dns-over-native-tls")
28))]
29use tokio_openssl::SslStream as TokioTlsStream;
30#[cfg(feature = "dns-over-rustls")]
31use tokio_rustls::client::TlsStream as TokioTlsStream;
32
33use crate::config::{NameServerConfig, Protocol, ResolverOpts};
34#[cfg(feature = "dns-over-https")]
35use proto::https::{HttpsClientConnect, HttpsClientStream};
36#[cfg(feature = "mdns")]
37use proto::multicast::{MdnsClientConnect, MdnsClientStream, MdnsQueryType};
38#[cfg(feature = "dns-over-quic")]
39use proto::quic::{QuicClientConnect, QuicClientStream};
40use proto::tcp::DnsTcpStream;
41use proto::udp::DnsUdpSocket;
42use proto::{
43    self,
44    error::ProtoError,
45    op::NoopMessageFinalizer,
46    tcp::TcpClientConnect,
47    tcp::TcpClientStream,
48    udp::UdpClientConnect,
49    udp::UdpClientStream,
50    xfer::{
51        DnsExchange, DnsExchangeConnect, DnsExchangeSend, DnsHandle, DnsMultiplexer,
52        DnsMultiplexerConnect, DnsRequest, DnsResponse,
53    },
54    Time,
55};
56#[cfg(feature = "tokio-runtime")]
57use proto::{iocompat::AsyncIoTokioAsStd, TokioTime};
58#[cfg(feature = "dns-over-quic")]
59use trust_dns_proto::udp::QuicLocalAddr;
60
61use crate::error::ResolveError;
62
63/// RuntimeProvider defines which async runtime that handles IO and timers.
64pub trait RuntimeProvider: Clone + Send + Sync + Unpin + 'static {
65    /// Handle to the executor;
66    type Handle: Clone + Send + Spawn + Sync + Unpin;
67
68    /// Timer
69    type Timer: Time + Send + Unpin;
70
71    #[cfg(not(feature = "dns-over-quic"))]
72    /// UdpSocket
73    type Udp: DnsUdpSocket + Send;
74    #[cfg(feature = "dns-over-quic")]
75    /// UdpSocket, where `QuicLocalAddr` is for `quinn` crate.
76    type Udp: DnsUdpSocket + QuicLocalAddr + Send;
77
78    /// TcpStream
79    type Tcp: DnsTcpStream;
80
81    /// Create a runtime handle
82    fn create_handle(&self) -> Self::Handle;
83
84    /// Create a TCP connection with custom configuration.
85    fn connect_tcp(
86        &self,
87        server_addr: SocketAddr,
88    ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>>;
89
90    /// Create a UDP socket bound to `local_addr`. The returned value should **not** be connected to `server_addr`.
91    /// *Notice: the future should be ready once returned at best effort. Otherwise UDP DNS may need much more retries.*
92    fn bind_udp(
93        &self,
94        local_addr: SocketAddr,
95        server_addr: SocketAddr,
96    ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>>;
97}
98
99/// Create `DnsHandle` with the help of `RuntimeProvider`.
100/// This trait is designed for customization.
101pub trait ConnectionProvider: 'static + Clone + Send + Sync + Unpin {
102    /// The handle to the connect for sending DNS requests.
103    type Conn: DnsHandle<Error = ResolveError> + Clone + Send + Sync + 'static;
104    /// Ths future is responsible for spawning any background tasks as necessary.
105    type FutureConn: Future<Output = Result<Self::Conn, ResolveError>> + Send + 'static;
106    /// Provider that handles the underlying I/O and timing.
107    type RuntimeProvider: RuntimeProvider;
108
109    /// Create a new connection.
110    fn new_connection(&self, config: &NameServerConfig, options: &ResolverOpts)
111        -> Self::FutureConn;
112}
113
114/// A type defines the Handle which can spawn future.
115pub trait Spawn {
116    /// Spawn a future in the background
117    fn spawn_bg<F>(&mut self, future: F)
118    where
119        F: Future<Output = Result<(), ProtoError>> + Send + 'static;
120}
121
122#[cfg(feature = "dns-over-tls")]
123/// Predefined type for TLS client stream
124type TlsClientStream<S> =
125    TcpClientStream<AsyncIoTokioAsStd<TokioTlsStream<proto::iocompat::AsyncIoStdAsTokio<S>>>>;
126
127/// The variants of all supported connections for the Resolver
128#[allow(clippy::large_enum_variant, clippy::type_complexity)]
129pub(crate) enum ConnectionConnect<R: RuntimeProvider> {
130    Udp(DnsExchangeConnect<UdpClientConnect<R::Udp>, UdpClientStream<R::Udp>, R::Timer>),
131    Tcp(
132        DnsExchangeConnect<
133            DnsMultiplexerConnect<
134                TcpClientConnect<<R as RuntimeProvider>::Tcp>,
135                TcpClientStream<<R as RuntimeProvider>::Tcp>,
136                NoopMessageFinalizer,
137            >,
138            DnsMultiplexer<TcpClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
139            R::Timer,
140        >,
141    ),
142    #[cfg(all(feature = "dns-over-tls", feature = "tokio-runtime"))]
143    Tls(
144        DnsExchangeConnect<
145            DnsMultiplexerConnect<
146                Pin<
147                    Box<
148                        dyn Future<
149                                Output = Result<
150                                    TlsClientStream<<R as RuntimeProvider>::Tcp>,
151                                    ProtoError,
152                                >,
153                            > + Send
154                            + 'static,
155                    >,
156                >,
157                TlsClientStream<<R as RuntimeProvider>::Tcp>,
158                NoopMessageFinalizer,
159            >,
160            DnsMultiplexer<TlsClientStream<<R as RuntimeProvider>::Tcp>, NoopMessageFinalizer>,
161            TokioTime,
162        >,
163    ),
164    #[cfg(all(feature = "dns-over-https", feature = "tokio-runtime"))]
165    Https(DnsExchangeConnect<HttpsClientConnect<R::Tcp>, HttpsClientStream, TokioTime>),
166    #[cfg(all(feature = "dns-over-quic", feature = "tokio-runtime"))]
167    Quic(DnsExchangeConnect<QuicClientConnect, QuicClientStream, TokioTime>),
168    #[cfg(feature = "mdns")]
169    Mdns(
170        DnsExchangeConnect<
171            DnsMultiplexerConnect<MdnsClientConnect, MdnsClientStream, NoopMessageFinalizer>,
172            DnsMultiplexer<MdnsClientStream, NoopMessageFinalizer>,
173            TokioTime,
174        >,
175    ),
176}
177
178/// Resolves to a new Connection
179#[must_use = "futures do nothing unless polled"]
180pub struct ConnectionFuture<R: RuntimeProvider> {
181    pub(crate) connect: ConnectionConnect<R>,
182    pub(crate) spawner: R::Handle,
183}
184
185impl<R: RuntimeProvider> Future for ConnectionFuture<R> {
186    type Output = Result<GenericConnection, ResolveError>;
187
188    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
189        Poll::Ready(Ok(match &mut self.connect {
190            ConnectionConnect::Udp(ref mut conn) => {
191                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
192                self.spawner.spawn_bg(bg);
193                GenericConnection(conn)
194            }
195            ConnectionConnect::Tcp(ref mut conn) => {
196                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
197                self.spawner.spawn_bg(bg);
198                GenericConnection(conn)
199            }
200            #[cfg(feature = "dns-over-tls")]
201            ConnectionConnect::Tls(ref mut conn) => {
202                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
203                self.spawner.spawn_bg(bg);
204                GenericConnection(conn)
205            }
206            #[cfg(feature = "dns-over-https")]
207            ConnectionConnect::Https(ref mut conn) => {
208                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
209                self.spawner.spawn_bg(bg);
210                GenericConnection(conn)
211            }
212            #[cfg(feature = "dns-over-quic")]
213            ConnectionConnect::Quic(ref mut conn) => {
214                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
215                self.spawner.spawn_bg(bg);
216                GenericConnection(conn)
217            }
218            #[cfg(feature = "mdns")]
219            ConnectionConnect::Mdns(ref mut conn) => {
220                let (conn, bg) = ready!(conn.poll_unpin(cx))?;
221                self.spawner.spawn_bg(bg);
222                GenericConnection(conn)
223            }
224        }))
225    }
226}
227
228/// A connected DNS handle
229#[derive(Clone)]
230pub struct GenericConnection(DnsExchange);
231
232impl DnsHandle for GenericConnection {
233    type Response = ConnectionResponse;
234    type Error = ResolveError;
235
236    fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&mut self, request: R) -> Self::Response {
237        ConnectionResponse(self.0.send(request))
238    }
239}
240
241/// Default connector for `GenericConnection`
242#[derive(Clone)]
243pub struct GenericConnector<P: RuntimeProvider> {
244    runtime_provider: P,
245}
246
247impl<P: RuntimeProvider> GenericConnector<P> {
248    /// Create a new instance.
249    pub fn new(runtime_provider: P) -> Self {
250        Self { runtime_provider }
251    }
252}
253
254impl<P: RuntimeProvider + Default> Default for GenericConnector<P> {
255    fn default() -> Self {
256        Self {
257            runtime_provider: P::default(),
258        }
259    }
260}
261
262impl<P: RuntimeProvider> ConnectionProvider for GenericConnector<P> {
263    type Conn = GenericConnection;
264    type FutureConn = ConnectionFuture<P>;
265    type RuntimeProvider = P;
266
267    fn new_connection(
268        &self,
269        config: &NameServerConfig,
270        options: &ResolverOpts,
271    ) -> Self::FutureConn {
272        let dns_connect = match config.protocol {
273            Protocol::Udp => {
274                let provider_handle = self.runtime_provider.clone();
275                let closure = move |local_addr: SocketAddr, server_addr: SocketAddr| {
276                    provider_handle.bind_udp(local_addr, server_addr)
277                };
278                let stream = UdpClientStream::with_creator(
279                    config.socket_addr,
280                    None,
281                    options.timeout,
282                    Arc::new(closure),
283                );
284                let exchange = DnsExchange::connect(stream);
285                ConnectionConnect::Udp(exchange)
286            }
287            Protocol::Tcp => {
288                let socket_addr = config.socket_addr;
289                let timeout = options.timeout;
290                let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
291
292                let (stream, handle) =
293                    TcpClientStream::with_future(tcp_future, socket_addr, timeout);
294                // TODO: need config for Signer...
295                let dns_conn = DnsMultiplexer::with_timeout(
296                    stream,
297                    handle,
298                    timeout,
299                    NoopMessageFinalizer::new(),
300                );
301
302                let exchange = DnsExchange::connect(dns_conn);
303                ConnectionConnect::Tcp(exchange)
304            }
305            #[cfg(feature = "dns-over-tls")]
306            Protocol::Tls => {
307                let socket_addr = config.socket_addr;
308                let timeout = options.timeout;
309                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
310                let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
311
312                #[cfg(feature = "dns-over-rustls")]
313                let client_config = config.tls_config.clone();
314
315                #[cfg(feature = "dns-over-rustls")]
316                let (stream, handle) = {
317                    crate::tls::new_tls_stream_with_future(
318                        tcp_future,
319                        socket_addr,
320                        tls_dns_name,
321                        client_config,
322                    )
323                };
324                #[cfg(not(feature = "dns-over-rustls"))]
325                let (stream, handle) = {
326                    crate::tls::new_tls_stream_with_future(tcp_future, socket_addr, tls_dns_name)
327                };
328
329                let dns_conn = DnsMultiplexer::with_timeout(
330                    stream,
331                    handle,
332                    timeout,
333                    NoopMessageFinalizer::new(),
334                );
335
336                let exchange = DnsExchange::connect(dns_conn);
337                ConnectionConnect::Tls(exchange)
338            }
339            #[cfg(feature = "dns-over-https")]
340            Protocol::Https => {
341                let socket_addr = config.socket_addr;
342                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
343                #[cfg(feature = "dns-over-rustls")]
344                let client_config = config.tls_config.clone();
345                let tcp_future = self.runtime_provider.connect_tcp(socket_addr);
346
347                let exchange = crate::https::new_https_stream_with_future(
348                    tcp_future,
349                    socket_addr,
350                    tls_dns_name,
351                    client_config,
352                );
353                ConnectionConnect::Https(exchange)
354            }
355            #[cfg(feature = "dns-over-quic")]
356            Protocol::Quic => {
357                let socket_addr = config.socket_addr;
358                let bind_addr = config.bind_addr.unwrap_or(match socket_addr {
359                    SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0),
360                    SocketAddr::V6(_) => {
361                        SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 0)
362                    }
363                });
364                let tls_dns_name = config.tls_dns_name.clone().unwrap_or_default();
365                #[cfg(feature = "dns-over-rustls")]
366                let client_config = config.tls_config.clone();
367                let udp_future = self.runtime_provider.bind_udp(bind_addr, socket_addr);
368
369                let exchange = crate::quic::new_quic_stream_with_future(
370                    udp_future,
371                    socket_addr,
372                    tls_dns_name,
373                    client_config,
374                );
375                ConnectionConnect::Quic(exchange)
376            }
377            #[cfg(feature = "mdns")]
378            Protocol::Mdns => {
379                let socket_addr = config.socket_addr;
380                let timeout = options.timeout;
381
382                let (stream, handle) =
383                    MdnsClientStream::new(socket_addr, MdnsQueryType::OneShot, None, None, None);
384                // TODO: need config for Signer...
385                let dns_conn = DnsMultiplexer::with_timeout(
386                    stream,
387                    handle,
388                    timeout,
389                    NoopMessageFinalizer::new(),
390                );
391
392                let exchange = DnsExchange::connect(dns_conn);
393                ConnectionConnect::Mdns(exchange)
394            }
395        };
396
397        ConnectionFuture::<P> {
398            connect: dns_connect,
399            spawner: self.runtime_provider.create_handle(),
400        }
401    }
402}
403
404/// A stream of response to a DNS request.
405#[must_use = "steam do nothing unless polled"]
406pub struct ConnectionResponse(DnsExchangeSend);
407
408impl Stream for ConnectionResponse {
409    type Item = Result<DnsResponse, ResolveError>;
410
411    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
412        Poll::Ready(ready!(self.0.poll_next_unpin(cx)).map(|r| r.map_err(ResolveError::from)))
413    }
414}
415
416#[cfg(feature = "tokio-runtime")]
417#[cfg_attr(docsrs, doc(cfg(feature = "tokio-runtime")))]
418#[allow(unreachable_pub)]
419pub mod tokio_runtime {
420    use super::*;
421    use std::sync::{Arc, Mutex};
422    use tokio::net::UdpSocket as TokioUdpSocket;
423    use tokio::task::JoinSet;
424
425    /// A handle to the Tokio runtime
426    #[derive(Clone, Default)]
427    pub struct TokioHandle {
428        join_set: Arc<Mutex<JoinSet<Result<(), ProtoError>>>>,
429    }
430
431    impl Spawn for TokioHandle {
432        fn spawn_bg<F>(&mut self, future: F)
433        where
434            F: Future<Output = Result<(), ProtoError>> + Send + 'static,
435        {
436            let mut join_set = self.join_set.lock().unwrap();
437            join_set.spawn(future);
438            reap_tasks(&mut join_set);
439        }
440    }
441
442    /// The Tokio Runtime for async execution
443    #[derive(Clone, Default)]
444    pub struct TokioRuntimeProvider(TokioHandle);
445
446    impl TokioRuntimeProvider {
447        /// Create a Tokio runtime
448        pub fn new() -> Self {
449            Self::default()
450        }
451    }
452
453    impl RuntimeProvider for TokioRuntimeProvider {
454        type Handle = TokioHandle;
455        type Timer = TokioTime;
456        type Udp = TokioUdpSocket;
457        type Tcp = AsyncIoTokioAsStd<TokioTcpStream>;
458
459        fn create_handle(&self) -> Self::Handle {
460            self.0.clone()
461        }
462
463        fn connect_tcp(
464            &self,
465            server_addr: SocketAddr,
466        ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Tcp>>>> {
467            Box::pin(async move {
468                TokioTcpStream::connect(server_addr)
469                    .await
470                    .map(AsyncIoTokioAsStd)
471            })
472        }
473
474        fn bind_udp(
475            &self,
476            local_addr: SocketAddr,
477            _server_addr: SocketAddr,
478        ) -> Pin<Box<dyn Send + Future<Output = io::Result<Self::Udp>>>> {
479            Box::pin(tokio::net::UdpSocket::bind(local_addr))
480        }
481    }
482
483    /// Reap finished tasks from a `JoinSet`, without awaiting or blocking.
484    fn reap_tasks(join_set: &mut JoinSet<Result<(), ProtoError>>) {
485        while FutureExt::now_or_never(join_set.join_next()).is_some() {}
486    }
487
488    /// Default ConnectionProvider with `GenericConnection`.
489    pub type TokioConnectionProvider = GenericConnector<TokioRuntimeProvider>;
490}