litep2p/protocol/libp2p/kademlia/
executor.rs1use crate::{protocol::libp2p::kademlia::query::QueryId, substream::Substream, PeerId};
22
23use bytes::{Bytes, BytesMut};
24use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt};
25
26use std::{
27 future::Future,
28 pin::Pin,
29 task::{Context, Poll, Waker},
30 time::Duration,
31};
32
33const READ_TIMEOUT: Duration = Duration::from_secs(15);
35
36#[derive(Debug)]
38pub enum QueryResult {
39 SendSuccess {
41 substream: Substream,
43 },
44
45 ReadSuccess {
47 substream: Substream,
49
50 message: BytesMut,
52 },
53
54 Timeout,
56
57 SubstreamClosed,
59}
60
61#[derive(Debug)]
63pub struct QueryContext {
64 pub peer: PeerId,
66
67 pub query_id: Option<QueryId>,
69
70 pub result: QueryResult,
72}
73
74#[derive(Default)]
76pub struct FuturesStream<F> {
77 futures: FuturesUnordered<F>,
78 waker: Option<Waker>,
79}
80
81impl<F> FuturesStream<F> {
82 pub fn new() -> Self {
84 Self {
85 futures: FuturesUnordered::new(),
86 waker: None,
87 }
88 }
89
90 pub fn push(&mut self, future: F) {
92 self.futures.push(future);
93
94 if let Some(waker) = self.waker.take() {
95 waker.wake();
96 }
97 }
98}
99
100impl<F: Future> Stream for FuturesStream<F> {
101 type Item = <F as Future>::Output;
102
103 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
104 let Poll::Ready(Some(result)) = self.futures.poll_next_unpin(cx) else {
105 self.waker = Some(cx.waker().clone());
113
114 return Poll::Pending;
115 };
116
117 Poll::Ready(Some(result))
118 }
119}
120
121pub struct QueryExecutor {
123 futures: FuturesStream<BoxFuture<'static, QueryContext>>,
125}
126
127impl QueryExecutor {
128 pub fn new() -> Self {
130 Self {
131 futures: FuturesStream::new(),
132 }
133 }
134
135 pub fn send_message(&mut self, peer: PeerId, message: Bytes, mut substream: Substream) {
137 self.futures.push(Box::pin(async move {
138 match substream.send_framed(message).await {
139 Ok(_) => QueryContext {
140 peer,
141 query_id: None,
142 result: QueryResult::SendSuccess { substream },
143 },
144 Err(_) => QueryContext {
145 peer,
146 query_id: None,
147 result: QueryResult::SubstreamClosed,
148 },
149 }
150 }));
151 }
152
153 pub fn read_message(
155 &mut self,
156 peer: PeerId,
157 query_id: Option<QueryId>,
158 mut substream: Substream,
159 ) {
160 self.futures.push(Box::pin(async move {
161 match tokio::time::timeout(READ_TIMEOUT, substream.next()).await {
162 Err(_) => QueryContext {
163 peer,
164 query_id,
165 result: QueryResult::Timeout,
166 },
167 Ok(Some(Ok(message))) => QueryContext {
168 peer,
169 query_id,
170 result: QueryResult::ReadSuccess { substream, message },
171 },
172 Ok(None) | Ok(Some(Err(_))) => QueryContext {
173 peer,
174 query_id,
175 result: QueryResult::SubstreamClosed,
176 },
177 }
178 }));
179 }
180
181 pub fn send_request_read_response(
183 &mut self,
184 peer: PeerId,
185 query_id: Option<QueryId>,
186 message: Bytes,
187 mut substream: Substream,
188 ) {
189 self.futures.push(Box::pin(async move {
190 if let Err(_) = substream.send_framed(message).await {
191 let _ = substream.close().await;
192 return QueryContext {
193 peer,
194 query_id,
195 result: QueryResult::SubstreamClosed,
196 };
197 }
198
199 match tokio::time::timeout(READ_TIMEOUT, substream.next()).await {
200 Err(_) => QueryContext {
201 peer,
202 query_id,
203 result: QueryResult::Timeout,
204 },
205 Ok(Some(Ok(message))) => QueryContext {
206 peer,
207 query_id,
208 result: QueryResult::ReadSuccess { substream, message },
209 },
210 Ok(None) | Ok(Some(Err(_))) => QueryContext {
211 peer,
212 query_id,
213 result: QueryResult::SubstreamClosed,
214 },
215 }
216 }));
217 }
218}
219
220impl Stream for QueryExecutor {
221 type Item = QueryContext;
222
223 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
224 self.futures.poll_next_unpin(cx)
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::{mock::substream::MockSubstream, types::SubstreamId};
232
233 #[tokio::test]
234 async fn substream_read_timeout() {
235 let mut executor = QueryExecutor::new();
236 let peer = PeerId::random();
237 let mut substream = MockSubstream::new();
238 substream.expect_poll_next().returning(|_| Poll::Pending);
239 let substream = Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream));
240
241 executor.read_message(peer, None, substream);
242
243 match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
244 Ok(Some(QueryContext {
245 peer: queried_peer,
246 query_id,
247 result,
248 })) => {
249 assert_eq!(peer, queried_peer);
250 assert!(query_id.is_none());
251 assert!(std::matches!(result, QueryResult::Timeout));
252 }
253 result => panic!("invalid result received: {result:?}"),
254 }
255 }
256
257 #[tokio::test]
258 async fn substream_read_substream_closed() {
259 let mut executor = QueryExecutor::new();
260 let peer = PeerId::random();
261 let mut substream = MockSubstream::new();
262 substream.expect_poll_next().times(1).return_once(|_| {
263 Poll::Ready(Some(Err(crate::error::SubstreamError::ConnectionClosed)))
264 });
265
266 executor.read_message(
267 peer,
268 Some(QueryId(1338)),
269 Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)),
270 );
271
272 match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
273 Ok(Some(QueryContext {
274 peer: queried_peer,
275 query_id,
276 result,
277 })) => {
278 assert_eq!(peer, queried_peer);
279 assert_eq!(query_id, Some(QueryId(1338)));
280 assert!(std::matches!(result, QueryResult::SubstreamClosed));
281 }
282 result => panic!("invalid result received: {result:?}"),
283 }
284 }
285
286 #[tokio::test]
287 async fn send_succeeds_no_message_read() {
288 let mut executor = QueryExecutor::new();
289 let peer = PeerId::random();
290
291 let mut substream = MockSubstream::new();
293 substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(())));
294 substream.expect_start_send().times(1).return_once(|_| Ok(()));
295 substream.expect_poll_flush().times(1).return_once(|_| Poll::Ready(Ok(())));
296 substream.expect_poll_next().times(1).return_once(|_| {
297 Poll::Ready(Some(Err(crate::error::SubstreamError::ConnectionClosed)))
298 });
299
300 executor.send_request_read_response(
301 peer,
302 Some(QueryId(1337)),
303 Bytes::from_static(b"hello, world"),
304 Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)),
305 );
306
307 match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
308 Ok(Some(QueryContext {
309 peer: queried_peer,
310 query_id,
311 result,
312 })) => {
313 assert_eq!(peer, queried_peer);
314 assert_eq!(query_id, Some(QueryId(1337)));
315 assert!(std::matches!(result, QueryResult::SubstreamClosed));
316 }
317 result => panic!("invalid result received: {result:?}"),
318 }
319 }
320
321 #[tokio::test]
322 async fn send_fails_no_message_read() {
323 let mut executor = QueryExecutor::new();
324 let peer = PeerId::random();
325
326 let mut substream = MockSubstream::new();
328 substream
329 .expect_poll_ready()
330 .times(1)
331 .return_once(|_| Poll::Ready(Err(crate::error::SubstreamError::ConnectionClosed)));
332 substream.expect_poll_close().times(1).return_once(|_| Poll::Ready(Ok(())));
333
334 executor.send_request_read_response(
335 peer,
336 Some(QueryId(1337)),
337 Bytes::from_static(b"hello, world"),
338 Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)),
339 );
340
341 match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
342 Ok(Some(QueryContext {
343 peer: queried_peer,
344 query_id,
345 result,
346 })) => {
347 assert_eq!(peer, queried_peer);
348 assert_eq!(query_id, Some(QueryId(1337)));
349 assert!(std::matches!(result, QueryResult::SubstreamClosed));
350 }
351 result => panic!("invalid result received: {result:?}"),
352 }
353 }
354
355 #[tokio::test]
356 async fn read_message_timeout() {
357 let mut executor = QueryExecutor::new();
358 let peer = PeerId::random();
359
360 let mut substream = MockSubstream::new();
362 substream.expect_poll_next().returning(|_| Poll::Pending);
363
364 executor.read_message(
365 peer,
366 Some(QueryId(1336)),
367 Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)),
368 );
369
370 match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
371 Ok(Some(QueryContext {
372 peer: queried_peer,
373 query_id,
374 result,
375 })) => {
376 assert_eq!(peer, queried_peer);
377 assert_eq!(query_id, Some(QueryId(1336)));
378 assert!(std::matches!(result, QueryResult::Timeout));
379 }
380 result => panic!("invalid result received: {result:?}"),
381 }
382 }
383
384 #[tokio::test]
385 async fn read_message_substream_closed() {
386 let mut executor = QueryExecutor::new();
387 let peer = PeerId::random();
388
389 let mut substream = MockSubstream::new();
391 substream
392 .expect_poll_next()
393 .times(1)
394 .return_once(|_| Poll::Ready(Some(Err(crate::error::SubstreamError::ChannelClogged))));
395
396 executor.read_message(
397 peer,
398 Some(QueryId(1335)),
399 Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream)),
400 );
401
402 match tokio::time::timeout(Duration::from_secs(20), executor.next()).await {
403 Ok(Some(QueryContext {
404 peer: queried_peer,
405 query_id,
406 result,
407 })) => {
408 assert_eq!(peer, queried_peer);
409 assert_eq!(query_id, Some(QueryId(1335)));
410 assert!(std::matches!(result, QueryResult::SubstreamClosed));
411 }
412 result => panic!("invalid result received: {result:?}"),
413 }
414 }
415}