use jsonrpsee::ConnectionId;
use parking_lot::Mutex;
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
#[derive(Default, Clone)]
pub struct RpcConnections {
capacity: usize,
data: Arc<Mutex<HashMap<ConnectionId, ConnectionData>>>,
}
#[derive(Default)]
struct ConnectionData {
num_identifiers: usize,
identifiers: HashSet<String>,
}
impl RpcConnections {
pub fn new(capacity: usize) -> Self {
RpcConnections { capacity, data: Default::default() }
}
pub fn reserve_space(&self, connection_id: ConnectionId) -> Option<ReservedConnection> {
let mut data = self.data.lock();
let entry = data.entry(connection_id).or_insert_with(ConnectionData::default);
if entry.num_identifiers >= self.capacity {
return None;
}
entry.num_identifiers = entry.num_identifiers.saturating_add(1);
Some(ReservedConnection { connection_id, rpc_connections: Some(self.clone()) })
}
fn unreserve_space(&self, connection_id: ConnectionId) {
let mut data = self.data.lock();
let entry = data.entry(connection_id).or_insert_with(ConnectionData::default);
entry.num_identifiers = entry.num_identifiers.saturating_sub(1);
if entry.num_identifiers == 0 {
data.remove(&connection_id);
}
}
fn register_identifier(&self, connection_id: ConnectionId, identifier: String) -> bool {
let mut data = self.data.lock();
let entry = data.entry(connection_id).or_insert_with(ConnectionData::default);
if entry.identifiers.len() >= self.capacity {
return false;
}
entry.identifiers.insert(identifier)
}
fn unregister_identifier(&self, connection_id: ConnectionId, identifier: &str) {
let mut data = self.data.lock();
if let Some(connection_data) = data.get_mut(&connection_id) {
connection_data.identifiers.remove(identifier);
connection_data.num_identifiers = connection_data.num_identifiers.saturating_sub(1);
if connection_data.num_identifiers == 0 {
data.remove(&connection_id);
}
}
}
pub fn contains_identifier(&self, connection_id: ConnectionId, identifier: &str) -> bool {
let data = self.data.lock();
data.get(&connection_id)
.map(|connection_data| connection_data.identifiers.contains(identifier))
.unwrap_or(false)
}
}
pub struct ReservedConnection {
connection_id: ConnectionId,
rpc_connections: Option<RpcConnections>,
}
impl ReservedConnection {
pub fn register(mut self, identifier: String) -> Option<RegisteredConnection> {
let rpc_connections = self.rpc_connections.take()?;
if rpc_connections.register_identifier(self.connection_id, identifier.clone()) {
Some(RegisteredConnection {
connection_id: self.connection_id,
identifier,
rpc_connections,
})
} else {
None
}
}
}
impl Drop for ReservedConnection {
fn drop(&mut self) {
if let Some(rpc_connections) = self.rpc_connections.take() {
rpc_connections.unreserve_space(self.connection_id);
}
}
}
pub struct RegisteredConnection {
connection_id: ConnectionId,
identifier: String,
rpc_connections: RpcConnections,
}
impl Drop for RegisteredConnection {
fn drop(&mut self) {
self.rpc_connections.unregister_identifier(self.connection_id, &self.identifier);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn reserve_space() {
let rpc_connections = RpcConnections::new(2);
let conn_id = ConnectionId(1);
let reserved = rpc_connections.reserve_space(conn_id);
assert!(reserved.is_some());
assert_eq!(1, rpc_connections.data.lock().get(&conn_id).unwrap().num_identifiers);
assert_eq!(rpc_connections.data.lock().len(), 1);
let reserved = reserved.unwrap();
let registered = reserved.register("identifier1".to_string()).unwrap();
assert!(rpc_connections.contains_identifier(conn_id, "identifier1"));
assert_eq!(1, rpc_connections.data.lock().get(&conn_id).unwrap().num_identifiers);
drop(registered);
assert!(rpc_connections.data.lock().get(&conn_id).is_none());
assert!(rpc_connections.data.lock().is_empty());
assert!(!rpc_connections.contains_identifier(conn_id, "identifier1"));
}
#[test]
fn reserve_space_capacity_reached() {
let rpc_connections = RpcConnections::new(2);
let conn_id = ConnectionId(1);
let reserved = rpc_connections.reserve_space(conn_id);
assert!(reserved.is_some());
assert_eq!(1, rpc_connections.data.lock().get(&conn_id).unwrap().num_identifiers);
let reserved = reserved.unwrap();
let registered = reserved.register("identifier1".to_string()).unwrap();
assert!(rpc_connections.contains_identifier(conn_id, "identifier1"));
assert_eq!(1, rpc_connections.data.lock().get(&conn_id).unwrap().num_identifiers);
let reserved = rpc_connections.reserve_space(conn_id);
assert!(reserved.is_some());
assert_eq!(2, rpc_connections.data.lock().get(&conn_id).unwrap().num_identifiers);
let reserved = reserved.unwrap();
let registered_second = reserved.register("identifier2".to_string()).unwrap();
assert!(rpc_connections.contains_identifier(conn_id, "identifier2"));
assert_eq!(2, rpc_connections.data.lock().get(&conn_id).unwrap().num_identifiers);
let reserved = rpc_connections.reserve_space(conn_id);
assert!(reserved.is_none());
drop(registered);
assert_eq!(1, rpc_connections.data.lock().get(&conn_id).unwrap().num_identifiers);
assert!(rpc_connections.contains_identifier(conn_id, "identifier2"));
assert!(!rpc_connections.contains_identifier(conn_id, "identifier1"));
let reserved = rpc_connections.reserve_space(conn_id);
assert!(reserved.is_some());
assert_eq!(2, rpc_connections.data.lock().get(&conn_id).unwrap().num_identifiers);
drop(reserved);
drop(registered_second);
assert!(rpc_connections.data.lock().get(&conn_id).is_none());
}
}