litep2p/protocol/notification/
negotiation.rs

1// Copyright 2023 litep2p developers
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Implementation of the notification handshaking.
22
23use crate::{substream::Substream, PeerId};
24
25use futures::{FutureExt, Sink, Stream};
26use futures_timer::Delay;
27use parking_lot::RwLock;
28
29use std::{
30    collections::{HashMap, VecDeque},
31    pin::Pin,
32    sync::Arc,
33    task::{Context, Poll},
34    time::Duration,
35};
36
37/// Logging target for the file.
38const LOG_TARGET: &str = "litep2p::notification::negotiation";
39
40/// Maximum timeout wait before for handshake before operation is considered failed.
41const NEGOTIATION_TIMEOUT: Duration = Duration::from_secs(10);
42
43/// Substream direction.
44#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
45pub enum Direction {
46    /// Outbound substream, opened by local node.
47    Outbound,
48
49    /// Inbound substream, opened by remote node.
50    Inbound,
51}
52
53/// Events emitted by [`HandshakeService`].
54#[derive(Debug)]
55pub enum HandshakeEvent {
56    /// Substream has been negotiated.
57    Negotiated {
58        /// Peer ID.
59        peer: PeerId,
60
61        /// Handshake.
62        handshake: Vec<u8>,
63
64        /// Substream.
65        substream: Substream,
66
67        /// Direction.
68        direction: Direction,
69    },
70
71    /// Outbound substream has been negotiated.
72    NegotiationError {
73        /// Peer ID.
74        peer: PeerId,
75
76        /// Direction.
77        direction: Direction,
78    },
79}
80
81/// Outbound substream's handshake state
82enum HandshakeState {
83    /// Send handshake to remote peer.
84    SendHandshake,
85
86    /// Sink is ready for the handshake to be sent.
87    SinkReady,
88
89    /// Handshake has been sent.
90    HandshakeSent,
91
92    /// Read handshake from remote peer.
93    ReadHandshake,
94}
95
96/// Handshake service.
97pub(crate) struct HandshakeService {
98    /// Handshake.
99    handshake: Arc<RwLock<Vec<u8>>>,
100
101    /// Pending outbound substreams.
102    /// Substreams:
103    substreams: HashMap<(PeerId, Direction), (Substream, Delay, HandshakeState)>,
104
105    /// Ready substreams.
106    ready: VecDeque<(PeerId, Direction, Vec<u8>)>,
107}
108
109impl HandshakeService {
110    /// Create new [`HandshakeService`].
111    pub fn new(handshake: Arc<RwLock<Vec<u8>>>) -> Self {
112        Self {
113            handshake,
114            ready: VecDeque::new(),
115            substreams: HashMap::new(),
116        }
117    }
118
119    /// Remove outbound substream from [`HandshakeService`].
120    pub fn remove_outbound(&mut self, peer: &PeerId) -> Option<Substream> {
121        self.substreams
122            .remove(&(*peer, Direction::Outbound))
123            .map(|(substream, _, _)| substream)
124    }
125
126    /// Remove inbound substream from [`HandshakeService`].
127    pub fn remove_inbound(&mut self, peer: &PeerId) -> Option<Substream> {
128        self.substreams
129            .remove(&(*peer, Direction::Inbound))
130            .map(|(substream, _, _)| substream)
131    }
132
133    /// Negotiate outbound handshake.
134    pub fn negotiate_outbound(&mut self, peer: PeerId, substream: Substream) {
135        tracing::trace!(target: LOG_TARGET, ?peer, "negotiate outbound");
136
137        self.substreams.insert(
138            (peer, Direction::Outbound),
139            (
140                substream,
141                Delay::new(NEGOTIATION_TIMEOUT),
142                HandshakeState::SendHandshake,
143            ),
144        );
145    }
146
147    /// Read handshake from remote peer.
148    pub fn read_handshake(&mut self, peer: PeerId, substream: Substream) {
149        tracing::trace!(target: LOG_TARGET, ?peer, "read handshake");
150
151        self.substreams.insert(
152            (peer, Direction::Inbound),
153            (
154                substream,
155                Delay::new(NEGOTIATION_TIMEOUT),
156                HandshakeState::ReadHandshake,
157            ),
158        );
159    }
160
161    /// Write handshake to remote peer.
162    pub fn send_handshake(&mut self, peer: PeerId, substream: Substream) {
163        tracing::trace!(target: LOG_TARGET, ?peer, "send handshake");
164
165        self.substreams.insert(
166            (peer, Direction::Inbound),
167            (
168                substream,
169                Delay::new(NEGOTIATION_TIMEOUT),
170                HandshakeState::SendHandshake,
171            ),
172        );
173    }
174
175    /// Returns `true` if [`HandshakeService`] contains no elements.
176    pub fn is_empty(&self) -> bool {
177        self.substreams.is_empty()
178    }
179
180    /// Pop event from the event queue.
181    ///
182    /// The substream may not exist in the queue anymore as it may have been removed
183    /// by `NotificationProtocol` if either one of the substreams failed to negotiate.
184    fn pop_event(&mut self) -> Option<(PeerId, HandshakeEvent)> {
185        while let Some((peer, direction, handshake)) = self.ready.pop_front() {
186            if let Some((substream, _, _)) = self.substreams.remove(&(peer, direction)) {
187                return Some((
188                    peer,
189                    HandshakeEvent::Negotiated {
190                        peer,
191                        handshake,
192                        substream,
193                        direction,
194                    },
195                ));
196            }
197        }
198
199        None
200    }
201}
202
203impl Stream for HandshakeService {
204    type Item = (PeerId, HandshakeEvent);
205
206    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
207        let inner = Pin::into_inner(self);
208
209        if let Some(event) = inner.pop_event() {
210            return Poll::Ready(Some(event));
211        }
212
213        if inner.substreams.is_empty() {
214            return Poll::Pending;
215        }
216
217        'outer: for ((peer, direction), (ref mut substream, ref mut timer, state)) in
218            inner.substreams.iter_mut()
219        {
220            if let Poll::Ready(()) = timer.poll_unpin(cx) {
221                return Poll::Ready(Some((
222                    *peer,
223                    HandshakeEvent::NegotiationError {
224                        peer: *peer,
225                        direction: *direction,
226                    },
227                )));
228            }
229
230            loop {
231                let pinned = Pin::new(&mut *substream);
232
233                match state {
234                    HandshakeState::SendHandshake => match pinned.poll_ready(cx) {
235                        Poll::Ready(Ok(())) => {
236                            *state = HandshakeState::SinkReady;
237                            continue;
238                        }
239                        Poll::Ready(Err(_)) =>
240                            return Poll::Ready(Some((
241                                *peer,
242                                HandshakeEvent::NegotiationError {
243                                    peer: *peer,
244                                    direction: *direction,
245                                },
246                            ))),
247                        Poll::Pending => continue 'outer,
248                    },
249                    HandshakeState::SinkReady => {
250                        match pinned.start_send((*inner.handshake.read()).clone().into()) {
251                            Ok(()) => {
252                                *state = HandshakeState::HandshakeSent;
253                                continue;
254                            }
255                            Err(_) =>
256                                return Poll::Ready(Some((
257                                    *peer,
258                                    HandshakeEvent::NegotiationError {
259                                        peer: *peer,
260                                        direction: *direction,
261                                    },
262                                ))),
263                        }
264                    }
265                    HandshakeState::HandshakeSent => match pinned.poll_flush(cx) {
266                        Poll::Ready(Ok(())) => match direction {
267                            Direction::Outbound => {
268                                *state = HandshakeState::ReadHandshake;
269                                continue;
270                            }
271                            Direction::Inbound => {
272                                inner.ready.push_back((*peer, *direction, vec![]));
273                                continue 'outer;
274                            }
275                        },
276                        Poll::Ready(Err(_)) =>
277                            return Poll::Ready(Some((
278                                *peer,
279                                HandshakeEvent::NegotiationError {
280                                    peer: *peer,
281                                    direction: *direction,
282                                },
283                            ))),
284                        Poll::Pending => continue 'outer,
285                    },
286                    HandshakeState::ReadHandshake => match pinned.poll_next(cx) {
287                        Poll::Ready(Some(Ok(handshake))) => {
288                            inner.ready.push_back((*peer, *direction, handshake.freeze().into()));
289                            continue 'outer;
290                        }
291                        Poll::Ready(Some(Err(_))) | Poll::Ready(None) => {
292                            return Poll::Ready(Some((
293                                *peer,
294                                HandshakeEvent::NegotiationError {
295                                    peer: *peer,
296                                    direction: *direction,
297                                },
298                            )));
299                        }
300                        Poll::Pending => continue 'outer,
301                    },
302                }
303            }
304        }
305
306        if let Some((peer, direction, handshake)) = inner.ready.pop_front() {
307            let (substream, _, _) =
308                inner.substreams.remove(&(peer, direction)).expect("peer to exist");
309
310            return Poll::Ready(Some((
311                peer,
312                HandshakeEvent::Negotiated {
313                    peer,
314                    handshake,
315                    substream,
316                    direction,
317                },
318            )));
319        }
320
321        Poll::Pending
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use crate::{
329        mock::substream::{DummySubstream, MockSubstream},
330        types::SubstreamId,
331    };
332    use futures::StreamExt;
333
334    #[tokio::test]
335    async fn substream_error_when_sending_handshake() {
336        let mut service = HandshakeService::new(Arc::new(RwLock::new(vec![1, 2, 3, 4])));
337
338        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
339            Poll::Pending => Poll::Ready(()),
340            _ => panic!("invalid event received"),
341        })
342        .await;
343
344        let mut substream = MockSubstream::new();
345        substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(())));
346        substream
347            .expect_start_send()
348            .times(1)
349            .return_once(|_| Err(crate::error::SubstreamError::ConnectionClosed));
350
351        let peer = PeerId::random();
352        let substream = Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream));
353
354        service.send_handshake(peer, substream);
355        match service.next().await {
356            Some((
357                failed_peer,
358                HandshakeEvent::NegotiationError {
359                    peer: event_peer,
360                    direction,
361                },
362            )) => {
363                assert_eq!(failed_peer, peer);
364                assert_eq!(event_peer, peer);
365                assert_eq!(direction, Direction::Inbound);
366            }
367            _ => panic!("invalid event received"),
368        }
369    }
370
371    #[tokio::test]
372    async fn substream_error_when_flushing_substream() {
373        let mut service = HandshakeService::new(Arc::new(RwLock::new(vec![1, 2, 3, 4])));
374
375        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
376            Poll::Pending => Poll::Ready(()),
377            _ => panic!("invalid event received"),
378        })
379        .await;
380
381        let mut substream = MockSubstream::new();
382        substream.expect_poll_ready().times(1).return_once(|_| Poll::Ready(Ok(())));
383        substream.expect_start_send().times(1).return_once(|_| Ok(()));
384        substream
385            .expect_poll_flush()
386            .times(1)
387            .return_once(|_| Poll::Ready(Err(crate::error::SubstreamError::ConnectionClosed)));
388
389        let peer = PeerId::random();
390        let substream = Substream::new_mock(peer, SubstreamId::from(0usize), Box::new(substream));
391
392        service.send_handshake(peer, substream);
393        match service.next().await {
394            Some((
395                failed_peer,
396                HandshakeEvent::NegotiationError {
397                    peer: event_peer,
398                    direction,
399                },
400            )) => {
401                assert_eq!(failed_peer, peer);
402                assert_eq!(event_peer, peer);
403                assert_eq!(direction, Direction::Inbound);
404            }
405            _ => panic!("invalid event received"),
406        }
407    }
408
409    // inbound substream is negotiated and it pushed into `inner` but outbound substream fails to
410    // negotiate
411    #[tokio::test]
412    async fn pop_event_but_substream_doesnt_exist() {
413        let mut service = HandshakeService::new(Arc::new(RwLock::new(vec![1, 2, 3, 4])));
414        let peer = PeerId::random();
415
416        // inbound substream has finished
417        service.ready.push_front((peer, Direction::Inbound, vec![]));
418        service.substreams.insert(
419            (peer, Direction::Inbound),
420            (
421                Substream::new_mock(
422                    peer,
423                    SubstreamId::from(1337usize),
424                    Box::new(DummySubstream::new()),
425                ),
426                Delay::new(NEGOTIATION_TIMEOUT),
427                HandshakeState::HandshakeSent,
428            ),
429        );
430        service.substreams.insert(
431            (peer, Direction::Outbound),
432            (
433                Substream::new_mock(
434                    peer,
435                    SubstreamId::from(1337usize),
436                    Box::new(DummySubstream::new()),
437                ),
438                Delay::new(NEGOTIATION_TIMEOUT),
439                HandshakeState::SendHandshake,
440            ),
441        );
442
443        // outbound substream failed and `NotificationProtocol` removes
444        // both substreams from `HandshakeService`
445        assert!(service.remove_outbound(&peer).is_some());
446        assert!(service.remove_inbound(&peer).is_some());
447
448        futures::future::poll_fn(|cx| match service.poll_next_unpin(cx) {
449            Poll::Pending => Poll::Ready(()),
450            _ => panic!("invalid event received"),
451        })
452        .await
453    }
454}