litep2p/protocol/
protocol_set.rs

1// Copyright 2023 litep2p developers
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::{
22    codec::ProtocolCodec,
23    error::{Error, NegotiationError, SubstreamError},
24    multistream_select::{
25        NegotiationError as MultiStreamNegotiationError, ProtocolError as MultiStreamProtocolError,
26    },
27    protocol::{
28        connection::{ConnectionHandle, Permit},
29        Direction, TransportEvent,
30    },
31    substream::Substream,
32    transport::{
33        manager::{ProtocolContext, TransportManagerEvent},
34        Endpoint,
35    },
36    types::{protocol::ProtocolName, ConnectionId, SubstreamId},
37    PeerId,
38};
39
40use futures::{stream::FuturesUnordered, Stream, StreamExt};
41use multiaddr::Multiaddr;
42use tokio::sync::mpsc::{channel, Receiver, Sender};
43
44#[cfg(any(feature = "quic", feature = "webrtc", feature = "websocket"))]
45use std::sync::atomic::Ordering;
46use std::{
47    collections::HashMap,
48    fmt::Debug,
49    pin::Pin,
50    sync::{atomic::AtomicUsize, Arc},
51    task::{Context, Poll},
52};
53
54/// Logging target for the file.
55const LOG_TARGET: &str = "litep2p::protocol-set";
56
57/// Events emitted by the underlying transport protocols.
58#[derive(Debug)]
59pub enum InnerTransportEvent {
60    /// Connection established to `peer`.
61    ConnectionEstablished {
62        /// Peer ID.
63        peer: PeerId,
64
65        /// Connection ID.
66        connection: ConnectionId,
67
68        /// Endpoint.
69        endpoint: Endpoint,
70
71        /// Handle for communicating with the connection.
72        sender: ConnectionHandle,
73    },
74
75    /// Connection closed.
76    ConnectionClosed {
77        /// Peer ID.
78        peer: PeerId,
79
80        /// Connection ID.
81        connection: ConnectionId,
82    },
83
84    /// Failed to dial peer.
85    ///
86    /// This is reported to that protocol which initiated the connection.
87    DialFailure {
88        /// Peer ID.
89        peer: PeerId,
90
91        /// Dialed address.
92        address: Multiaddr,
93    },
94
95    /// Substream opened for `peer`.
96    SubstreamOpened {
97        /// Peer ID.
98        peer: PeerId,
99
100        /// Protocol name.
101        ///
102        /// One protocol handler may handle multiple sub-protocols (such as `/ipfs/identify/1.0.0`
103        /// and `/ipfs/identify/push/1.0.0`) or it may have aliases which should be handled by
104        /// the same protocol handler. When the substream is sent from transport to the protocol
105        /// handler, the protocol name that was used to negotiate the substream is also sent so
106        /// the protocol can handle the substream appropriately.
107        protocol: ProtocolName,
108
109        /// Fallback name.
110        ///
111        /// If the substream was negotiated using a fallback name of the main protocol,
112        /// `fallback` is `Some`.
113        fallback: Option<ProtocolName>,
114
115        /// Substream direction.
116        ///
117        /// Informs the protocol whether the substream is inbound (opened by the remote node)
118        /// or outbound (opened by the local node). This allows the protocol to distinguish
119        /// between the two types of substreams and execute correct code for the substream.
120        ///
121        /// Outbound substreams also contain the substream ID which allows the protocol to
122        /// distinguish between different outbound substreams.
123        direction: Direction,
124
125        /// Substream.
126        substream: Substream,
127    },
128
129    /// Failed to open substream.
130    ///
131    /// Substream open failures are reported only for outbound substreams.
132    SubstreamOpenFailure {
133        /// Substream ID.
134        substream: SubstreamId,
135
136        /// Error that occurred when the substream was being opened.
137        error: SubstreamError,
138    },
139}
140
141impl From<InnerTransportEvent> for TransportEvent {
142    fn from(event: InnerTransportEvent) -> Self {
143        match event {
144            InnerTransportEvent::DialFailure { peer, address } =>
145                TransportEvent::DialFailure { peer, address },
146            InnerTransportEvent::SubstreamOpened {
147                peer,
148                protocol,
149                fallback,
150                direction,
151                substream,
152            } => TransportEvent::SubstreamOpened {
153                peer,
154                protocol,
155                fallback,
156                direction,
157                substream,
158            },
159            InnerTransportEvent::SubstreamOpenFailure { substream, error } =>
160                TransportEvent::SubstreamOpenFailure { substream, error },
161            event => panic!("cannot convert {event:?}"),
162        }
163    }
164}
165
166/// Events emitted by the installed protocols to transport.
167#[derive(Debug)]
168pub enum ProtocolCommand {
169    /// Open substream.
170    OpenSubstream {
171        /// Protocol name.
172        protocol: ProtocolName,
173
174        /// Fallback names.
175        ///
176        /// If the protocol has changed its name but wishes to support the old name(s), it must
177        /// provide the old protocol names in `fallback_names`. These are fed into
178        /// `multistream-select` which them attempts to negotiate a protocol for the substream
179        /// using one of the provided names and if the substream is negotiated successfully, will
180        /// report back the actual protocol name that was negotiated, in case the protocol
181        /// needs to deal with the old version of the protocol in different way compared to
182        /// the new version.
183        fallback_names: Vec<ProtocolName>,
184
185        /// Substream ID.
186        ///
187        /// Protocol allocates an ephemeral ID for outbound substreams which allows it to track
188        /// the state of its pending substream. The ID is given back to protocol in
189        /// [`TransportEvent::SubstreamOpened`]/[`TransportEvent::SubstreamOpenFailure`].
190        ///
191        /// This allows the protocol to distinguish inbound substreams from outbound substreams
192        /// and associate incoming substreams with whatever logic it has.
193        substream_id: SubstreamId,
194
195        /// Connection permit.
196        ///
197        /// `Permit` allows the connection to be kept open while the permit is held and it is given
198        /// to the substream to hold once it has been opened. When the substream is dropped, the
199        /// permit is dropped and the connection may be closed if no other permit is being
200        /// held.
201        permit: Permit,
202    },
203
204    /// Forcibly close the connection, even if other protocols have substreams open over it.
205    ForceClose,
206}
207
208/// Supported protocol information.
209///
210/// Each connection gets a copy of [`ProtocolSet`] which allows it to interact
211/// directly with installed protocols.
212pub struct ProtocolSet {
213    /// Installed protocols.
214    pub(crate) protocols: HashMap<ProtocolName, ProtocolContext>,
215    mgr_tx: Sender<TransportManagerEvent>,
216    connection: ConnectionHandle,
217    rx: Receiver<ProtocolCommand>,
218    #[allow(unused)]
219    next_substream_id: Arc<AtomicUsize>,
220    fallback_names: HashMap<ProtocolName, ProtocolName>,
221}
222
223impl ProtocolSet {
224    pub fn new(
225        connection_id: ConnectionId,
226        mgr_tx: Sender<TransportManagerEvent>,
227        next_substream_id: Arc<AtomicUsize>,
228        protocols: HashMap<ProtocolName, ProtocolContext>,
229    ) -> Self {
230        let (tx, rx) = channel(256);
231
232        let fallback_names = protocols
233            .iter()
234            .flat_map(|(protocol, context)| {
235                context
236                    .fallback_names
237                    .iter()
238                    .map(|fallback| (fallback.clone(), protocol.clone()))
239                    .collect::<HashMap<_, _>>()
240            })
241            .collect();
242
243        ProtocolSet {
244            rx,
245            mgr_tx,
246            protocols,
247            next_substream_id,
248            fallback_names,
249            connection: ConnectionHandle::new(connection_id, tx),
250        }
251    }
252
253    /// Try to acquire permit to keep the connection open.
254    pub fn try_get_permit(&mut self) -> Option<Permit> {
255        self.connection.try_get_permit()
256    }
257
258    /// Get next substream ID.
259    #[cfg(any(feature = "quic", feature = "webrtc", feature = "websocket"))]
260    pub fn next_substream_id(&self) -> SubstreamId {
261        SubstreamId::from(self.next_substream_id.fetch_add(1usize, Ordering::Relaxed))
262    }
263
264    /// Get the list of all supported protocols.
265    pub fn protocols(&self) -> Vec<ProtocolName> {
266        self.protocols
267            .keys()
268            .cloned()
269            .chain(self.fallback_names.keys().cloned())
270            .collect()
271    }
272
273    /// Report to `protocol` that substream was opened for `peer`.
274    pub async fn report_substream_open(
275        &mut self,
276        peer: PeerId,
277        protocol: ProtocolName,
278        direction: Direction,
279        substream: Substream,
280    ) -> Result<(), SubstreamError> {
281        tracing::debug!(target: LOG_TARGET, %protocol, ?peer, ?direction, "substream opened");
282
283        let (protocol, fallback) = match self.fallback_names.get(&protocol) {
284            Some(main_protocol) => (main_protocol.clone(), Some(protocol)),
285            None => (protocol, None),
286        };
287
288        let Some(protocol_context) = self.protocols.get(&protocol) else {
289            return Err(NegotiationError::MultistreamSelectError(
290                MultiStreamNegotiationError::ProtocolError(
291                    MultiStreamProtocolError::ProtocolNotSupported,
292                ),
293            )
294            .into());
295        };
296
297        let event = InnerTransportEvent::SubstreamOpened {
298            peer,
299            protocol: protocol.clone(),
300            fallback,
301            direction,
302            substream,
303        };
304
305        protocol_context
306            .tx
307            .send(event)
308            .await
309            .map_err(|_| SubstreamError::ConnectionClosed)
310    }
311
312    /// Get codec used by the protocol.
313    pub fn protocol_codec(&self, protocol: &ProtocolName) -> ProtocolCodec {
314        // NOTE: `protocol` must exist in `self.protocol` as it was negotiated
315        // using the protocols from this set
316        self.protocols
317            .get(self.fallback_names.get(protocol).map_or(protocol, |protocol| protocol))
318            .expect("protocol to exist")
319            .codec
320    }
321
322    /// Report to `protocol` that connection failed to open substream for `peer`.
323    pub async fn report_substream_open_failure(
324        &mut self,
325        protocol: ProtocolName,
326        substream: SubstreamId,
327        error: SubstreamError,
328    ) -> crate::Result<()> {
329        tracing::debug!(
330            target: LOG_TARGET,
331            %protocol,
332            ?substream,
333            ?error,
334            "failed to open substream",
335        );
336
337        self.protocols
338            .get_mut(&protocol)
339            .ok_or(Error::ProtocolNotSupported(protocol.to_string()))?
340            .tx
341            .send(InnerTransportEvent::SubstreamOpenFailure { substream, error })
342            .await
343            .map_err(From::from)
344    }
345
346    /// Report to protocols that a connection was established.
347    pub(crate) async fn report_connection_established(
348        &mut self,
349        peer: PeerId,
350        endpoint: Endpoint,
351    ) -> crate::Result<()> {
352        let connection_handle = self.connection.downgrade();
353        let mut futures = self
354            .protocols
355            .values()
356            .map(|sender| {
357                let endpoint = endpoint.clone();
358                let connection_handle = connection_handle.clone();
359
360                async move {
361                    sender
362                        .tx
363                        .send(InnerTransportEvent::ConnectionEstablished {
364                            peer,
365                            connection: endpoint.connection_id(),
366                            endpoint,
367                            sender: connection_handle,
368                        })
369                        .await
370                }
371            })
372            .collect::<FuturesUnordered<_>>();
373
374        while !futures.is_empty() {
375            if let Some(Err(error)) = futures.next().await {
376                return Err(error.into());
377            }
378        }
379
380        Ok(())
381    }
382
383    /// Report to protocols that a connection was closed.
384    pub(crate) async fn report_connection_closed(
385        &mut self,
386        peer: PeerId,
387        connection_id: ConnectionId,
388    ) -> crate::Result<()> {
389        let mut futures = self
390            .protocols
391            .values()
392            .map(|sender| async move {
393                sender
394                    .tx
395                    .send(InnerTransportEvent::ConnectionClosed {
396                        peer,
397                        connection: connection_id,
398                    })
399                    .await
400            })
401            .collect::<FuturesUnordered<_>>();
402
403        while !futures.is_empty() {
404            if let Some(Err(error)) = futures.next().await {
405                return Err(error.into());
406            }
407        }
408
409        self.mgr_tx
410            .send(TransportManagerEvent::ConnectionClosed {
411                peer,
412                connection: connection_id,
413            })
414            .await
415            .map_err(From::from)
416    }
417}
418
419impl Stream for ProtocolSet {
420    type Item = ProtocolCommand;
421
422    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
423        self.rx.poll_recv(cx)
424    }
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430    use crate::mock::substream::MockSubstream;
431    use std::collections::HashSet;
432
433    #[tokio::test]
434    async fn fallback_is_provided() {
435        let (tx, _rx) = channel(64);
436        let (tx1, _rx1) = channel(64);
437
438        let mut protocol_set = ProtocolSet::new(
439            ConnectionId::from(0usize),
440            tx,
441            Default::default(),
442            HashMap::from_iter([(
443                ProtocolName::from("/notif/1"),
444                ProtocolContext {
445                    tx: tx1,
446                    codec: ProtocolCodec::Identity(32),
447                    fallback_names: vec![
448                        ProtocolName::from("/notif/1/fallback/1"),
449                        ProtocolName::from("/notif/1/fallback/2"),
450                    ],
451                },
452            )]),
453        );
454
455        let expected_protocols = HashSet::from([
456            ProtocolName::from("/notif/1"),
457            ProtocolName::from("/notif/1/fallback/1"),
458            ProtocolName::from("/notif/1/fallback/2"),
459        ]);
460
461        for protocol in protocol_set.protocols().iter() {
462            assert!(expected_protocols.contains(protocol));
463        }
464
465        protocol_set
466            .report_substream_open(
467                PeerId::random(),
468                ProtocolName::from("/notif/1/fallback/2"),
469                Direction::Inbound,
470                Substream::new_mock(
471                    PeerId::random(),
472                    SubstreamId::from(0usize),
473                    Box::new(MockSubstream::new()),
474                ),
475            )
476            .await
477            .unwrap();
478    }
479
480    #[tokio::test]
481    async fn main_protocol_reported_if_main_protocol_negotiated() {
482        let (tx, _rx) = channel(64);
483        let (tx1, mut rx1) = channel(64);
484
485        let mut protocol_set = ProtocolSet::new(
486            ConnectionId::from(0usize),
487            tx,
488            Default::default(),
489            HashMap::from_iter([(
490                ProtocolName::from("/notif/1"),
491                ProtocolContext {
492                    tx: tx1,
493                    codec: ProtocolCodec::Identity(32),
494                    fallback_names: vec![
495                        ProtocolName::from("/notif/1/fallback/1"),
496                        ProtocolName::from("/notif/1/fallback/2"),
497                    ],
498                },
499            )]),
500        );
501
502        protocol_set
503            .report_substream_open(
504                PeerId::random(),
505                ProtocolName::from("/notif/1"),
506                Direction::Inbound,
507                Substream::new_mock(
508                    PeerId::random(),
509                    SubstreamId::from(0usize),
510                    Box::new(MockSubstream::new()),
511                ),
512            )
513            .await
514            .unwrap();
515
516        match rx1.recv().await.unwrap() {
517            InnerTransportEvent::SubstreamOpened {
518                protocol, fallback, ..
519            } => {
520                assert!(fallback.is_none());
521                assert_eq!(protocol, ProtocolName::from("/notif/1"));
522            }
523            _ => panic!("invalid event received"),
524        }
525    }
526
527    #[tokio::test]
528    async fn fallback_is_reported_to_protocol() {
529        let (tx, _rx) = channel(64);
530        let (tx1, mut rx1) = channel(64);
531
532        let mut protocol_set = ProtocolSet::new(
533            ConnectionId::from(0usize),
534            tx,
535            Default::default(),
536            HashMap::from_iter([(
537                ProtocolName::from("/notif/1"),
538                ProtocolContext {
539                    tx: tx1,
540                    codec: ProtocolCodec::Identity(32),
541                    fallback_names: vec![
542                        ProtocolName::from("/notif/1/fallback/1"),
543                        ProtocolName::from("/notif/1/fallback/2"),
544                    ],
545                },
546            )]),
547        );
548
549        protocol_set
550            .report_substream_open(
551                PeerId::random(),
552                ProtocolName::from("/notif/1/fallback/2"),
553                Direction::Inbound,
554                Substream::new_mock(
555                    PeerId::random(),
556                    SubstreamId::from(0usize),
557                    Box::new(MockSubstream::new()),
558                ),
559            )
560            .await
561            .unwrap();
562
563        match rx1.recv().await.unwrap() {
564            InnerTransportEvent::SubstreamOpened {
565                protocol, fallback, ..
566            } => {
567                assert_eq!(fallback, Some(ProtocolName::from("/notif/1/fallback/2")));
568                assert_eq!(protocol, ProtocolName::from("/notif/1"));
569            }
570            _ => panic!("invalid event received"),
571        }
572    }
573}