trust_dns_proto/xfer/
dns_exchange.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! This module contains all the types for demuxing DNS oriented streams.
9
10use std::marker::PhantomData;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use futures_channel::mpsc;
15use futures_util::future::{Future, FutureExt};
16use futures_util::stream::{Peekable, Stream, StreamExt};
17use tracing::{debug, warn};
18
19use crate::error::*;
20use crate::xfer::dns_handle::DnsHandle;
21use crate::xfer::DnsResponseReceiver;
22use crate::xfer::{
23    BufDnsRequestStreamHandle, DnsRequest, DnsRequestSender, DnsResponse, OneshotDnsRequest,
24    CHANNEL_BUFFER_SIZE,
25};
26use crate::Time;
27
28/// This is a generic Exchange implemented over multiplexed DNS connection providers.
29///
30/// The underlying `DnsRequestSender` is expected to multiplex any I/O connections. DnsExchange assumes that the underlying stream is responsible for this.
31#[must_use = "futures do nothing unless polled"]
32pub struct DnsExchange {
33    sender: BufDnsRequestStreamHandle,
34}
35
36impl DnsExchange {
37    /// Initializes a TcpStream with an existing tcp::TcpStream.
38    ///
39    /// This is intended for use with a TcpListener and Incoming.
40    ///
41    /// # Arguments
42    ///
43    /// * `stream` - the established IO stream for communication
44    pub fn from_stream<S, TE>(stream: S) -> (Self, DnsExchangeBackground<S, TE>)
45    where
46        S: DnsRequestSender + 'static + Send + Unpin,
47    {
48        let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
49        let message_sender = BufDnsRequestStreamHandle { sender };
50
51        Self::from_stream_with_receiver(stream, outbound_messages, message_sender)
52    }
53
54    /// Wraps a stream where a sender and receiver have already been established
55    pub fn from_stream_with_receiver<S, TE>(
56        stream: S,
57        receiver: mpsc::Receiver<OneshotDnsRequest>,
58        sender: BufDnsRequestStreamHandle,
59    ) -> (Self, DnsExchangeBackground<S, TE>)
60    where
61        S: DnsRequestSender + 'static + Send + Unpin,
62    {
63        let background = DnsExchangeBackground {
64            io_stream: stream,
65            outbound_messages: receiver.peekable(),
66            marker: PhantomData,
67        };
68
69        (Self { sender }, background)
70    }
71
72    /// Returns a future, which itself wraps a future which is awaiting connection.
73    ///
74    /// The connect_future should be lazy.
75    pub fn connect<F, S, TE>(connect_future: F) -> DnsExchangeConnect<F, S, TE>
76    where
77        F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
78        S: DnsRequestSender + 'static + Send + Unpin,
79        TE: Time + Unpin,
80    {
81        let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
82        let message_sender = BufDnsRequestStreamHandle { sender };
83
84        DnsExchangeConnect::connect(connect_future, outbound_messages, message_sender)
85    }
86}
87
88impl Clone for DnsExchange {
89    fn clone(&self) -> Self {
90        Self {
91            sender: self.sender.clone(),
92        }
93    }
94}
95
96impl DnsHandle for DnsExchange {
97    type Response = DnsExchangeSend;
98    type Error = ProtoError;
99
100    fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&mut self, request: R) -> Self::Response {
101        DnsExchangeSend {
102            result: self.sender.send(request),
103            _sender: self.sender.clone(), // TODO: this shouldn't be necessary, currently the presence of Senders is what allows the background to track current users, it generally is dropped right after send, this makes sure that there is at least one active after send
104        }
105    }
106}
107
108/// A Stream that will resolve to Responses after sending the request
109#[must_use = "futures do nothing unless polled"]
110pub struct DnsExchangeSend {
111    result: DnsResponseReceiver,
112    _sender: BufDnsRequestStreamHandle,
113}
114
115impl Stream for DnsExchangeSend {
116    type Item = Result<DnsResponse, ProtoError>;
117
118    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119        // as long as there is no result, poll the exchange
120        self.result.poll_next_unpin(cx)
121    }
122}
123
124/// This background future is responsible for driving all network operations for the DNS protocol.
125///
126/// It must be spawned before any DNS messages are sent.
127#[must_use = "futures do nothing unless polled"]
128pub struct DnsExchangeBackground<S, TE>
129where
130    S: DnsRequestSender + 'static + Send + Unpin,
131{
132    io_stream: S,
133    outbound_messages: Peekable<mpsc::Receiver<OneshotDnsRequest>>,
134    marker: PhantomData<TE>,
135}
136
137impl<S, TE> DnsExchangeBackground<S, TE>
138where
139    S: DnsRequestSender + 'static + Send + Unpin,
140{
141    fn pollable_split(&mut self) -> (&mut S, &mut Peekable<mpsc::Receiver<OneshotDnsRequest>>) {
142        (&mut self.io_stream, &mut self.outbound_messages)
143    }
144}
145
146impl<S, TE> Future for DnsExchangeBackground<S, TE>
147where
148    S: DnsRequestSender + 'static + Send + Unpin,
149    TE: Time + Unpin,
150{
151    type Output = Result<(), ProtoError>;
152
153    #[allow(clippy::unused_unit)]
154    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
155        let (io_stream, outbound_messages) = self.pollable_split();
156        let mut io_stream = Pin::new(io_stream);
157        let mut outbound_messages = Pin::new(outbound_messages);
158
159        // this will not accept incoming data while there is data to send
160        //  makes this self throttling.
161        loop {
162            // poll the underlying stream, to drive it...
163            match io_stream.as_mut().poll_next(cx) {
164                // The stream is ready
165                Poll::Ready(Some(Ok(()))) => (),
166                Poll::Pending => {
167                    if io_stream.is_shutdown() {
168                        // the io_stream is in a shutdown state, we are only waiting for final results...
169                        return Poll::Pending;
170                    }
171
172                    // NotReady and not shutdown, see if there are more messages to send
173                    ()
174                } // underlying stream is complete.
175                Poll::Ready(None) => {
176                    debug!("io_stream is done, shutting down");
177                    // TODO: return shutdown error to anything in the stream?
178
179                    return Poll::Ready(Ok(()));
180                }
181                Poll::Ready(Some(Err(err))) => {
182                    debug!(
183                        error = err.as_dyn(),
184                        "io_stream hit an error, shutting down"
185                    );
186
187                    return Poll::Ready(Err(err));
188                }
189            }
190
191            // then see if there is more to send
192            match outbound_messages.as_mut().poll_next(cx) {
193                // already handled above, here to make sure the poll() pops the next message
194                Poll::Ready(Some(dns_request)) => {
195                    // if there is no peer, this connection should die...
196                    let (dns_request, serial_response): (DnsRequest, _) = dns_request.into_parts();
197
198                    // Try to forward the `DnsResponseStream` to the requesting task. If we fail,
199                    // it must be because the requesting task has gone away / is no longer
200                    // interested. In that case, we can just log a warning, but there's no need
201                    // to take any more serious measures (such as shutting down this task).
202                    match serial_response.send_response(io_stream.send_message(dns_request)) {
203                        Ok(()) => (),
204                        Err(_) => {
205                            warn!("failed to associate send_message response to the sender");
206                        }
207                    }
208                }
209                // On not ready, this is our time to return...
210                Poll::Pending => return Poll::Pending,
211                Poll::Ready(None) => {
212                    // if there is nothing that can use this connection to send messages, then this is done...
213                    io_stream.shutdown();
214
215                    // now we'll await the stream to shutdown... see io_stream poll above
216                }
217            }
218
219            // else we loop to poll on the outbound_messages
220        }
221    }
222}
223
224/// A wrapper for a future DnsExchange connection.
225///
226/// DnsExchangeConnect is cloneable, making it possible to share this if the connection
227///  will be shared across threads.
228///
229/// The future will return a tuple of the DnsExchange (for sending messages) and a background
230///  for running the background tasks. The background is optional as only one thread should run
231///  the background. If returned, it must be spawned before any dns requests will function.
232pub struct DnsExchangeConnect<F, S, TE>(DnsExchangeConnectInner<F, S, TE>)
233where
234    F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
235    S: DnsRequestSender + 'static,
236    TE: Time + Unpin;
237
238impl<F, S, TE> DnsExchangeConnect<F, S, TE>
239where
240    F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
241    S: DnsRequestSender + 'static,
242    TE: Time + Unpin,
243{
244    fn connect(
245        connect_future: F,
246        outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
247        sender: BufDnsRequestStreamHandle,
248    ) -> Self {
249        Self(DnsExchangeConnectInner::Connecting {
250            connect_future,
251            outbound_messages: Some(outbound_messages),
252            sender: Some(sender),
253        })
254    }
255}
256
257#[allow(clippy::type_complexity)]
258impl<F, S, TE> Future for DnsExchangeConnect<F, S, TE>
259where
260    F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
261    S: DnsRequestSender + 'static + Send + Unpin,
262    TE: Time + Unpin,
263{
264    type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
265
266    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
267        self.0.poll_unpin(cx)
268    }
269}
270
271enum DnsExchangeConnectInner<F, S, TE>
272where
273    F: Future<Output = Result<S, ProtoError>> + 'static + Send,
274    S: DnsRequestSender + 'static + Send,
275    TE: Time + Unpin,
276{
277    Connecting {
278        connect_future: F,
279        outbound_messages: Option<mpsc::Receiver<OneshotDnsRequest>>,
280        sender: Option<BufDnsRequestStreamHandle>,
281    },
282    Connected {
283        exchange: DnsExchange,
284        background: Option<DnsExchangeBackground<S, TE>>,
285    },
286    FailAll {
287        error: ProtoError,
288        outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
289    },
290}
291
292#[allow(clippy::type_complexity)]
293impl<F, S, TE> Future for DnsExchangeConnectInner<F, S, TE>
294where
295    F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
296    S: DnsRequestSender + 'static + Send + Unpin,
297    TE: Time + Unpin,
298{
299    type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
300
301    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
302        loop {
303            let next;
304            match *self {
305                Self::Connecting {
306                    ref mut connect_future,
307                    ref mut outbound_messages,
308                    ref mut sender,
309                } => {
310                    let connect_future = Pin::new(connect_future);
311                    match connect_future.poll(cx) {
312                        Poll::Ready(Ok(stream)) => {
313                            //debug!("connection established: {}", stream);
314
315                            let (exchange, background) = DnsExchange::from_stream_with_receiver(
316                                stream,
317                                outbound_messages
318                                    .take()
319                                    .expect("cannot poll after complete"),
320                                sender.take().expect("cannot poll after complete"),
321                            );
322
323                            next = Self::Connected {
324                                exchange,
325                                background: Some(background),
326                            };
327                        }
328                        Poll::Pending => return Poll::Pending,
329                        Poll::Ready(Err(error)) => {
330                            debug!(error = error.as_dyn(), "stream errored while connecting");
331                            next = Self::FailAll {
332                                error,
333                                outbound_messages: outbound_messages
334                                    .take()
335                                    .expect("cannot poll after complete"),
336                            }
337                        }
338                    };
339                }
340                Self::Connected {
341                    ref exchange,
342                    ref mut background,
343                } => {
344                    let exchange = exchange.clone();
345                    let background = background.take().expect("cannot poll after complete");
346
347                    return Poll::Ready(Ok((exchange, background)));
348                }
349                Self::FailAll {
350                    ref error,
351                    ref mut outbound_messages,
352                } => {
353                    while let Some(outbound_message) = match outbound_messages.poll_next_unpin(cx) {
354                        Poll::Ready(opt) => opt,
355                        Poll::Pending => return Poll::Pending,
356                    } {
357                        // ignoring errors... best effort send...
358                        outbound_message
359                            .into_parts()
360                            .1
361                            .send_response(error.clone().into())
362                            .ok();
363                    }
364
365                    return Poll::Ready(Err(error.clone()));
366                }
367            }
368
369            *self = next;
370        }
371    }
372}