1use crate::{
22 error::Error,
23 protocol::notification::types::{
24 Direction, InnerNotificationEvent, NotificationCommand, NotificationError,
25 NotificationEvent, ValidationResult,
26 },
27 types::protocol::ProtocolName,
28 PeerId,
29};
30
31use bytes::BytesMut;
32use futures::Stream;
33use parking_lot::RwLock;
34use tokio::sync::{
35 mpsc::{error::TrySendError, Receiver, Sender},
36 oneshot,
37};
38
39use std::{
40 collections::{HashMap, HashSet},
41 pin::Pin,
42 sync::Arc,
43 task::{Context, Poll},
44};
45
46const LOG_TARGET: &str = "litep2p::notification::handle";
48
49#[derive(Debug, Clone)]
50pub(crate) struct NotificationEventHandle {
51 tx: Sender<InnerNotificationEvent>,
52}
53
54impl NotificationEventHandle {
55 pub(crate) fn new(tx: Sender<InnerNotificationEvent>) -> Self {
57 Self { tx }
58 }
59
60 pub(crate) async fn report_inbound_substream(
62 &self,
63 protocol: ProtocolName,
64 fallback: Option<ProtocolName>,
65 peer: PeerId,
66 handshake: Vec<u8>,
67 tx: oneshot::Sender<ValidationResult>,
68 ) {
69 let _ = self
70 .tx
71 .send(InnerNotificationEvent::ValidateSubstream {
72 protocol,
73 fallback,
74 peer,
75 handshake,
76 tx,
77 })
78 .await;
79 }
80
81 pub(crate) async fn report_notification_stream_opened(
83 &self,
84 protocol: ProtocolName,
85 fallback: Option<ProtocolName>,
86 direction: Direction,
87 peer: PeerId,
88 handshake: Vec<u8>,
89 sink: NotificationSink,
90 ) {
91 let _ = self
92 .tx
93 .send(InnerNotificationEvent::NotificationStreamOpened {
94 protocol,
95 fallback,
96 direction,
97 peer,
98 handshake,
99 sink,
100 })
101 .await;
102 }
103
104 pub(crate) async fn report_notification_stream_closed(&self, peer: PeerId) {
106 let _ = self.tx.send(InnerNotificationEvent::NotificationStreamClosed { peer }).await;
107 }
108
109 pub(crate) async fn report_notification_stream_open_failure(
111 &self,
112 peer: PeerId,
113 error: NotificationError,
114 ) {
115 let _ = self
116 .tx
117 .send(InnerNotificationEvent::NotificationStreamOpenFailure { peer, error })
118 .await;
119 }
120}
121
122#[derive(Debug, Clone)]
126pub struct NotificationSink {
127 peer: PeerId,
129
130 sync_tx: Sender<Vec<u8>>,
132
133 async_tx: Sender<Vec<u8>>,
135}
136
137impl NotificationSink {
138 pub(crate) fn new(peer: PeerId, sync_tx: Sender<Vec<u8>>, async_tx: Sender<Vec<u8>>) -> Self {
140 Self {
141 peer,
142 async_tx,
143 sync_tx,
144 }
145 }
146
147 pub fn send_sync_notification(&self, notification: Vec<u8>) -> Result<(), NotificationError> {
151 self.sync_tx.try_send(notification).map_err(|error| match error {
152 TrySendError::Closed(_) => NotificationError::NoConnection,
153 TrySendError::Full(_) => NotificationError::ChannelClogged,
154 })
155 }
156
157 pub async fn send_async_notification(&self, notification: Vec<u8>) -> crate::Result<()> {
163 self.async_tx
164 .send(notification)
165 .await
166 .map_err(|_| Error::PeerDoesntExist(self.peer))
167 }
168}
169
170#[derive(Debug)]
172pub struct NotificationHandle {
173 event_rx: Receiver<InnerNotificationEvent>,
175
176 notif_rx: Receiver<(PeerId, BytesMut)>,
178
179 command_tx: Sender<NotificationCommand>,
181
182 peers: HashMap<PeerId, NotificationSink>,
184
185 clogged: HashSet<PeerId>,
187
188 pending_validations: HashMap<PeerId, oneshot::Sender<ValidationResult>>,
190
191 handshake: Arc<RwLock<Vec<u8>>>,
193
194 protocol_name: ProtocolName,
196}
197
198impl NotificationHandle {
199 pub(crate) fn new(
201 event_rx: Receiver<InnerNotificationEvent>,
202 notif_rx: Receiver<(PeerId, BytesMut)>,
203 command_tx: Sender<NotificationCommand>,
204 handshake: Arc<RwLock<Vec<u8>>>,
205 protocol_name: ProtocolName,
206 ) -> Self {
207 Self {
208 event_rx,
209 notif_rx,
210 command_tx,
211 handshake,
212 peers: HashMap::new(),
213 clogged: HashSet::new(),
214 pending_validations: HashMap::new(),
215 protocol_name,
216 }
217 }
218
219 pub async fn open_substream(&self, peer: PeerId) -> crate::Result<()> {
228 tracing::trace!(target: LOG_TARGET, ?peer, protocol_name = ?self.protocol_name, "open substream");
229
230 if self.peers.contains_key(&peer) {
231 return Err(Error::PeerAlreadyExists(peer));
232 }
233
234 self.command_tx
235 .send(NotificationCommand::OpenSubstream {
236 peers: HashSet::from_iter([peer]),
237 })
238 .await
239 .map_or(Ok(()), |_| Ok(()))
240 }
241
242 pub async fn open_substream_batch(
249 &self,
250 peers: impl Iterator<Item = PeerId>,
251 ) -> Result<(), HashSet<PeerId>> {
252 let (to_add, to_ignore): (Vec<_>, Vec<_>) = peers
253 .map(|peer| match self.peers.contains_key(&peer) {
254 true => (None, Some(peer)),
255 false => (Some(peer), None),
256 })
257 .unzip();
258
259 let to_add = to_add.into_iter().flatten().collect::<HashSet<_>>();
260 let to_ignore = to_ignore.into_iter().flatten().collect::<HashSet<_>>();
261
262 tracing::trace!(
263 target: LOG_TARGET,
264 peers_to_add = ?to_add.len(),
265 peers_to_ignore = ?to_ignore.len(),
266 protocol_name = ?self.protocol_name,
267 "open substream",
268 );
269
270 let _ = self.command_tx.send(NotificationCommand::OpenSubstream { peers: to_add }).await;
271
272 match to_ignore.is_empty() {
273 true => Ok(()),
274 false => Err(to_ignore),
275 }
276 }
277
278 pub fn try_open_substream_batch(
286 &self,
287 peers: impl Iterator<Item = PeerId>,
288 ) -> Result<(), HashSet<PeerId>> {
289 let (to_add, to_ignore): (Vec<_>, Vec<_>) = peers
290 .map(|peer| match self.peers.contains_key(&peer) {
291 true => (None, Some(peer)),
292 false => (Some(peer), None),
293 })
294 .unzip();
295
296 let to_add = to_add.into_iter().flatten().collect::<HashSet<_>>();
297 let to_ignore = to_ignore.into_iter().flatten().collect::<HashSet<_>>();
298
299 tracing::trace!(
300 target: LOG_TARGET,
301 peers_to_add = ?to_add.len(),
302 peers_to_ignore = ?to_ignore.len(),
303 protocol_name = ?self.protocol_name,
304 "open substream",
305 );
306
307 self.command_tx
308 .try_send(NotificationCommand::OpenSubstream {
309 peers: to_add.clone(),
310 })
311 .map_err(|_| to_add)
312 }
313
314 pub async fn close_substream(&self, peer: PeerId) {
316 tracing::trace!(target: LOG_TARGET, ?peer, protocol_name = ?self.protocol_name, "close substream");
317
318 if !self.peers.contains_key(&peer) {
319 return;
320 }
321
322 let _ = self
323 .command_tx
324 .send(NotificationCommand::CloseSubstream {
325 peers: HashSet::from_iter([peer]),
326 })
327 .await;
328 }
329
330 pub async fn close_substream_batch(&self, peers: impl Iterator<Item = PeerId>) {
335 let peers = peers.filter(|peer| self.peers.contains_key(peer)).collect::<HashSet<_>>();
336
337 if peers.is_empty() {
338 return;
339 }
340
341 tracing::trace!(
342 target: LOG_TARGET,
343 ?peers,
344 protocol_name = ?self.protocol_name,
345 "close substreams",
346 );
347
348 let _ = self.command_tx.send(NotificationCommand::CloseSubstream { peers }).await;
349 }
350
351 pub fn try_close_substream_batch(
361 &self,
362 peers: impl Iterator<Item = PeerId>,
363 ) -> Result<(), HashSet<PeerId>> {
364 let peers = peers.filter(|peer| self.peers.contains_key(peer)).collect::<HashSet<_>>();
365
366 if peers.is_empty() {
367 return Err(HashSet::new());
368 }
369
370 tracing::trace!(
371 target: LOG_TARGET,
372 ?peers,
373 protocol_name = ?self.protocol_name,
374 "close substreams",
375 );
376
377 self.command_tx
378 .try_send(NotificationCommand::CloseSubstream {
379 peers: peers.clone(),
380 })
381 .map_err(|_| peers)
382 }
383
384 pub fn set_handshake(&mut self, handshake: Vec<u8>) {
386 tracing::trace!(target: LOG_TARGET, ?handshake, protocol_name = ?self.protocol_name, "set handshake");
387
388 *self.handshake.write() = handshake;
389 }
390
391 pub fn send_validation_result(&mut self, peer: PeerId, result: ValidationResult) {
394 tracing::trace!(target: LOG_TARGET, ?peer, ?result, protocol_name = ?self.protocol_name, "send validation result");
395
396 self.pending_validations.remove(&peer).map(|tx| tx.send(result));
397 }
398
399 pub fn send_sync_notification(
403 &mut self,
404 peer: PeerId,
405 notification: Vec<u8>,
406 ) -> Result<(), NotificationError> {
407 match self.peers.get_mut(&peer) {
408 Some(sink) => match sink.send_sync_notification(notification) {
409 Ok(()) => Ok(()),
410 Err(error) => match error {
411 NotificationError::NoConnection => Err(NotificationError::NoConnection),
412 NotificationError::ChannelClogged => {
413 let _ = self.clogged.insert(peer).then(|| {
414 tracing::warn!(
415 target: LOG_TARGET,
416 ?peer,
417 protocol_name = ?self.protocol_name,
418 "notification channel clogged, force close connection",
419 );
420
421 self.command_tx.try_send(NotificationCommand::ForceClose { peer })
422 });
423
424 Err(NotificationError::ChannelClogged)
425 }
426 _ => unreachable!(),
428 },
429 },
430 None => Ok(()),
431 }
432 }
433
434 pub async fn send_async_notification(
440 &mut self,
441 peer: PeerId,
442 notification: Vec<u8>,
443 ) -> crate::Result<()> {
444 match self.peers.get_mut(&peer) {
445 Some(sink) => sink.send_async_notification(notification).await,
446 None => Err(Error::PeerDoesntExist(peer)),
447 }
448 }
449
450 pub fn notification_sink(&self, peer: PeerId) -> Option<NotificationSink> {
454 self.peers.get(&peer).cloned()
455 }
456
457 #[cfg(feature = "fuzz")]
458 pub async fn fuzz_send_message(&mut self, command: NotificationCommand) -> crate::Result<()> {
460 if let NotificationCommand::SendNotification { peer_id, notif } = command {
461 self.send_async_notification(peer_id, notif).await?;
462 } else {
463 let _ = self.command_tx.send(command).await;
464 }
465 Ok(())
466 }
467}
468
469impl Stream for NotificationHandle {
470 type Item = NotificationEvent;
471
472 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
473 loop {
474 match self.event_rx.poll_recv(cx) {
475 Poll::Pending => {}
476 Poll::Ready(None) => return Poll::Ready(None),
477 Poll::Ready(Some(event)) => match event {
478 InnerNotificationEvent::NotificationStreamOpened {
479 protocol,
480 fallback,
481 direction,
482 peer,
483 handshake,
484 sink,
485 } => {
486 self.peers.insert(peer, sink);
487
488 return Poll::Ready(Some(NotificationEvent::NotificationStreamOpened {
489 protocol,
490 fallback,
491 direction,
492 peer,
493 handshake,
494 }));
495 }
496 InnerNotificationEvent::NotificationStreamClosed { peer } => {
497 self.peers.remove(&peer);
498 self.clogged.remove(&peer);
499
500 return Poll::Ready(Some(NotificationEvent::NotificationStreamClosed {
501 peer,
502 }));
503 }
504 InnerNotificationEvent::ValidateSubstream {
505 protocol,
506 fallback,
507 peer,
508 handshake,
509 tx,
510 } => {
511 self.pending_validations.insert(peer, tx);
512
513 return Poll::Ready(Some(NotificationEvent::ValidateSubstream {
514 protocol,
515 fallback,
516 peer,
517 handshake,
518 }));
519 }
520 InnerNotificationEvent::NotificationStreamOpenFailure { peer, error } =>
521 return Poll::Ready(Some(
522 NotificationEvent::NotificationStreamOpenFailure { peer, error },
523 )),
524 },
525 }
526
527 match futures::ready!(self.notif_rx.poll_recv(cx)) {
528 None => return Poll::Ready(None),
529 Some((peer, notification)) =>
530 if self.peers.contains_key(&peer) {
531 return Poll::Ready(Some(NotificationEvent::NotificationReceived {
532 peer,
533 notification,
534 }));
535 },
536 }
537 }
538 }
539}