1use crate::{
22 error::Error,
23 protocol::notification::types::{
24 Direction, InnerNotificationEvent, NotificationCommand, NotificationError,
25 NotificationEvent, ValidationResult,
26 },
27 types::protocol::ProtocolName,
28 PeerId,
29};
30
31use bytes::BytesMut;
32use futures::Stream;
33use parking_lot::RwLock;
34use tokio::sync::{
35 mpsc::{error::TrySendError, Receiver, Sender},
36 oneshot,
37};
38
39use std::{
40 collections::{HashMap, HashSet},
41 pin::Pin,
42 sync::Arc,
43 task::{Context, Poll},
44};
45
46const LOG_TARGET: &str = "litep2p::notification::handle";
48
49#[derive(Debug, Clone)]
50pub(crate) struct NotificationEventHandle {
51 tx: Sender<InnerNotificationEvent>,
52}
53
54impl NotificationEventHandle {
55 pub(crate) fn new(tx: Sender<InnerNotificationEvent>) -> Self {
57 Self { tx }
58 }
59
60 pub(crate) async fn report_inbound_substream(
62 &self,
63 protocol: ProtocolName,
64 fallback: Option<ProtocolName>,
65 peer: PeerId,
66 handshake: Vec<u8>,
67 tx: oneshot::Sender<ValidationResult>,
68 ) {
69 let _ = self
70 .tx
71 .send(InnerNotificationEvent::ValidateSubstream {
72 protocol,
73 fallback,
74 peer,
75 handshake,
76 tx,
77 })
78 .await;
79 }
80
81 pub(crate) async fn report_notification_stream_opened(
83 &self,
84 protocol: ProtocolName,
85 fallback: Option<ProtocolName>,
86 direction: Direction,
87 peer: PeerId,
88 handshake: Vec<u8>,
89 sink: NotificationSink,
90 ) {
91 let _ = self
92 .tx
93 .send(InnerNotificationEvent::NotificationStreamOpened {
94 protocol,
95 fallback,
96 direction,
97 peer,
98 handshake,
99 sink,
100 })
101 .await;
102 }
103
104 pub(crate) async fn report_notification_stream_closed(&self, peer: PeerId) {
106 let _ = self.tx.send(InnerNotificationEvent::NotificationStreamClosed { peer }).await;
107 }
108
109 pub(crate) async fn report_notification_stream_open_failure(
111 &self,
112 peer: PeerId,
113 error: NotificationError,
114 ) {
115 let _ = self
116 .tx
117 .send(InnerNotificationEvent::NotificationStreamOpenFailure { peer, error })
118 .await;
119 }
120}
121
122#[derive(Debug, Clone)]
126pub struct NotificationSink {
127 peer: PeerId,
129
130 sync_tx: Sender<Vec<u8>>,
132
133 async_tx: Sender<Vec<u8>>,
135}
136
137impl NotificationSink {
138 pub(crate) fn new(peer: PeerId, sync_tx: Sender<Vec<u8>>, async_tx: Sender<Vec<u8>>) -> Self {
140 Self {
141 peer,
142 async_tx,
143 sync_tx,
144 }
145 }
146
147 pub fn send_sync_notification(&self, notification: Vec<u8>) -> Result<(), NotificationError> {
151 self.sync_tx.try_send(notification).map_err(|error| match error {
152 TrySendError::Closed(_) => NotificationError::NoConnection,
153 TrySendError::Full(_) => NotificationError::ChannelClogged,
154 })
155 }
156
157 pub async fn send_async_notification(&self, notification: Vec<u8>) -> crate::Result<()> {
163 self.async_tx
164 .send(notification)
165 .await
166 .map_err(|_| Error::PeerDoesntExist(self.peer))
167 }
168}
169
170#[derive(Debug)]
172pub struct NotificationHandle {
173 event_rx: Receiver<InnerNotificationEvent>,
175
176 notif_rx: Receiver<(PeerId, BytesMut)>,
178
179 command_tx: Sender<NotificationCommand>,
181
182 peers: HashMap<PeerId, NotificationSink>,
184
185 clogged: HashSet<PeerId>,
187
188 pending_validations: HashMap<PeerId, oneshot::Sender<ValidationResult>>,
190
191 handshake: Arc<RwLock<Vec<u8>>>,
193}
194
195impl NotificationHandle {
196 pub(crate) fn new(
198 event_rx: Receiver<InnerNotificationEvent>,
199 notif_rx: Receiver<(PeerId, BytesMut)>,
200 command_tx: Sender<NotificationCommand>,
201 handshake: Arc<RwLock<Vec<u8>>>,
202 ) -> Self {
203 Self {
204 event_rx,
205 notif_rx,
206 command_tx,
207 handshake,
208 peers: HashMap::new(),
209 clogged: HashSet::new(),
210 pending_validations: HashMap::new(),
211 }
212 }
213
214 pub async fn open_substream(&self, peer: PeerId) -> crate::Result<()> {
223 tracing::trace!(target: LOG_TARGET, ?peer, "open substream");
224
225 if self.peers.contains_key(&peer) {
226 return Err(Error::PeerAlreadyExists(peer));
227 }
228
229 self.command_tx
230 .send(NotificationCommand::OpenSubstream {
231 peers: HashSet::from_iter([peer]),
232 })
233 .await
234 .map_or(Ok(()), |_| Ok(()))
235 }
236
237 pub async fn open_substream_batch(
244 &self,
245 peers: impl Iterator<Item = PeerId>,
246 ) -> Result<(), HashSet<PeerId>> {
247 let (to_add, to_ignore): (Vec<_>, Vec<_>) = peers
248 .map(|peer| match self.peers.contains_key(&peer) {
249 true => (None, Some(peer)),
250 false => (Some(peer), None),
251 })
252 .unzip();
253
254 let to_add = to_add.into_iter().flatten().collect::<HashSet<_>>();
255 let to_ignore = to_ignore.into_iter().flatten().collect::<HashSet<_>>();
256
257 tracing::trace!(
258 target: LOG_TARGET,
259 peers_to_add = ?to_add.len(),
260 peers_to_ignore = ?to_ignore.len(),
261 "open substream",
262 );
263
264 let _ = self.command_tx.send(NotificationCommand::OpenSubstream { peers: to_add }).await;
265
266 match to_ignore.is_empty() {
267 true => Ok(()),
268 false => Err(to_ignore),
269 }
270 }
271
272 pub fn try_open_substream_batch(
280 &self,
281 peers: impl Iterator<Item = PeerId>,
282 ) -> Result<(), HashSet<PeerId>> {
283 let (to_add, to_ignore): (Vec<_>, Vec<_>) = peers
284 .map(|peer| match self.peers.contains_key(&peer) {
285 true => (None, Some(peer)),
286 false => (Some(peer), None),
287 })
288 .unzip();
289
290 let to_add = to_add.into_iter().flatten().collect::<HashSet<_>>();
291 let to_ignore = to_ignore.into_iter().flatten().collect::<HashSet<_>>();
292
293 tracing::trace!(
294 target: LOG_TARGET,
295 peers_to_add = ?to_add.len(),
296 peers_to_ignore = ?to_ignore.len(),
297 "open substream",
298 );
299
300 self.command_tx
301 .try_send(NotificationCommand::OpenSubstream {
302 peers: to_add.clone(),
303 })
304 .map_err(|_| to_add)
305 }
306
307 pub async fn close_substream(&self, peer: PeerId) {
309 tracing::trace!(target: LOG_TARGET, ?peer, "close substream");
310
311 if !self.peers.contains_key(&peer) {
312 return;
313 }
314
315 let _ = self
316 .command_tx
317 .send(NotificationCommand::CloseSubstream {
318 peers: HashSet::from_iter([peer]),
319 })
320 .await;
321 }
322
323 pub async fn close_substream_batch(&self, peers: impl Iterator<Item = PeerId>) {
328 let peers = peers.filter(|peer| self.peers.contains_key(peer)).collect::<HashSet<_>>();
329
330 if peers.is_empty() {
331 return;
332 }
333
334 tracing::trace!(
335 target: LOG_TARGET,
336 ?peers,
337 "close substreams",
338 );
339
340 let _ = self.command_tx.send(NotificationCommand::CloseSubstream { peers }).await;
341 }
342
343 pub fn try_close_substream_batch(
353 &self,
354 peers: impl Iterator<Item = PeerId>,
355 ) -> Result<(), HashSet<PeerId>> {
356 let peers = peers.filter(|peer| self.peers.contains_key(peer)).collect::<HashSet<_>>();
357
358 if peers.is_empty() {
359 return Err(HashSet::new());
360 }
361
362 tracing::trace!(
363 target: LOG_TARGET,
364 ?peers,
365 "close substreams",
366 );
367
368 self.command_tx
369 .try_send(NotificationCommand::CloseSubstream {
370 peers: peers.clone(),
371 })
372 .map_err(|_| peers)
373 }
374
375 pub fn set_handshake(&mut self, handshake: Vec<u8>) {
377 tracing::trace!(target: LOG_TARGET, ?handshake, "set handshake");
378
379 *self.handshake.write() = handshake;
380 }
381
382 pub fn send_validation_result(&mut self, peer: PeerId, result: ValidationResult) {
385 tracing::trace!(target: LOG_TARGET, ?peer, ?result, "send validation result");
386
387 self.pending_validations.remove(&peer).map(|tx| tx.send(result));
388 }
389
390 pub fn send_sync_notification(
394 &mut self,
395 peer: PeerId,
396 notification: Vec<u8>,
397 ) -> Result<(), NotificationError> {
398 match self.peers.get_mut(&peer) {
399 Some(sink) => match sink.send_sync_notification(notification) {
400 Ok(()) => Ok(()),
401 Err(error) => match error {
402 NotificationError::NoConnection => Err(NotificationError::NoConnection),
403 NotificationError::ChannelClogged => {
404 let _ = self.clogged.insert(peer).then(|| {
405 self.command_tx.try_send(NotificationCommand::ForceClose { peer })
406 });
407
408 Err(NotificationError::ChannelClogged)
409 }
410 _ => unreachable!(),
412 },
413 },
414 None => Ok(()),
415 }
416 }
417
418 pub async fn send_async_notification(
424 &mut self,
425 peer: PeerId,
426 notification: Vec<u8>,
427 ) -> crate::Result<()> {
428 match self.peers.get_mut(&peer) {
429 Some(sink) => sink.send_async_notification(notification).await,
430 None => Err(Error::PeerDoesntExist(peer)),
431 }
432 }
433
434 pub fn notification_sink(&self, peer: PeerId) -> Option<NotificationSink> {
438 self.peers.get(&peer).cloned()
439 }
440}
441
442impl Stream for NotificationHandle {
443 type Item = NotificationEvent;
444
445 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
446 loop {
447 match self.event_rx.poll_recv(cx) {
448 Poll::Pending => {}
449 Poll::Ready(None) => return Poll::Ready(None),
450 Poll::Ready(Some(event)) => match event {
451 InnerNotificationEvent::NotificationStreamOpened {
452 protocol,
453 fallback,
454 direction,
455 peer,
456 handshake,
457 sink,
458 } => {
459 self.peers.insert(peer, sink);
460
461 return Poll::Ready(Some(NotificationEvent::NotificationStreamOpened {
462 protocol,
463 fallback,
464 direction,
465 peer,
466 handshake,
467 }));
468 }
469 InnerNotificationEvent::NotificationStreamClosed { peer } => {
470 self.peers.remove(&peer);
471 self.clogged.remove(&peer);
472
473 return Poll::Ready(Some(NotificationEvent::NotificationStreamClosed {
474 peer,
475 }));
476 }
477 InnerNotificationEvent::ValidateSubstream {
478 protocol,
479 fallback,
480 peer,
481 handshake,
482 tx,
483 } => {
484 self.pending_validations.insert(peer, tx);
485
486 return Poll::Ready(Some(NotificationEvent::ValidateSubstream {
487 protocol,
488 fallback,
489 peer,
490 handshake,
491 }));
492 }
493 InnerNotificationEvent::NotificationStreamOpenFailure { peer, error } =>
494 return Poll::Ready(Some(
495 NotificationEvent::NotificationStreamOpenFailure { peer, error },
496 )),
497 },
498 }
499
500 match futures::ready!(self.notif_rx.poll_recv(cx)) {
501 None => return Poll::Ready(None),
502 Some((peer, notification)) =>
503 if self.peers.contains_key(&peer) {
504 return Poll::Ready(Some(NotificationEvent::NotificationReceived {
505 peer,
506 notification,
507 }));
508 },
509 }
510 }
511 }
512}