litep2p/protocol/
connection.rs1use crate::{
24 error::{Error, SubstreamError},
25 protocol::protocol_set::ProtocolCommand,
26 types::{protocol::ProtocolName, ConnectionId, SubstreamId},
27};
28
29use tokio::sync::mpsc::{error::TrySendError, Sender, WeakSender};
30
31#[derive(Debug, Clone)]
33enum ConnectionType {
34 Active(Sender<ProtocolCommand>),
36
37 Inactive(WeakSender<ProtocolCommand>),
41}
42
43#[derive(Debug, Clone)]
46pub struct ConnectionHandle {
47 connection: ConnectionType,
49
50 connection_id: ConnectionId,
52}
53
54impl ConnectionHandle {
55 pub fn new(connection_id: ConnectionId, connection: Sender<ProtocolCommand>) -> Self {
60 Self {
61 connection_id,
62 connection: ConnectionType::Active(connection),
63 }
64 }
65
66 pub fn downgrade(&mut self) -> Self {
72 match &self.connection {
73 ConnectionType::Active(connection) => {
74 let handle = Self::new(self.connection_id, connection.clone());
75 self.connection = ConnectionType::Inactive(connection.downgrade());
76
77 handle
78 }
79 ConnectionType::Inactive(_) => {
80 panic!("state mismatch: tried to downgrade an inactive connection")
81 }
82 }
83 }
84
85 pub fn connection_id(&self) -> &ConnectionId {
87 &self.connection_id
88 }
89
90 pub fn close(&mut self) {
92 if let ConnectionType::Active(connection) = &self.connection {
93 self.connection = ConnectionType::Inactive(connection.downgrade());
94 }
95 }
96
97 pub fn try_get_permit(&self) -> Option<Permit> {
99 match &self.connection {
100 ConnectionType::Active(active) => Some(Permit::new(active.clone())),
101 ConnectionType::Inactive(inactive) => Some(Permit::new(inactive.upgrade()?)),
102 }
103 }
104
105 pub fn open_substream(
108 &mut self,
109 protocol: ProtocolName,
110 fallback_names: Vec<ProtocolName>,
111 substream_id: SubstreamId,
112 permit: Permit,
113 ) -> Result<(), SubstreamError> {
114 match &self.connection {
115 ConnectionType::Active(active) => active.clone(),
116 ConnectionType::Inactive(inactive) =>
117 inactive.upgrade().ok_or(SubstreamError::ConnectionClosed)?,
118 }
119 .try_send(ProtocolCommand::OpenSubstream {
120 protocol: protocol.clone(),
121 fallback_names,
122 substream_id,
123 permit,
124 })
125 .map_err(|error| match error {
126 TrySendError::Full(_) => SubstreamError::ChannelClogged,
127 TrySendError::Closed(_) => SubstreamError::ConnectionClosed,
128 })
129 }
130
131 pub fn force_close(&mut self) -> crate::Result<()> {
133 match &self.connection {
134 ConnectionType::Active(active) => active.clone(),
135 ConnectionType::Inactive(inactive) =>
136 inactive.upgrade().ok_or(Error::ConnectionClosed)?,
137 }
138 .try_send(ProtocolCommand::ForceClose)
139 .map_err(|error| match error {
140 TrySendError::Full(_) => Error::ChannelClogged,
141 TrySendError::Closed(_) => Error::ConnectionClosed,
142 })
143 }
144}
145
146#[derive(Debug)]
148pub struct Permit {
149 _connection: Sender<ProtocolCommand>,
151}
152
153impl Permit {
154 pub fn new(_connection: Sender<ProtocolCommand>) -> Self {
156 Self { _connection }
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use tokio::sync::mpsc::channel;
164
165 #[test]
166 #[should_panic]
167 fn downgrade_inactive_connection() {
168 let (tx, _rx) = channel(1);
169 let mut handle = ConnectionHandle::new(ConnectionId::new(), tx);
170
171 let mut new_handle = handle.downgrade();
172 assert!(std::matches!(
173 new_handle.connection,
174 ConnectionType::Inactive(_)
175 ));
176
177 let _handle = new_handle.downgrade();
179 }
180
181 #[tokio::test]
182 async fn open_substream_open_downgraded_connection() {
183 let (tx, mut rx) = channel(1);
184 let mut handle = ConnectionHandle::new(ConnectionId::new(), tx);
185 let mut handle = handle.downgrade();
186 let permit = handle.try_get_permit().unwrap();
187
188 let result = handle.open_substream(
189 ProtocolName::from("/protocol/1"),
190 Vec::new(),
191 SubstreamId::new(),
192 permit,
193 );
194
195 assert!(result.is_ok());
196 assert!(rx.recv().await.is_some());
197 }
198
199 #[tokio::test]
200 async fn open_substream_closed_downgraded_connection() {
201 let (tx, _rx) = channel(1);
202 let mut handle = ConnectionHandle::new(ConnectionId::new(), tx);
203 let mut handle = handle.downgrade();
204 let permit = handle.try_get_permit().unwrap();
205 drop(_rx);
206
207 let result = handle.open_substream(
208 ProtocolName::from("/protocol/1"),
209 Vec::new(),
210 SubstreamId::new(),
211 permit,
212 );
213
214 assert!(result.is_err());
215 }
216
217 #[tokio::test]
218 async fn open_substream_channel_clogged() {
219 let (tx, _rx) = channel(1);
220 let mut handle = ConnectionHandle::new(ConnectionId::new(), tx);
221 let mut handle = handle.downgrade();
222 let permit = handle.try_get_permit().unwrap();
223
224 let result = handle.open_substream(
225 ProtocolName::from("/protocol/1"),
226 Vec::new(),
227 SubstreamId::new(),
228 permit,
229 );
230 assert!(result.is_ok());
231
232 let permit = handle.try_get_permit().unwrap();
233 match handle.open_substream(
234 ProtocolName::from("/protocol/1"),
235 Vec::new(),
236 SubstreamId::new(),
237 permit,
238 ) {
239 Err(SubstreamError::ChannelClogged) => {}
240 error => panic!("invalid error: {error:?}"),
241 }
242 }
243}