sc_rpc_spec_v2/common/
connections.rs1use jsonrpsee::ConnectionId;
20use parking_lot::Mutex;
21use std::{
22 collections::{HashMap, HashSet},
23 sync::Arc,
24};
25
26#[derive(Default, Clone)]
29pub struct RpcConnections {
30 capacity: usize,
36 data: Arc<Mutex<HashMap<ConnectionId, ConnectionData>>>,
38}
39
40#[derive(Default)]
41struct ConnectionData {
42 num_identifiers: usize,
56 identifiers: HashSet<String>,
65}
66
67impl RpcConnections {
68 pub fn new(capacity: usize) -> Self {
70 RpcConnections { capacity, data: Default::default() }
71 }
72
73 pub fn reserve_space(&self, connection_id: ConnectionId) -> Option<ReservedConnection> {
78 let mut data = self.data.lock();
79
80 let entry = data.entry(connection_id).or_insert_with(ConnectionData::default);
81 if entry.num_identifiers >= self.capacity {
82 return None;
83 }
84 entry.num_identifiers = entry.num_identifiers.saturating_add(1);
85
86 Some(ReservedConnection { connection_id, rpc_connections: Some(self.clone()) })
87 }
88
89 fn unreserve_space(&self, connection_id: ConnectionId) {
95 let mut data = self.data.lock();
96
97 let entry = data.entry(connection_id).or_insert_with(ConnectionData::default);
98 entry.num_identifiers = entry.num_identifiers.saturating_sub(1);
99
100 if entry.num_identifiers == 0 {
101 data.remove(&connection_id);
102 }
103 }
104
105 fn register_identifier(&self, connection_id: ConnectionId, identifier: String) -> bool {
113 let mut data = self.data.lock();
114
115 let entry = data.entry(connection_id).or_insert_with(ConnectionData::default);
116 if entry.identifiers.len() >= self.capacity {
118 return false;
119 }
120
121 entry.identifiers.insert(identifier)
122 }
123
124 fn unregister_identifier(&self, connection_id: ConnectionId, identifier: &str) {
126 let mut data = self.data.lock();
127 if let Some(connection_data) = data.get_mut(&connection_id) {
128 connection_data.identifiers.remove(identifier);
129 connection_data.num_identifiers = connection_data.num_identifiers.saturating_sub(1);
130
131 if connection_data.num_identifiers == 0 {
132 data.remove(&connection_id);
133 }
134 }
135 }
136
137 pub fn contains_identifier(&self, connection_id: ConnectionId, identifier: &str) -> bool {
139 let data = self.data.lock();
140 data.get(&connection_id)
141 .map(|connection_data| connection_data.identifiers.contains(identifier))
142 .unwrap_or(false)
143 }
144}
145
146pub struct ReservedConnection {
149 connection_id: ConnectionId,
150 rpc_connections: Option<RpcConnections>,
151}
152
153impl ReservedConnection {
154 pub fn register(mut self, identifier: String) -> Option<RegisteredConnection> {
156 let rpc_connections = self.rpc_connections.take()?;
157
158 if rpc_connections.register_identifier(self.connection_id, identifier.clone()) {
159 Some(RegisteredConnection {
160 connection_id: self.connection_id,
161 identifier,
162 rpc_connections,
163 })
164 } else {
165 None
166 }
167 }
168}
169
170impl Drop for ReservedConnection {
171 fn drop(&mut self) {
172 if let Some(rpc_connections) = self.rpc_connections.take() {
173 rpc_connections.unreserve_space(self.connection_id);
174 }
175 }
176}
177
178pub struct RegisteredConnection {
180 connection_id: ConnectionId,
181 identifier: String,
182 rpc_connections: RpcConnections,
183}
184
185impl Drop for RegisteredConnection {
186 fn drop(&mut self) {
187 self.rpc_connections.unregister_identifier(self.connection_id, &self.identifier);
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn reserve_space() {
197 let rpc_connections = RpcConnections::new(2);
198 let conn_id = ConnectionId(1);
199 let reserved = rpc_connections.reserve_space(conn_id);
200
201 assert!(reserved.is_some());
202 assert_eq!(1, rpc_connections.data.lock().get(&conn_id).unwrap().num_identifiers);
203 assert_eq!(rpc_connections.data.lock().len(), 1);
204
205 let reserved = reserved.unwrap();
206 let registered = reserved.register("identifier1".to_string()).unwrap();
207 assert!(rpc_connections.contains_identifier(conn_id, "identifier1"));
208 assert_eq!(1, rpc_connections.data.lock().get(&conn_id).unwrap().num_identifiers);
209 drop(registered);
210
211 assert!(rpc_connections.data.lock().get(&conn_id).is_none());
213 assert!(rpc_connections.data.lock().is_empty());
214 assert!(!rpc_connections.contains_identifier(conn_id, "identifier1"));
216 }
217
218 #[test]
219 fn reserve_space_capacity_reached() {
220 let rpc_connections = RpcConnections::new(2);
221 let conn_id = ConnectionId(1);
222
223 let reserved = rpc_connections.reserve_space(conn_id);
225 assert!(reserved.is_some());
226 assert_eq!(1, rpc_connections.data.lock().get(&conn_id).unwrap().num_identifiers);
227
228 let reserved = reserved.unwrap();
230 let registered = reserved.register("identifier1".to_string()).unwrap();
231 assert!(rpc_connections.contains_identifier(conn_id, "identifier1"));
232 assert_eq!(1, rpc_connections.data.lock().get(&conn_id).unwrap().num_identifiers);
233
234 let reserved = rpc_connections.reserve_space(conn_id);
236 assert!(reserved.is_some());
237 assert_eq!(2, rpc_connections.data.lock().get(&conn_id).unwrap().num_identifiers);
238
239 let reserved = reserved.unwrap();
241 let registered_second = reserved.register("identifier2".to_string()).unwrap();
242 assert!(rpc_connections.contains_identifier(conn_id, "identifier2"));
243 assert_eq!(2, rpc_connections.data.lock().get(&conn_id).unwrap().num_identifiers);
244
245 let reserved = rpc_connections.reserve_space(conn_id);
247 assert!(reserved.is_none());
248
249 drop(registered);
251 assert_eq!(1, rpc_connections.data.lock().get(&conn_id).unwrap().num_identifiers);
252 assert!(rpc_connections.contains_identifier(conn_id, "identifier2"));
253 assert!(!rpc_connections.contains_identifier(conn_id, "identifier1"));
254
255 let reserved = rpc_connections.reserve_space(conn_id);
257 assert!(reserved.is_some());
258 assert_eq!(2, rpc_connections.data.lock().get(&conn_id).unwrap().num_identifiers);
259
260 drop(reserved);
262 drop(registered_second);
263 assert!(rpc_connections.data.lock().get(&conn_id).is_none());
264 }
265}