1use crate::{substream::Substream, PeerId};
24
25use futures::{FutureExt, Sink, Stream};
26use futures_timer::Delay;
27use parking_lot::RwLock;
28
29use std::{
30 collections::{HashMap, VecDeque},
31 pin::Pin,
32 sync::Arc,
33 task::{Context, Poll},
34 time::Duration,
35};
36
37const LOG_TARGET: &str = "litep2p::notification::negotiation";
39
40const NEGOTIATION_TIMEOUT: Duration = Duration::from_secs(10);
42
43#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
45pub enum Direction {
46 Outbound,
48
49 Inbound,
51}
52
53#[derive(Debug)]
55pub enum HandshakeEvent {
56 Negotiated {
58 peer: PeerId,
60
61 handshake: Vec<u8>,
63
64 substream: Substream,
66
67 direction: Direction,
69 },
70
71 NegotiationError {
73 peer: PeerId,
75
76 direction: Direction,
78 },
79}
80
81enum HandshakeState {
83 SendHandshake,
85
86 SinkReady,
88
89 HandshakeSent,
91
92 ReadHandshake,
94}
95
96pub(crate) struct HandshakeService {
98 handshake: Arc<RwLock<Vec<u8>>>,
100
101 substreams: HashMap<(PeerId, Direction), (Substream, Delay, HandshakeState)>,
104
105 ready: VecDeque<(PeerId, Direction, Vec<u8>)>,
107}
108
109impl HandshakeService {
110 pub fn new(handshake: Arc<RwLock<Vec<u8>>>) -> Self {
112 Self {
113 handshake,
114 ready: VecDeque::new(),
115 substreams: HashMap::new(),
116 }
117 }
118
119 pub fn remove_outbound(&mut self, peer: &PeerId) -> Option<Substream> {
121 self.substreams
122 .remove(&(*peer, Direction::Outbound))
123 .map(|(substream, _, _)| substream)
124 }
125
126 pub fn remove_inbound(&mut self, peer: &PeerId) -> Option<Substream> {
128 self.substreams
129 .remove(&(*peer, Direction::Inbound))
130 .map(|(substream, _, _)| substream)
131 }
132
133 pub fn negotiate_outbound(&mut self, peer: PeerId, substream: Substream) {
135 tracing::trace!(target: LOG_TARGET, ?peer, "negotiate outbound");
136
137 self.substreams.insert(
138 (peer, Direction::Outbound),
139 (
140 substream,
141 Delay::new(NEGOTIATION_TIMEOUT),
142 HandshakeState::SendHandshake,
143 ),
144 );
145 }
146
147 pub fn read_handshake(&mut self, peer: PeerId, substream: Substream) {
149 tracing::trace!(target: LOG_TARGET, ?peer, "read handshake");
150
151 self.substreams.insert(
152 (peer, Direction::Inbound),
153 (
154 substream,
155 Delay::new(NEGOTIATION_TIMEOUT),
156 HandshakeState::ReadHandshake,
157 ),
158 );
159 }
160
161 pub fn send_handshake(&mut self, peer: PeerId, substream: Substream) {
163 tracing::trace!(target: LOG_TARGET, ?peer, "send handshake");
164
165 self.substreams.insert(
166 (peer, Direction::Inbound),
167 (
168 substream,
169 Delay::new(NEGOTIATION_TIMEOUT),
170 HandshakeState::SendHandshake,
171 ),
172 );
173 }
174
175 pub fn is_empty(&self) -> bool {
177 self.substreams.is_empty()
178 }
179
180 fn pop_event(&mut self) -> Option<(PeerId, HandshakeEvent)> {
185 while let Some((peer, direction, handshake)) = self.ready.pop_front() {
186 if let Some((substream, _, _)) = self.substreams.remove(&(peer, direction)) {
187 return Some((
188 peer,
189 HandshakeEvent::Negotiated {
190 peer,
191 handshake,
192 substream,
193 direction,
194 },
195 ));
196 }
197 }
198
199 None
200 }
201}
202
203impl Stream for HandshakeService {
204 type Item = (PeerId, HandshakeEvent);
205
206 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
207 let inner = Pin::into_inner(self);
208
209 if let Some(event) = inner.pop_event() {
210 return Poll::Ready(Some(event));
211 }
212
213 if inner.substreams.is_empty() {
214 return Poll::Pending;
215 }
216
217 'outer: for ((peer, direction), (ref mut substream, ref mut timer, state)) in
218 inner.substreams.iter_mut()
219 {
220 if let Poll::Ready(()) = timer.poll_unpin(cx) {
221 return Poll::Ready(Some((
222 *peer,
223 HandshakeEvent::NegotiationError {
224 peer: *peer,
225 direction: *direction,
226 },
227 )));
228 }
229
230 loop {
231 let pinned = Pin::new(&mut *substream);
232
233 match state {
234 HandshakeState::SendHandshake => match pinned.poll_ready(cx) {
235 Poll::Ready(Ok(())) => {
236 *state = HandshakeState::SinkReady;
237 continue;
238 }
239 Poll::Ready(Err(_)) =>
240 return Poll::Ready(Some((
241 *peer,
242 HandshakeEvent::NegotiationError {
243 peer: *peer,
244 direction: *direction,
245 },
246 ))),
247 Poll::Pending => continue 'outer,
248 },
249 HandshakeState::SinkReady => {
250 match pinned.start_send((*inner.handshake.read()).clone().into()) {
251 Ok(()) => {
252 *state = HandshakeState::HandshakeSent;
253 continue;
254 }
255 Err(_) =>
256 return Poll::Ready(Some((
257 *peer,
258 HandshakeEvent::NegotiationError {
259 peer: *peer,
260 direction: *direction,
261 },
262 ))),
263 }
264 }
265 HandshakeState::HandshakeSent => match pinned.poll_flush(cx) {
266 Poll::Ready(Ok(())) => match direction {
267 Direction::Outbound => {
268 *state = HandshakeState::ReadHandshake;
269 continue;
270 }
271 Direction::Inbound => {
272 inner.ready.push_back((*peer, *direction, vec![]));
273 continue 'outer;
274 }
275 },
276 Poll::Ready(Err(_)) =>
277 return Poll::Ready(Some((
278 *peer,
279 HandshakeEvent::NegotiationError {
280 peer: *peer,
281 direction: *direction,
282 },
283 ))),
284 Poll::Pending => continue 'outer,
285 },
286 HandshakeState::ReadHandshake => match pinned.poll_next(cx) {
287 Poll::Ready(Some(Ok(handshake))) => {
288 inner.ready.push_back((*peer, *direction, handshake.freeze().into()));
289 continue 'outer;
290 }
291 Poll::Ready(Some(Err(_))) | Poll::Ready(None) => {
292 return Poll::Ready(Some((
293 *peer,
294 HandshakeEvent::NegotiationError {
295 peer: *peer,
296 direction: *direction,
297 },
298 )));
299 }
300 Poll::Pending => continue 'outer,
301 },
302 }
303 }
304 }
305
306 if let Some((peer, direction, handshake)) = inner.ready.pop_front() {
307 let (substream, _, _) =
308 inner.substreams.remove(&(peer, direction)).expect("peer to exist");
309
310 return Poll::Ready(Some((
311 peer,
312 HandshakeEvent::Negotiated {
313 peer,
314 handshake,
315 substream,
316 direction,
317 },
318 )));
319 }
320
321 Poll::Pending
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use crate::{
329 mock::substream::{DummySubstream, MockSubstream},
330 types::SubstreamId,
331 };
332 use futures::StreamExt;
333
334 #[tokio::test]
335 async fn substream_error_when_sending_handshake() {
336 let mut service = HandshakeService::new(Arc::new(RwLock::new(vec![1, 2, 3, 4])));
337
338 futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
339 Poll::Pending => Poll::Ready(()),
340 _ => panic!("invalid event received"),
341 })
342 .await;
343
344 let mut substream = MockSubstream::new();
345 substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(())));
346 substream
347 .expect_start_send()
348 .times(1)
349 .return_once(|_| Err(crate::error::SubstreamError::ConnectionClosed));
350
351 let peer = PeerId::random();
352 let substream = Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream));
353
354 service.send_handshake(peer, substream);
355 match service.next().await {
356 Some((
357 failed_peer,
358 HandshakeEvent::NegotiationError {
359 peer: event_peer,
360 direction,
361 },
362 )) => {
363 assert_eq!(failed_peer, peer);
364 assert_eq!(event_peer, peer);
365 assert_eq!(direction, Direction::Inbound);
366 }
367 _ => panic!("invalid event received"),
368 }
369 }
370
371 #[tokio::test]
372 async fn substream_error_when_flushing_substream() {
373 let mut service = HandshakeService::new(Arc::new(RwLock::new(vec![1, 2, 3, 4])));
374
375 futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
376 Poll::Pending => Poll::Ready(()),
377 _ => panic!("invalid event received"),
378 })
379 .await;
380
381 let mut substream = MockSubstream::new();
382 substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(())));
383 substream.expect_start_send().times(1).return_once(|_| Ok(()));
384 substream
385 .expect_poll_flush()
386 .times(1)
387 .return_once(|_| Poll::Ready(Err(crate::error::SubstreamError::ConnectionClosed)));
388
389 let peer = PeerId::random();
390 let substream = Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream));
391
392 service.send_handshake(peer, substream);
393 match service.next().await {
394 Some((
395 failed_peer,
396 HandshakeEvent::NegotiationError {
397 peer: event_peer,
398 direction,
399 },
400 )) => {
401 assert_eq!(failed_peer, peer);
402 assert_eq!(event_peer, peer);
403 assert_eq!(direction, Direction::Inbound);
404 }
405 _ => panic!("invalid event received"),
406 }
407 }
408
409 #[tokio::test]
412 async fn pop_event_but_substream_doesnt_exist() {
413 let mut service = HandshakeService::new(Arc::new(RwLock::new(vec![1, 2, 3, 4])));
414 let peer = PeerId::random();
415
416 service.ready.push_front((peer, Direction::Inbound, vec![]));
418 service.substreams.insert(
419 (peer, Direction::Inbound),
420 (
421 Substream::new_mock(
422 peer,
423 SubstreamId::from(1337usize),
424 Box::new(DummySubstream::new()),
425 ),
426 Delay::new(NEGOTIATION_TIMEOUT),
427 HandshakeState::HandshakeSent,
428 ),
429 );
430 service.substreams.insert(
431 (peer, Direction::Outbound),
432 (
433 Substream::new_mock(
434 peer,
435 SubstreamId::from(1337usize),
436 Box::new(DummySubstream::new()),
437 ),
438 Delay::new(NEGOTIATION_TIMEOUT),
439 HandshakeState::SendHandshake,
440 ),
441 );
442
443 assert!(service.remove_outbound(&peer).is_some());
446 assert!(service.remove_inbound(&peer).is_some());
447
448 futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
449 Poll::Pending => Poll::Ready(()),
450 _ => panic!("invalid event received"),
451 })
452 .await
453 }
454}