1use crate::{
22 codec::ProtocolCodec,
23 error::{Error, NegotiationError, SubstreamError},
24 multistream_select::{
25 NegotiationError as MultiStreamNegotiationError, ProtocolError as MultiStreamProtocolError,
26 },
27 protocol::{
28 connection::{ConnectionHandle, Permit},
29 Direction, TransportEvent,
30 },
31 substream::Substream,
32 transport::{
33 manager::{ProtocolContext, TransportManagerEvent},
34 Endpoint,
35 },
36 types::{protocol::ProtocolName, ConnectionId, SubstreamId},
37 PeerId,
38};
39
40use futures::{stream::FuturesUnordered, Stream, StreamExt};
41use multiaddr::Multiaddr;
42use tokio::sync::mpsc::{channel, Receiver, Sender};
43
44#[cfg(any(feature = "quic", feature = "webrtc", feature = "websocket"))]
45use std::sync::atomic::Ordering;
46use std::{
47 collections::HashMap,
48 fmt::Debug,
49 pin::Pin,
50 sync::{atomic::AtomicUsize, Arc},
51 task::{Context, Poll},
52};
53
54const LOG_TARGET: &str = "litep2p::protocol-set";
56
57#[derive(Debug)]
59pub enum InnerTransportEvent {
60 ConnectionEstablished {
62 peer: PeerId,
64
65 connection: ConnectionId,
67
68 endpoint: Endpoint,
70
71 sender: ConnectionHandle,
73 },
74
75 ConnectionClosed {
77 peer: PeerId,
79
80 connection: ConnectionId,
82 },
83
84 DialFailure {
88 peer: PeerId,
90
91 address: Multiaddr,
93 },
94
95 SubstreamOpened {
97 peer: PeerId,
99
100 protocol: ProtocolName,
108
109 fallback: Option<ProtocolName>,
114
115 direction: Direction,
124
125 substream: Substream,
127 },
128
129 SubstreamOpenFailure {
133 substream: SubstreamId,
135
136 error: SubstreamError,
138 },
139}
140
141impl From<InnerTransportEvent> for TransportEvent {
142 fn from(event: InnerTransportEvent) -> Self {
143 match event {
144 InnerTransportEvent::DialFailure { peer, address } =>
145 TransportEvent::DialFailure { peer, address },
146 InnerTransportEvent::SubstreamOpened {
147 peer,
148 protocol,
149 fallback,
150 direction,
151 substream,
152 } => TransportEvent::SubstreamOpened {
153 peer,
154 protocol,
155 fallback,
156 direction,
157 substream,
158 },
159 InnerTransportEvent::SubstreamOpenFailure { substream, error } =>
160 TransportEvent::SubstreamOpenFailure { substream, error },
161 event => panic!("cannot convert {event:?}"),
162 }
163 }
164}
165
166#[derive(Debug)]
168pub enum ProtocolCommand {
169 OpenSubstream {
171 protocol: ProtocolName,
173
174 fallback_names: Vec<ProtocolName>,
184
185 substream_id: SubstreamId,
194
195 permit: Permit,
202 },
203
204 ForceClose,
206}
207
208pub struct ProtocolSet {
213 pub(crate) protocols: HashMap<ProtocolName, ProtocolContext>,
215 mgr_tx: Sender<TransportManagerEvent>,
216 connection: ConnectionHandle,
217 rx: Receiver<ProtocolCommand>,
218 #[allow(unused)]
219 next_substream_id: Arc<AtomicUsize>,
220 fallback_names: HashMap<ProtocolName, ProtocolName>,
221}
222
223impl ProtocolSet {
224 pub fn new(
225 connection_id: ConnectionId,
226 mgr_tx: Sender<TransportManagerEvent>,
227 next_substream_id: Arc<AtomicUsize>,
228 protocols: HashMap<ProtocolName, ProtocolContext>,
229 ) -> Self {
230 let (tx, rx) = channel(256);
231
232 let fallback_names = protocols
233 .iter()
234 .flat_map(|(protocol, context)| {
235 context
236 .fallback_names
237 .iter()
238 .map(|fallback| (fallback.clone(), protocol.clone()))
239 .collect::<HashMap<_, _>>()
240 })
241 .collect();
242
243 ProtocolSet {
244 rx,
245 mgr_tx,
246 protocols,
247 next_substream_id,
248 fallback_names,
249 connection: ConnectionHandle::new(connection_id, tx),
250 }
251 }
252
253 pub fn try_get_permit(&mut self) -> Option<Permit> {
255 self.connection.try_get_permit()
256 }
257
258 #[cfg(any(feature = "quic", feature = "webrtc", feature = "websocket"))]
260 pub fn next_substream_id(&self) -> SubstreamId {
261 SubstreamId::from(self.next_substream_id.fetch_add(1usize, Ordering::Relaxed))
262 }
263
264 pub fn protocols(&self) -> Vec<ProtocolName> {
266 self.protocols
267 .keys()
268 .cloned()
269 .chain(self.fallback_names.keys().cloned())
270 .collect()
271 }
272
273 pub async fn report_substream_open(
275 &mut self,
276 peer: PeerId,
277 protocol: ProtocolName,
278 direction: Direction,
279 substream: Substream,
280 ) -> Result<(), SubstreamError> {
281 tracing::debug!(target: LOG_TARGET, %protocol, ?peer, ?direction, "substream opened");
282
283 let (protocol, fallback) = match self.fallback_names.get(&protocol) {
284 Some(main_protocol) => (main_protocol.clone(), Some(protocol)),
285 None => (protocol, None),
286 };
287
288 let Some(protocol_context) = self.protocols.get(&protocol) else {
289 return Err(NegotiationError::MultistreamSelectError(
290 MultiStreamNegotiationError::ProtocolError(
291 MultiStreamProtocolError::ProtocolNotSupported,
292 ),
293 )
294 .into());
295 };
296
297 let event = InnerTransportEvent::SubstreamOpened {
298 peer,
299 protocol: protocol.clone(),
300 fallback,
301 direction,
302 substream,
303 };
304
305 protocol_context
306 .tx
307 .send(event)
308 .await
309 .map_err(|_| SubstreamError::ConnectionClosed)
310 }
311
312 pub fn protocol_codec(&self, protocol: &ProtocolName) -> ProtocolCodec {
314 self.protocols
317 .get(self.fallback_names.get(protocol).map_or(protocol, |protocol| protocol))
318 .expect("protocol to exist")
319 .codec
320 }
321
322 pub async fn report_substream_open_failure(
324 &mut self,
325 protocol: ProtocolName,
326 substream: SubstreamId,
327 error: SubstreamError,
328 ) -> crate::Result<()> {
329 tracing::debug!(
330 target: LOG_TARGET,
331 %protocol,
332 ?substream,
333 ?error,
334 "failed to open substream",
335 );
336
337 self.protocols
338 .get_mut(&protocol)
339 .ok_or(Error::ProtocolNotSupported(protocol.to_string()))?
340 .tx
341 .send(InnerTransportEvent::SubstreamOpenFailure { substream, error })
342 .await
343 .map_err(From::from)
344 }
345
346 pub(crate) async fn report_connection_established(
348 &mut self,
349 peer: PeerId,
350 endpoint: Endpoint,
351 ) -> crate::Result<()> {
352 let connection_handle = self.connection.downgrade();
353 let mut futures = self
354 .protocols
355 .values()
356 .map(|sender| {
357 let endpoint = endpoint.clone();
358 let connection_handle = connection_handle.clone();
359
360 async move {
361 sender
362 .tx
363 .send(InnerTransportEvent::ConnectionEstablished {
364 peer,
365 connection: endpoint.connection_id(),
366 endpoint,
367 sender: connection_handle,
368 })
369 .await
370 }
371 })
372 .collect::<FuturesUnordered<_>>();
373
374 while !futures.is_empty() {
375 if let Some(Err(error)) = futures.next().await {
376 return Err(error.into());
377 }
378 }
379
380 Ok(())
381 }
382
383 pub(crate) async fn report_connection_closed(
385 &mut self,
386 peer: PeerId,
387 connection_id: ConnectionId,
388 ) -> crate::Result<()> {
389 let mut futures = self
390 .protocols
391 .values()
392 .map(|sender| async move {
393 sender
394 .tx
395 .send(InnerTransportEvent::ConnectionClosed {
396 peer,
397 connection: connection_id,
398 })
399 .await
400 })
401 .collect::<FuturesUnordered<_>>();
402
403 while !futures.is_empty() {
404 if let Some(Err(error)) = futures.next().await {
405 return Err(error.into());
406 }
407 }
408
409 self.mgr_tx
410 .send(TransportManagerEvent::ConnectionClosed {
411 peer,
412 connection: connection_id,
413 })
414 .await
415 .map_err(From::from)
416 }
417}
418
419impl Stream for ProtocolSet {
420 type Item = ProtocolCommand;
421
422 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
423 self.rx.poll_recv(cx)
424 }
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430 use crate::mock::substream::MockSubstream;
431 use std::collections::HashSet;
432
433 #[tokio::test]
434 async fn fallback_is_provided() {
435 let (tx, _rx) = channel(64);
436 let (tx1, _rx1) = channel(64);
437
438 let mut protocol_set = ProtocolSet::new(
439 ConnectionId::from(0usize),
440 tx,
441 Default::default(),
442 HashMap::from_iter([(
443 ProtocolName::from("/notif/1"),
444 ProtocolContext {
445 tx: tx1,
446 codec: ProtocolCodec::Identity(32),
447 fallback_names: vec![
448 ProtocolName::from("/notif/1/fallback/1"),
449 ProtocolName::from("/notif/1/fallback/2"),
450 ],
451 },
452 )]),
453 );
454
455 let expected_protocols = HashSet::from([
456 ProtocolName::from("/notif/1"),
457 ProtocolName::from("/notif/1/fallback/1"),
458 ProtocolName::from("/notif/1/fallback/2"),
459 ]);
460
461 for protocol in protocol_set.protocols().iter() {
462 assert!(expected_protocols.contains(protocol));
463 }
464
465 protocol_set
466 .report_substream_open(
467 PeerId::random(),
468 ProtocolName::from("/notif/1/fallback/2"),
469 Direction::Inbound,
470 Substream::new_mock(
471 PeerId::random(),
472 SubstreamId::from(0usize),
473 Box::new(MockSubstream::new()),
474 ),
475 )
476 .await
477 .unwrap();
478 }
479
480 #[tokio::test]
481 async fn main_protocol_reported_if_main_protocol_negotiated() {
482 let (tx, _rx) = channel(64);
483 let (tx1, mut rx1) = channel(64);
484
485 let mut protocol_set = ProtocolSet::new(
486 ConnectionId::from(0usize),
487 tx,
488 Default::default(),
489 HashMap::from_iter([(
490 ProtocolName::from("/notif/1"),
491 ProtocolContext {
492 tx: tx1,
493 codec: ProtocolCodec::Identity(32),
494 fallback_names: vec![
495 ProtocolName::from("/notif/1/fallback/1"),
496 ProtocolName::from("/notif/1/fallback/2"),
497 ],
498 },
499 )]),
500 );
501
502 protocol_set
503 .report_substream_open(
504 PeerId::random(),
505 ProtocolName::from("/notif/1"),
506 Direction::Inbound,
507 Substream::new_mock(
508 PeerId::random(),
509 SubstreamId::from(0usize),
510 Box::new(MockSubstream::new()),
511 ),
512 )
513 .await
514 .unwrap();
515
516 match rx1.recv().await.unwrap() {
517 InnerTransportEvent::SubstreamOpened {
518 protocol, fallback, ..
519 } => {
520 assert!(fallback.is_none());
521 assert_eq!(protocol, ProtocolName::from("/notif/1"));
522 }
523 _ => panic!("invalid event received"),
524 }
525 }
526
527 #[tokio::test]
528 async fn fallback_is_reported_to_protocol() {
529 let (tx, _rx) = channel(64);
530 let (tx1, mut rx1) = channel(64);
531
532 let mut protocol_set = ProtocolSet::new(
533 ConnectionId::from(0usize),
534 tx,
535 Default::default(),
536 HashMap::from_iter([(
537 ProtocolName::from("/notif/1"),
538 ProtocolContext {
539 tx: tx1,
540 codec: ProtocolCodec::Identity(32),
541 fallback_names: vec![
542 ProtocolName::from("/notif/1/fallback/1"),
543 ProtocolName::from("/notif/1/fallback/2"),
544 ],
545 },
546 )]),
547 );
548
549 protocol_set
550 .report_substream_open(
551 PeerId::random(),
552 ProtocolName::from("/notif/1/fallback/2"),
553 Direction::Inbound,
554 Substream::new_mock(
555 PeerId::random(),
556 SubstreamId::from(0usize),
557 Box::new(MockSubstream::new()),
558 ),
559 )
560 .await
561 .unwrap();
562
563 match rx1.recv().await.unwrap() {
564 InnerTransportEvent::SubstreamOpened {
565 protocol, fallback, ..
566 } => {
567 assert_eq!(fallback, Some(ProtocolName::from("/notif/1/fallback/2")));
568 assert_eq!(protocol, ProtocolName::from("/notif/1"));
569 }
570 _ => panic!("invalid event received"),
571 }
572 }
573}