libp2p_swarm/
stream_protocol.rs

1use either::Either;
2use std::fmt;
3use std::hash::{Hash, Hasher};
4use std::sync::Arc;
5
6/// Identifies a protocol for a stream.
7///
8/// libp2p nodes use stream protocols to negotiate what to do with a newly opened stream.
9/// Stream protocols are string-based and must start with a forward slash: `/`.
10#[derive(Clone, Eq)]
11pub struct StreamProtocol {
12    inner: Either<&'static str, Arc<str>>,
13}
14
15impl StreamProtocol {
16    /// Construct a new protocol from a static string slice.
17    ///
18    /// # Panics
19    ///
20    /// This function panics if the protocol does not start with a forward slash: `/`.
21    pub const fn new(s: &'static str) -> Self {
22        match s.as_bytes() {
23            [b'/', ..] => {}
24            _ => panic!("Protocols should start with a /"),
25        }
26
27        StreamProtocol {
28            inner: Either::Left(s),
29        }
30    }
31
32    /// Attempt to construct a protocol from an owned string.
33    ///
34    /// This function will fail if the protocol does not start with a forward slash: `/`.
35    /// Where possible, you should use [`StreamProtocol::new`] instead to avoid allocations.
36    pub fn try_from_owned(protocol: String) -> Result<Self, InvalidProtocol> {
37        if !protocol.starts_with('/') {
38            return Err(InvalidProtocol::missing_forward_slash());
39        }
40
41        Ok(StreamProtocol {
42            inner: Either::Right(Arc::from(protocol)), // FIXME: Can we somehow reuse the allocation from the owned string?
43        })
44    }
45}
46
47impl AsRef<str> for StreamProtocol {
48    fn as_ref(&self) -> &str {
49        either::for_both!(&self.inner, s => s)
50    }
51}
52
53impl fmt::Debug for StreamProtocol {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        either::for_both!(&self.inner, s => s.fmt(f))
56    }
57}
58
59impl fmt::Display for StreamProtocol {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        self.inner.fmt(f)
62    }
63}
64
65impl PartialEq<&str> for StreamProtocol {
66    fn eq(&self, other: &&str) -> bool {
67        self.as_ref() == *other
68    }
69}
70
71impl PartialEq<StreamProtocol> for &str {
72    fn eq(&self, other: &StreamProtocol) -> bool {
73        *self == other.as_ref()
74    }
75}
76
77impl PartialEq for StreamProtocol {
78    fn eq(&self, other: &Self) -> bool {
79        self.as_ref() == other.as_ref()
80    }
81}
82
83impl Hash for StreamProtocol {
84    fn hash<H: Hasher>(&self, state: &mut H) {
85        self.as_ref().hash(state)
86    }
87}
88
89#[derive(Debug)]
90pub struct InvalidProtocol {
91    // private field to prevent construction outside of this module
92    _private: (),
93}
94
95impl InvalidProtocol {
96    pub(crate) fn missing_forward_slash() -> Self {
97        InvalidProtocol { _private: () }
98    }
99}
100
101impl fmt::Display for InvalidProtocol {
102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
103        write!(
104            f,
105            "invalid protocol: string does not start with a forward slash"
106        )
107    }
108}
109
110impl std::error::Error for InvalidProtocol {}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn stream_protocol_print() {
118        let protocol = StreamProtocol::new("/foo/bar/1.0.0");
119
120        let debug = format!("{protocol:?}");
121        let display = format!("{protocol}");
122
123        assert_eq!(
124            debug, r#""/foo/bar/1.0.0""#,
125            "protocol to debug print as string with quotes"
126        );
127        assert_eq!(
128            display, "/foo/bar/1.0.0",
129            "protocol to display print as string without quotes"
130        );
131    }
132}