litep2p/protocol/libp2p/kademlia/
executor.rs

1// Copyright 2023 litep2p developers
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::{protocol::libp2p::kademlia::query::QueryId, substream::Substream, PeerId};
22
23use bytes::{Bytes, BytesMut};
24use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt};
25
26use std::{
27    future::Future,
28    pin::Pin,
29    task::{Context, Poll, Waker},
30    time::Duration,
31};
32
33/// Read timeout for inbound messages.
34const READ_TIMEOUT: Duration = Duration::from_secs(15);
35
36/// Query result.
37#[derive(Debug)]
38pub enum QueryResult {
39    /// Message was sent to remote peer successfully.
40    SendSuccess {
41        /// Substream.
42        substream: Substream,
43    },
44
45    /// Message was read from the remote peer successfully.
46    ReadSuccess {
47        /// Substream.
48        substream: Substream,
49
50        /// Read message.
51        message: BytesMut,
52    },
53
54    /// Timeout while reading a response from the substream.
55    Timeout,
56
57    /// Substream was closed wile reading/writing message to remote peer.
58    SubstreamClosed,
59}
60
61/// Query result.
62#[derive(Debug)]
63pub struct QueryContext {
64    /// Peer ID.
65    pub peer: PeerId,
66
67    /// Query ID.
68    pub query_id: Option<QueryId>,
69
70    /// Query result.
71    pub result: QueryResult,
72}
73
74/// Wrapper around [`FuturesUnordered`] that wakes a task up automatically.
75#[derive(Default)]
76pub struct FuturesStream<F> {
77    futures: FuturesUnordered<F>,
78    waker: Option<Waker>,
79}
80
81impl<F> FuturesStream<F> {
82    /// Create new [`FuturesStream`].
83    pub fn new() -> Self {
84        Self {
85            futures: FuturesUnordered::new(),
86            waker: None,
87        }
88    }
89
90    /// Push a future for processing.
91    pub fn push(&mut self, future: F) {
92        self.futures.push(future);
93
94        if let Some(waker) = self.waker.take() {
95            waker.wake();
96        }
97    }
98}
99
100impl<F: Future> Stream for FuturesStream<F> {
101    type Item = <F as Future>::Output;
102
103    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
104        let Poll::Ready(Some(result)) = self.futures.poll_next_unpin(cx) else {
105            // We must save the current waker to wake up the task when new futures are inserted.
106            //
107            // Otherwise, simply returning `Poll::Pending` here would cause the task to never be
108            // woken up again.
109            //
110            // We were previously relying on some other task from the `loop tokio::select!` to
111            // finish.
112            self.waker = Some(cx.waker().clone());
113
114            return Poll::Pending;
115        };
116
117        Poll::Ready(Some(result))
118    }
119}
120
121/// Query executor.
122pub struct QueryExecutor {
123    /// Pending futures.
124    futures: FuturesStream<BoxFuture<'static, QueryContext>>,
125}
126
127impl QueryExecutor {
128    /// Create new [`QueryExecutor`]
129    pub fn new() -> Self {
130        Self {
131            futures: FuturesStream::new(),
132        }
133    }
134
135    /// Send message to remote peer.
136    pub fn send_message(&mut self, peer: PeerId, message: Bytes, mut substream: Substream) {
137        self.futures.push(Box::pin(async move {
138            match substream.send_framed(message).await {
139                Ok(_) => QueryContext {
140                    peer,
141                    query_id: None,
142                    result: QueryResult::SendSuccess { substream },
143                },
144                Err(_) => QueryContext {
145                    peer,
146                    query_id: None,
147                    result: QueryResult::SubstreamClosed,
148                },
149            }
150        }));
151    }
152
153    /// Read message from remote peer with timeout.
154    pub fn read_message(
155        &mut self,
156        peer: PeerId,
157        query_id: Option<QueryId>,
158        mut substream: Substream,
159    ) {
160        self.futures.push(Box::pin(async move {
161            match tokio::time::timeout(READ_TIMEOUT, substream.next()).await {
162                Err(_) => QueryContext {
163                    peer,
164                    query_id,
165                    result: QueryResult::Timeout,
166                },
167                Ok(Some(Ok(message))) => QueryContext {
168                    peer,
169                    query_id,
170                    result: QueryResult::ReadSuccess { substream, message },
171                },
172                Ok(None) | Ok(Some(Err(_))) => QueryContext {
173                    peer,
174                    query_id,
175                    result: QueryResult::SubstreamClosed,
176                },
177            }
178        }));
179    }
180
181    /// Send request to remote peer and read response.
182    pub fn send_request_read_response(
183        &mut self,
184        peer: PeerId,
185        query_id: Option<QueryId>,
186        message: Bytes,
187        mut substream: Substream,
188    ) {
189        self.futures.push(Box::pin(async move {
190            if let Err(_) = substream.send_framed(message).await {
191                let _ = substream.close().await;
192                return QueryContext {
193                    peer,
194                    query_id,
195                    result: QueryResult::SubstreamClosed,
196                };
197            }
198
199            match tokio::time::timeout(READ_TIMEOUT, substream.next()).await {
200                Err(_) => QueryContext {
201                    peer,
202                    query_id,
203                    result: QueryResult::Timeout,
204                },
205                Ok(Some(Ok(message))) => QueryContext {
206                    peer,
207                    query_id,
208                    result: QueryResult::ReadSuccess { substream, message },
209                },
210                Ok(None) | Ok(Some(Err(_))) => QueryContext {
211                    peer,
212                    query_id,
213                    result: QueryResult::SubstreamClosed,
214                },
215            }
216        }));
217    }
218}
219
220impl Stream for QueryExecutor {
221    type Item = QueryContext;
222
223    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
224        self.futures.poll_next_unpin(cx)
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::{mock::substream::MockSubstream, types::SubstreamId};
232
233    #[tokio::test]
234    async fn substream_read_timeout() {
235        let mut executor = QueryExecutor::new();
236        let peer = PeerId::random();
237        let mut substream = MockSubstream::new();
238        substream.expect_poll_next().returning(|_| Poll::Pending);
239        let substream = Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream));
240
241        executor.read_message(peer, None, substream);
242
243        match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
244            Ok(Some(QueryContext {
245                peer: queried_peer,
246                query_id,
247                result,
248            })) => {
249                assert_eq!(peer, queried_peer);
250                assert!(query_id.is_none());
251                assert!(std::matches!(result, QueryResult::Timeout));
252            }
253            result => panic!("invalid result received: {result:?}"),
254        }
255    }
256
257    #[tokio::test]
258    async fn substream_read_substream_closed() {
259        let mut executor = QueryExecutor::new();
260        let peer = PeerId::random();
261        let mut substream = MockSubstream::new();
262        substream.expect_poll_next().times(1).return_once(|_| {
263            Poll::Ready(Some(Err(crate::error::SubstreamError::ConnectionClosed)))
264        });
265
266        executor.read_message(
267            peer,
268            Some(QueryId(1338)),
269            Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)),
270        );
271
272        match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
273            Ok(Some(QueryContext {
274                peer: queried_peer,
275                query_id,
276                result,
277            })) => {
278                assert_eq!(peer, queried_peer);
279                assert_eq!(query_id, Some(QueryId(1338)));
280                assert!(std::matches!(result, QueryResult::SubstreamClosed));
281            }
282            result => panic!("invalid result received: {result:?}"),
283        }
284    }
285
286    #[tokio::test]
287    async fn send_succeeds_no_message_read() {
288        let mut executor = QueryExecutor::new();
289        let peer = PeerId::random();
290
291        // prepare substream which succeeds in sending the message but closes right after
292        let mut substream = MockSubstream::new();
293        substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(())));
294        substream.expect_start_send().times(1).return_once(|_| Ok(()));
295        substream.expect_poll_flush().times(1).return_once(|_| Poll::Ready(Ok(())));
296        substream.expect_poll_next().times(1).return_once(|_| {
297            Poll::Ready(Some(Err(crate::error::SubstreamError::ConnectionClosed)))
298        });
299
300        executor.send_request_read_response(
301            peer,
302            Some(QueryId(1337)),
303            Bytes::from_static(b"hello, world"),
304            Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)),
305        );
306
307        match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
308            Ok(Some(QueryContext {
309                peer: queried_peer,
310                query_id,
311                result,
312            })) => {
313                assert_eq!(peer, queried_peer);
314                assert_eq!(query_id, Some(QueryId(1337)));
315                assert!(std::matches!(result, QueryResult::SubstreamClosed));
316            }
317            result => panic!("invalid result received: {result:?}"),
318        }
319    }
320
321    #[tokio::test]
322    async fn send_fails_no_message_read() {
323        let mut executor = QueryExecutor::new();
324        let peer = PeerId::random();
325
326        // prepare substream which succeeds in sending the message but closes right after
327        let mut substream = MockSubstream::new();
328        substream
329            .expect_poll_ready()
330            .times(1)
331            .return_once(|_| Poll::Ready(Err(crate::error::SubstreamError::ConnectionClosed)));
332        substream.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(())));
333
334        executor.send_request_read_response(
335            peer,
336            Some(QueryId(1337)),
337            Bytes::from_static(b"hello, world"),
338            Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)),
339        );
340
341        match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
342            Ok(Some(QueryContext {
343                peer: queried_peer,
344                query_id,
345                result,
346            })) => {
347                assert_eq!(peer, queried_peer);
348                assert_eq!(query_id, Some(QueryId(1337)));
349                assert!(std::matches!(result, QueryResult::SubstreamClosed));
350            }
351            result => panic!("invalid result received: {result:?}"),
352        }
353    }
354
355    #[tokio::test]
356    async fn read_message_timeout() {
357        let mut executor = QueryExecutor::new();
358        let peer = PeerId::random();
359
360        // prepare substream which succeeds in sending the message but closes right after
361        let mut substream = MockSubstream::new();
362        substream.expect_poll_next().returning(|_| Poll::Pending);
363
364        executor.read_message(
365            peer,
366            Some(QueryId(1336)),
367            Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)),
368        );
369
370        match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
371            Ok(Some(QueryContext {
372                peer: queried_peer,
373                query_id,
374                result,
375            })) => {
376                assert_eq!(peer, queried_peer);
377                assert_eq!(query_id, Some(QueryId(1336)));
378                assert!(std::matches!(result, QueryResult::Timeout));
379            }
380            result => panic!("invalid result received: {result:?}"),
381        }
382    }
383
384    #[tokio::test]
385    async fn read_message_substream_closed() {
386        let mut executor = QueryExecutor::new();
387        let peer = PeerId::random();
388
389        // prepare substream which succeeds in sending the message but closes right after
390        let mut substream = MockSubstream::new();
391        substream
392            .expect_poll_next()
393            .times(1)
394            .return_once(|_| Poll::Ready(Some(Err(crate::error::SubstreamError::ChannelClogged))));
395
396        executor.read_message(
397            peer,
398            Some(QueryId(1335)),
399            Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)),
400        );
401
402        match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
403            Ok(Some(QueryContext {
404                peer: queried_peer,
405                query_id,
406                result,
407            })) => {
408                assert_eq!(peer, queried_peer);
409                assert_eq!(query_id, Some(QueryId(1335)));
410                assert!(std::matches!(result, QueryResult::SubstreamClosed));
411            }
412            result => panic!("invalid result received: {result:?}"),
413        }
414    }
415}