1use std::marker::PhantomData;
11use std::pin::Pin;
12use std::task::{Context, Poll};
13
14use futures_channel::mpsc;
15use futures_util::future::{Future, FutureExt};
16use futures_util::stream::{Peekable, Stream, StreamExt};
17use tracing::{debug, warn};
18
19use crate::error::*;
20use crate::xfer::dns_handle::DnsHandle;
21use crate::xfer::DnsResponseReceiver;
22use crate::xfer::{
23 BufDnsRequestStreamHandle, DnsRequest, DnsRequestSender, DnsResponse, OneshotDnsRequest,
24 CHANNEL_BUFFER_SIZE,
25};
26use crate::Time;
27
28#[must_use = "futures do nothing unless polled"]
32pub struct DnsExchange {
33 sender: BufDnsRequestStreamHandle,
34}
35
36impl DnsExchange {
37 pub fn from_stream<S, TE>(stream: S) -> (Self, DnsExchangeBackground<S, TE>)
45 where
46 S: DnsRequestSender + 'static + Send + Unpin,
47 {
48 let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
49 let message_sender = BufDnsRequestStreamHandle { sender };
50
51 Self::from_stream_with_receiver(stream, outbound_messages, message_sender)
52 }
53
54 pub fn from_stream_with_receiver<S, TE>(
56 stream: S,
57 receiver: mpsc::Receiver<OneshotDnsRequest>,
58 sender: BufDnsRequestStreamHandle,
59 ) -> (Self, DnsExchangeBackground<S, TE>)
60 where
61 S: DnsRequestSender + 'static + Send + Unpin,
62 {
63 let background = DnsExchangeBackground {
64 io_stream: stream,
65 outbound_messages: receiver.peekable(),
66 marker: PhantomData,
67 };
68
69 (Self { sender }, background)
70 }
71
72 pub fn connect<F, S, TE>(connect_future: F) -> DnsExchangeConnect<F, S, TE>
76 where
77 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
78 S: DnsRequestSender + 'static + Send + Unpin,
79 TE: Time + Unpin,
80 {
81 let (sender, outbound_messages) = mpsc::channel(CHANNEL_BUFFER_SIZE);
82 let message_sender = BufDnsRequestStreamHandle { sender };
83
84 DnsExchangeConnect::connect(connect_future, outbound_messages, message_sender)
85 }
86}
87
88impl Clone for DnsExchange {
89 fn clone(&self) -> Self {
90 Self {
91 sender: self.sender.clone(),
92 }
93 }
94}
95
96impl DnsHandle for DnsExchange {
97 type Response = DnsExchangeSend;
98 type Error = ProtoError;
99
100 fn send<R: Into<DnsRequest> + Unpin + Send + 'static>(&mut self, request: R) -> Self::Response {
101 DnsExchangeSend {
102 result: self.sender.send(request),
103 _sender: self.sender.clone(), }
105 }
106}
107
108#[must_use = "futures do nothing unless polled"]
110pub struct DnsExchangeSend {
111 result: DnsResponseReceiver,
112 _sender: BufDnsRequestStreamHandle,
113}
114
115impl Stream for DnsExchangeSend {
116 type Item = Result<DnsResponse, ProtoError>;
117
118 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
119 self.result.poll_next_unpin(cx)
121 }
122}
123
124#[must_use = "futures do nothing unless polled"]
128pub struct DnsExchangeBackground<S, TE>
129where
130 S: DnsRequestSender + 'static + Send + Unpin,
131{
132 io_stream: S,
133 outbound_messages: Peekable<mpsc::Receiver<OneshotDnsRequest>>,
134 marker: PhantomData<TE>,
135}
136
137impl<S, TE> DnsExchangeBackground<S, TE>
138where
139 S: DnsRequestSender + 'static + Send + Unpin,
140{
141 fn pollable_split(&mut self) -> (&mut S, &mut Peekable<mpsc::Receiver<OneshotDnsRequest>>) {
142 (&mut self.io_stream, &mut self.outbound_messages)
143 }
144}
145
146impl<S, TE> Future for DnsExchangeBackground<S, TE>
147where
148 S: DnsRequestSender + 'static + Send + Unpin,
149 TE: Time + Unpin,
150{
151 type Output = Result<(), ProtoError>;
152
153 #[allow(clippy::unused_unit)]
154 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
155 let (io_stream, outbound_messages) = self.pollable_split();
156 let mut io_stream = Pin::new(io_stream);
157 let mut outbound_messages = Pin::new(outbound_messages);
158
159 loop {
162 match io_stream.as_mut().poll_next(cx) {
164 Poll::Ready(Some(Ok(()))) => (),
166 Poll::Pending => {
167 if io_stream.is_shutdown() {
168 return Poll::Pending;
170 }
171
172 ()
174 } Poll::Ready(None) => {
176 debug!("io_stream is done, shutting down");
177 return Poll::Ready(Ok(()));
180 }
181 Poll::Ready(Some(Err(err))) => {
182 debug!(
183 error = err.as_dyn(),
184 "io_stream hit an error, shutting down"
185 );
186
187 return Poll::Ready(Err(err));
188 }
189 }
190
191 match outbound_messages.as_mut().poll_next(cx) {
193 Poll::Ready(Some(dns_request)) => {
195 let (dns_request, serial_response): (DnsRequest, _) = dns_request.into_parts();
197
198 match serial_response.send_response(io_stream.send_message(dns_request)) {
203 Ok(()) => (),
204 Err(_) => {
205 warn!("failed to associate send_message response to the sender");
206 }
207 }
208 }
209 Poll::Pending => return Poll::Pending,
211 Poll::Ready(None) => {
212 io_stream.shutdown();
214
215 }
217 }
218
219 }
221 }
222}
223
224pub struct DnsExchangeConnect<F, S, TE>(DnsExchangeConnectInner<F, S, TE>)
233where
234 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
235 S: DnsRequestSender + 'static,
236 TE: Time + Unpin;
237
238impl<F, S, TE> DnsExchangeConnect<F, S, TE>
239where
240 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
241 S: DnsRequestSender + 'static,
242 TE: Time + Unpin,
243{
244 fn connect(
245 connect_future: F,
246 outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
247 sender: BufDnsRequestStreamHandle,
248 ) -> Self {
249 Self(DnsExchangeConnectInner::Connecting {
250 connect_future,
251 outbound_messages: Some(outbound_messages),
252 sender: Some(sender),
253 })
254 }
255}
256
257#[allow(clippy::type_complexity)]
258impl<F, S, TE> Future for DnsExchangeConnect<F, S, TE>
259where
260 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
261 S: DnsRequestSender + 'static + Send + Unpin,
262 TE: Time + Unpin,
263{
264 type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
265
266 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
267 self.0.poll_unpin(cx)
268 }
269}
270
271enum DnsExchangeConnectInner<F, S, TE>
272where
273 F: Future<Output = Result<S, ProtoError>> + 'static + Send,
274 S: DnsRequestSender + 'static + Send,
275 TE: Time + Unpin,
276{
277 Connecting {
278 connect_future: F,
279 outbound_messages: Option<mpsc::Receiver<OneshotDnsRequest>>,
280 sender: Option<BufDnsRequestStreamHandle>,
281 },
282 Connected {
283 exchange: DnsExchange,
284 background: Option<DnsExchangeBackground<S, TE>>,
285 },
286 FailAll {
287 error: ProtoError,
288 outbound_messages: mpsc::Receiver<OneshotDnsRequest>,
289 },
290}
291
292#[allow(clippy::type_complexity)]
293impl<F, S, TE> Future for DnsExchangeConnectInner<F, S, TE>
294where
295 F: Future<Output = Result<S, ProtoError>> + 'static + Send + Unpin,
296 S: DnsRequestSender + 'static + Send + Unpin,
297 TE: Time + Unpin,
298{
299 type Output = Result<(DnsExchange, DnsExchangeBackground<S, TE>), ProtoError>;
300
301 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
302 loop {
303 let next;
304 match *self {
305 Self::Connecting {
306 ref mut connect_future,
307 ref mut outbound_messages,
308 ref mut sender,
309 } => {
310 let connect_future = Pin::new(connect_future);
311 match connect_future.poll(cx) {
312 Poll::Ready(Ok(stream)) => {
313 let (exchange, background) = DnsExchange::from_stream_with_receiver(
316 stream,
317 outbound_messages
318 .take()
319 .expect("cannot poll after complete"),
320 sender.take().expect("cannot poll after complete"),
321 );
322
323 next = Self::Connected {
324 exchange,
325 background: Some(background),
326 };
327 }
328 Poll::Pending => return Poll::Pending,
329 Poll::Ready(Err(error)) => {
330 debug!(error = error.as_dyn(), "stream errored while connecting");
331 next = Self::FailAll {
332 error,
333 outbound_messages: outbound_messages
334 .take()
335 .expect("cannot poll after complete"),
336 }
337 }
338 };
339 }
340 Self::Connected {
341 ref exchange,
342 ref mut background,
343 } => {
344 let exchange = exchange.clone();
345 let background = background.take().expect("cannot poll after complete");
346
347 return Poll::Ready(Ok((exchange, background)));
348 }
349 Self::FailAll {
350 ref error,
351 ref mut outbound_messages,
352 } => {
353 while let Some(outbound_message) = match outbound_messages.poll_next_unpin(cx) {
354 Poll::Ready(opt) => opt,
355 Poll::Pending => return Poll::Pending,
356 } {
357 outbound_message
359 .into_parts()
360 .1
361 .send_response(error.clone().into())
362 .ok();
363 }
364
365 return Poll::Ready(Err(error.clone()));
366 }
367 }
368
369 *self = next;
370 }
371 }
372}