litep2p/protocol/notification/
handle.rs

1// Copyright 2023 litep2p developers
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::{
22    error::Error,
23    protocol::notification::types::{
24        Direction, InnerNotificationEvent, NotificationCommand, NotificationError,
25        NotificationEvent, ValidationResult,
26    },
27    types::protocol::ProtocolName,
28    PeerId,
29};
30
31use bytes::BytesMut;
32use futures::Stream;
33use parking_lot::RwLock;
34use tokio::sync::{
35    mpsc::{error::TrySendError, Receiver, Sender},
36    oneshot,
37};
38
39use std::{
40    collections::{HashMap, HashSet},
41    pin::Pin,
42    sync::Arc,
43    task::{Context, Poll},
44};
45
46/// Logging target for the file.
47const LOG_TARGET: &str = "litep2p::notification::handle";
48
49#[derive(Debug, Clone)]
50pub(crate) struct NotificationEventHandle {
51    tx: Sender<InnerNotificationEvent>,
52}
53
54impl NotificationEventHandle {
55    /// Create new [`NotificationEventHandle`].
56    pub(crate) fn new(tx: Sender<InnerNotificationEvent>) -> Self {
57        Self { tx }
58    }
59
60    /// Validate inbound substream.
61    pub(crate) async fn report_inbound_substream(
62        &self,
63        protocol: ProtocolName,
64        fallback: Option<ProtocolName>,
65        peer: PeerId,
66        handshake: Vec<u8>,
67        tx: oneshot::Sender<ValidationResult>,
68    ) {
69        let _ = self
70            .tx
71            .send(InnerNotificationEvent::ValidateSubstream {
72                protocol,
73                fallback,
74                peer,
75                handshake,
76                tx,
77            })
78            .await;
79    }
80
81    /// Notification stream opened.
82    pub(crate) async fn report_notification_stream_opened(
83        &self,
84        protocol: ProtocolName,
85        fallback: Option<ProtocolName>,
86        direction: Direction,
87        peer: PeerId,
88        handshake: Vec<u8>,
89        sink: NotificationSink,
90    ) {
91        let _ = self
92            .tx
93            .send(InnerNotificationEvent::NotificationStreamOpened {
94                protocol,
95                fallback,
96                direction,
97                peer,
98                handshake,
99                sink,
100            })
101            .await;
102    }
103
104    /// Notification stream closed.
105    pub(crate) async fn report_notification_stream_closed(&self, peer: PeerId) {
106        let _ = self.tx.send(InnerNotificationEvent::NotificationStreamClosed { peer }).await;
107    }
108
109    /// Failed to open notification stream.
110    pub(crate) async fn report_notification_stream_open_failure(
111        &self,
112        peer: PeerId,
113        error: NotificationError,
114    ) {
115        let _ = self
116            .tx
117            .send(InnerNotificationEvent::NotificationStreamOpenFailure { peer, error })
118            .await;
119    }
120}
121
122/// Notification sink.
123///
124/// Allows the user to send notifications both synchronously and asynchronously.
125#[derive(Debug, Clone)]
126pub struct NotificationSink {
127    /// Peer ID.
128    peer: PeerId,
129
130    /// TX channel for sending notifications synchronously.
131    sync_tx: Sender<Vec<u8>>,
132
133    /// TX channel for sending notifications asynchronously.
134    async_tx: Sender<Vec<u8>>,
135}
136
137impl NotificationSink {
138    /// Create new [`NotificationSink`].
139    pub(crate) fn new(peer: PeerId, sync_tx: Sender<Vec<u8>>, async_tx: Sender<Vec<u8>>) -> Self {
140        Self {
141            peer,
142            async_tx,
143            sync_tx,
144        }
145    }
146
147    /// Send notification to `peer` synchronously.
148    ///
149    /// If the channel is clogged, [`NotificationError::ChannelClogged`] is returned.
150    pub fn send_sync_notification(&self, notification: Vec<u8>) -> Result<(), NotificationError> {
151        self.sync_tx.try_send(notification).map_err(|error| match error {
152            TrySendError::Closed(_) => NotificationError::NoConnection,
153            TrySendError::Full(_) => NotificationError::ChannelClogged,
154        })
155    }
156
157    /// Send notification to `peer` asynchronously, waiting for the channel to have capacity
158    /// if it's clogged.
159    ///
160    /// Returns [`Error::PeerDoesntExist(PeerId)`](crate::error::Error::PeerDoesntExist)
161    /// if the connection has been closed.
162    pub async fn send_async_notification(&self, notification: Vec<u8>) -> crate::Result<()> {
163        self.async_tx
164            .send(notification)
165            .await
166            .map_err(|_| Error::PeerDoesntExist(self.peer))
167    }
168}
169
170/// Handle allowing the user protocol to interact with the notification protocol.
171#[derive(Debug)]
172pub struct NotificationHandle {
173    /// RX channel for receiving events from the notification protocol.
174    event_rx: Receiver<InnerNotificationEvent>,
175
176    /// RX channel for receiving notifications from connection handlers.
177    notif_rx: Receiver<(PeerId, BytesMut)>,
178
179    /// TX channel for sending commands to the notification protocol.
180    command_tx: Sender<NotificationCommand>,
181
182    /// Peers.
183    peers: HashMap<PeerId, NotificationSink>,
184
185    /// Clogged peers.
186    clogged: HashSet<PeerId>,
187
188    /// Pending validations.
189    pending_validations: HashMap<PeerId, oneshot::Sender<ValidationResult>>,
190
191    /// Handshake.
192    handshake: Arc<RwLock<Vec<u8>>>,
193
194    /// Protocol name.
195    protocol_name: ProtocolName,
196}
197
198impl NotificationHandle {
199    /// Create new [`NotificationHandle`].
200    pub(crate) fn new(
201        event_rx: Receiver<InnerNotificationEvent>,
202        notif_rx: Receiver<(PeerId, BytesMut)>,
203        command_tx: Sender<NotificationCommand>,
204        handshake: Arc<RwLock<Vec<u8>>>,
205        protocol_name: ProtocolName,
206    ) -> Self {
207        Self {
208            event_rx,
209            notif_rx,
210            command_tx,
211            handshake,
212            peers: HashMap::new(),
213            clogged: HashSet::new(),
214            pending_validations: HashMap::new(),
215            protocol_name,
216        }
217    }
218
219    /// Open substream to `peer`.
220    ///
221    /// Returns [`Error::PeerAlreadyExists(PeerId)`](crate::error::Error::PeerAlreadyExists) if
222    /// substream is already open to `peer`.
223    ///
224    /// If connection to peer is closed, `NotificationProtocol` tries to dial the peer and if the
225    /// dial succeeds, tries to open a substream. This behavior can be disabled with
226    /// [`ConfigBuilder::with_dialing_enabled(false)`](super::config::ConfigBuilder::with_dialing_enabled()).
227    pub async fn open_substream(&self, peer: PeerId) -> crate::Result<()> {
228        tracing::trace!(target: LOG_TARGET, ?peer, protocol_name = ?self.protocol_name, "open substream");
229
230        if self.peers.contains_key(&peer) {
231            return Err(Error::PeerAlreadyExists(peer));
232        }
233
234        self.command_tx
235            .send(NotificationCommand::OpenSubstream {
236                peers: HashSet::from_iter([peer]),
237            })
238            .await
239            .map_or(Ok(()), |_| Ok(()))
240    }
241
242    /// Open substreams to multiple peers.
243    ///
244    /// Similar to [`NotificationHandle::open_substream()`] but multiple substreams are initiated
245    /// using a single call to `NotificationProtocol`.
246    ///
247    /// Peers who are already connected are ignored and returned as `Err(HashSet<PeerId>>)`.
248    pub async fn open_substream_batch(
249        &self,
250        peers: impl Iterator<Item = PeerId>,
251    ) -> Result<(), HashSet<PeerId>> {
252        let (to_add, to_ignore): (Vec<_>, Vec<_>) = peers
253            .map(|peer| match self.peers.contains_key(&peer) {
254                true => (None, Some(peer)),
255                false => (Some(peer), None),
256            })
257            .unzip();
258
259        let to_add = to_add.into_iter().flatten().collect::<HashSet<_>>();
260        let to_ignore = to_ignore.into_iter().flatten().collect::<HashSet<_>>();
261
262        tracing::trace!(
263            target: LOG_TARGET,
264            peers_to_add = ?to_add.len(),
265            peers_to_ignore = ?to_ignore.len(),
266            protocol_name = ?self.protocol_name,
267            "open substream",
268        );
269
270        let _ = self.command_tx.send(NotificationCommand::OpenSubstream { peers: to_add }).await;
271
272        match to_ignore.is_empty() {
273            true => Ok(()),
274            false => Err(to_ignore),
275        }
276    }
277
278    /// Try to open substreams to multiple peers.
279    ///
280    /// Similar to [`NotificationHandle::open_substream()`] but multiple substreams are initiated
281    /// using a single call to `NotificationProtocol`.
282    ///
283    /// If the channel is clogged, peers for whom a connection is not yet open are returned as
284    /// `Err(HashSet<PeerId>)`.
285    pub fn try_open_substream_batch(
286        &self,
287        peers: impl Iterator<Item = PeerId>,
288    ) -> Result<(), HashSet<PeerId>> {
289        let (to_add, to_ignore): (Vec<_>, Vec<_>) = peers
290            .map(|peer| match self.peers.contains_key(&peer) {
291                true => (None, Some(peer)),
292                false => (Some(peer), None),
293            })
294            .unzip();
295
296        let to_add = to_add.into_iter().flatten().collect::<HashSet<_>>();
297        let to_ignore = to_ignore.into_iter().flatten().collect::<HashSet<_>>();
298
299        tracing::trace!(
300            target: LOG_TARGET,
301            peers_to_add = ?to_add.len(),
302            peers_to_ignore = ?to_ignore.len(),
303            protocol_name = ?self.protocol_name,
304            "open substream",
305        );
306
307        self.command_tx
308            .try_send(NotificationCommand::OpenSubstream {
309                peers: to_add.clone(),
310            })
311            .map_err(|_| to_add)
312    }
313
314    /// Close substream to `peer`.
315    pub async fn close_substream(&self, peer: PeerId) {
316        tracing::trace!(target: LOG_TARGET, ?peer, protocol_name = ?self.protocol_name, "close substream");
317
318        if !self.peers.contains_key(&peer) {
319            return;
320        }
321
322        let _ = self
323            .command_tx
324            .send(NotificationCommand::CloseSubstream {
325                peers: HashSet::from_iter([peer]),
326            })
327            .await;
328    }
329
330    /// Close substream to multiple peers.
331    ///
332    /// Similar to [`NotificationHandle::close_substream()`] but multiple substreams are closed
333    /// using a single call to `NotificationProtocol`.
334    pub async fn close_substream_batch(&self, peers: impl Iterator<Item = PeerId>) {
335        let peers = peers.filter(|peer| self.peers.contains_key(peer)).collect::<HashSet<_>>();
336
337        if peers.is_empty() {
338            return;
339        }
340
341        tracing::trace!(
342            target: LOG_TARGET,
343            ?peers,
344            protocol_name = ?self.protocol_name,
345            "close substreams",
346        );
347
348        let _ = self.command_tx.send(NotificationCommand::CloseSubstream { peers }).await;
349    }
350
351    /// Try close substream to multiple peers.
352    ///
353    /// Similar to [`NotificationHandle::close_substream()`] but multiple substreams are closed
354    /// using a single call to `NotificationProtocol`.
355    ///
356    /// If the channel is clogged, `peers` is returned as `Err(HashSet<PeerId>)`.
357    ///
358    /// If `peers` is empty after filtering all already-connected peers,
359    /// `Err(HashMap::new())` is returned.
360    pub fn try_close_substream_batch(
361        &self,
362        peers: impl Iterator<Item = PeerId>,
363    ) -> Result<(), HashSet<PeerId>> {
364        let peers = peers.filter(|peer| self.peers.contains_key(peer)).collect::<HashSet<_>>();
365
366        if peers.is_empty() {
367            return Err(HashSet::new());
368        }
369
370        tracing::trace!(
371            target: LOG_TARGET,
372            ?peers,
373            protocol_name = ?self.protocol_name,
374            "close substreams",
375        );
376
377        self.command_tx
378            .try_send(NotificationCommand::CloseSubstream {
379                peers: peers.clone(),
380            })
381            .map_err(|_| peers)
382    }
383
384    /// Set new handshake.
385    pub fn set_handshake(&mut self, handshake: Vec<u8>) {
386        tracing::trace!(target: LOG_TARGET, ?handshake, protocol_name = ?self.protocol_name, "set handshake");
387
388        *self.handshake.write() = handshake;
389    }
390
391    /// Send validation result to the notification protocol for an inbound substream received from
392    /// `peer`.
393    pub fn send_validation_result(&mut self, peer: PeerId, result: ValidationResult) {
394        tracing::trace!(target: LOG_TARGET, ?peer, ?result, protocol_name = ?self.protocol_name, "send validation result");
395
396        self.pending_validations.remove(&peer).map(|tx| tx.send(result));
397    }
398
399    /// Send notification to `peer` synchronously.
400    ///
401    /// If the channel is clogged, [`NotificationError::ChannelClogged`] is returned.
402    pub fn send_sync_notification(
403        &mut self,
404        peer: PeerId,
405        notification: Vec<u8>,
406    ) -> Result<(), NotificationError> {
407        match self.peers.get_mut(&peer) {
408            Some(sink) => match sink.send_sync_notification(notification) {
409                Ok(()) => Ok(()),
410                Err(error) => match error {
411                    NotificationError::NoConnection => Err(NotificationError::NoConnection),
412                    NotificationError::ChannelClogged => {
413                        let _ = self.clogged.insert(peer).then(|| {
414                            tracing::warn!(
415                                target: LOG_TARGET,
416                                ?peer,
417                                protocol_name = ?self.protocol_name,
418                                "notification channel clogged, force close connection",
419                            );
420
421                            self.command_tx.try_send(NotificationCommand::ForceClose { peer })
422                        });
423
424                        Err(NotificationError::ChannelClogged)
425                    }
426                    // sink doesn't emit any other `NotificationError`s
427                    _ => unreachable!(),
428                },
429            },
430            None => Ok(()),
431        }
432    }
433
434    /// Send notification to `peer` asynchronously, waiting for the channel to have capacity
435    /// if it's clogged.
436    ///
437    /// Returns [`Error::PeerDoesntExist(PeerId)`](crate::error::Error::PeerDoesntExist) if the
438    /// connection has been closed.
439    pub async fn send_async_notification(
440        &mut self,
441        peer: PeerId,
442        notification: Vec<u8>,
443    ) -> crate::Result<()> {
444        match self.peers.get_mut(&peer) {
445            Some(sink) => sink.send_async_notification(notification).await,
446            None => Err(Error::PeerDoesntExist(peer)),
447        }
448    }
449
450    /// Get a copy of the underlying notification sink for the peer.
451    ///
452    /// `None` is returned if `peer` doesn't exist.
453    pub fn notification_sink(&self, peer: PeerId) -> Option<NotificationSink> {
454        self.peers.get(&peer).cloned()
455    }
456
457    #[cfg(feature = "fuzz")]
458    /// Expose functionality for fuzzing
459    pub async fn fuzz_send_message(&mut self, command: NotificationCommand) -> crate::Result<()> {
460        if let NotificationCommand::SendNotification { peer_id, notif } = command {
461            self.send_async_notification(peer_id, notif).await?;
462        } else {
463            let _ = self.command_tx.send(command).await;
464        }
465        Ok(())
466    }
467}
468
469impl Stream for NotificationHandle {
470    type Item = NotificationEvent;
471
472    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
473        loop {
474            match self.event_rx.poll_recv(cx) {
475                Poll::Pending => {}
476                Poll::Ready(None) => return Poll::Ready(None),
477                Poll::Ready(Some(event)) => match event {
478                    InnerNotificationEvent::NotificationStreamOpened {
479                        protocol,
480                        fallback,
481                        direction,
482                        peer,
483                        handshake,
484                        sink,
485                    } => {
486                        self.peers.insert(peer, sink);
487
488                        return Poll::Ready(Some(NotificationEvent::NotificationStreamOpened {
489                            protocol,
490                            fallback,
491                            direction,
492                            peer,
493                            handshake,
494                        }));
495                    }
496                    InnerNotificationEvent::NotificationStreamClosed { peer } => {
497                        self.peers.remove(&peer);
498                        self.clogged.remove(&peer);
499
500                        return Poll::Ready(Some(NotificationEvent::NotificationStreamClosed {
501                            peer,
502                        }));
503                    }
504                    InnerNotificationEvent::ValidateSubstream {
505                        protocol,
506                        fallback,
507                        peer,
508                        handshake,
509                        tx,
510                    } => {
511                        self.pending_validations.insert(peer, tx);
512
513                        return Poll::Ready(Some(NotificationEvent::ValidateSubstream {
514                            protocol,
515                            fallback,
516                            peer,
517                            handshake,
518                        }));
519                    }
520                    InnerNotificationEvent::NotificationStreamOpenFailure { peer, error } =>
521                        return Poll::Ready(Some(
522                            NotificationEvent::NotificationStreamOpenFailure { peer, error },
523                        )),
524                },
525            }
526
527            match futures::ready!(self.notif_rx.poll_recv(cx)) {
528                None => return Poll::Ready(None),
529                Some((peer, notification)) =>
530                    if self.peers.contains_key(&peer) {
531                        return Poll::Ready(Some(NotificationEvent::NotificationReceived {
532                            peer,
533                            notification,
534                        }));
535                    },
536            }
537        }
538    }
539}