1use crate::types::SenderId;
24use parking_lot::{Mutex, RwLock};
25use std::{
26 collections::HashMap,
27 sync::Arc,
28 time::{Duration, Instant},
29};
30
31const STALE_ENTRY_TTL: Duration = Duration::from_secs(3600);
33
34#[derive(Debug, Clone)]
36struct TokenBucket {
37 tokens: f64,
38 capacity: f64,
39 refill_per_sec: f64,
40 last: Instant,
41}
42
43impl TokenBucket {
44 fn new(capacity: f64, refill_per_sec: f64) -> Self {
45 Self { tokens: capacity, capacity, refill_per_sec, last: Instant::now() }
46 }
47
48 fn refill(&mut self, now: Instant) {
50 let elapsed = now.saturating_duration_since(self.last).as_secs_f64();
51 if elapsed > 0.0 {
52 self.tokens = (self.tokens + elapsed * self.refill_per_sec).min(self.capacity);
53 self.last = now;
54 }
55 }
56
57 fn try_consume(&mut self, n: f64, now: Instant) -> Result<(), Duration> {
60 self.refill(now);
61 if self.tokens >= n {
62 self.tokens -= n;
63 Ok(())
64 } else {
65 let deficit = n - self.tokens;
66 let secs =
67 if self.refill_per_sec > 0.0 { deficit / self.refill_per_sec } else { f64::MAX };
68 Err(Duration::from_secs_f64(secs.clamp(0.0, u64::MAX as f64)))
69 }
70 }
71}
72
73#[derive(Debug)]
74struct UserRateState {
75 requests: TokenBucket,
76 bandwidth: TokenBucket,
77 last_touch: Instant,
78}
79
80#[derive(Debug, Clone)]
82pub struct RateLimitConfig {
83 pub enabled: bool,
85 pub submit_rate_per_min: u32,
87 pub submit_burst: u32,
89 pub bandwidth_per_min: u64,
91 pub bandwidth_burst: u64,
93}
94
95impl RateLimitConfig {
96 pub fn disabled() -> Self {
98 Self {
99 enabled: false,
100 submit_rate_per_min: 0,
101 submit_burst: 0,
102 bandwidth_per_min: 0,
103 bandwidth_burst: 0,
104 }
105 }
106}
107
108pub struct RateLimiter {
110 cfg: RateLimitConfig,
111 users: RwLock<HashMap<SenderId, Arc<Mutex<UserRateState>>>>,
112}
113
114impl RateLimiter {
115 pub fn new(cfg: RateLimitConfig) -> Self {
117 Self { cfg, users: RwLock::new(HashMap::new()) }
118 }
119
120 fn new_state(&self, now: Instant) -> UserRateState {
121 let requests = TokenBucket::new(
122 self.cfg.submit_burst as f64,
123 self.cfg.submit_rate_per_min as f64 / 60.0,
124 );
125 let bandwidth = TokenBucket::new(
126 self.cfg.bandwidth_burst as f64,
127 self.cfg.bandwidth_per_min as f64 / 60.0,
128 );
129 UserRateState { requests, bandwidth, last_touch: now }
130 }
131
132 fn get_or_create(&self, sender_id: &SenderId, now: Instant) -> Arc<Mutex<UserRateState>> {
133 if let Some(state) = self.users.read().get(sender_id).cloned() {
134 return state;
135 }
136 let mut users = self.users.write();
137 users
138 .entry(*sender_id)
139 .or_insert_with(|| Arc::new(Mutex::new(self.new_state(now))))
140 .clone()
141 }
142
143 pub fn check(&self, sender_id: &SenderId, data_len: u64) -> Result<(), u64> {
148 if !self.cfg.enabled {
149 return Ok(());
150 }
151
152 let now = Instant::now();
153 let state = self.get_or_create(sender_id, now);
154 let mut state = state.lock();
155 state.last_touch = now;
156
157 let req_wait = state.requests.try_consume(1.0, now).err();
158 if let Some(wait) = req_wait {
159 return Err(wait.as_secs().max(1));
160 }
161
162 if let Err(wait) = state.bandwidth.try_consume(data_len as f64, now) {
165 state.requests.tokens = (state.requests.tokens + 1.0).min(state.requests.capacity);
167 return Err(wait.as_secs().max(1));
168 }
169
170 Ok(())
171 }
172
173 pub fn evict_stale(&self) {
176 if !self.cfg.enabled {
177 return;
178 }
179 let now = Instant::now();
180 let mut users = self.users.write();
181 users.retain(|_, state| {
182 let state = state.lock();
183 now.saturating_duration_since(state.last_touch) < STALE_ENTRY_TTL
184 });
185 }
186
187 pub fn tracked_senders(&self) -> usize {
189 self.users.read().len()
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 const SENDER_A: SenderId = [1u8; 32];
198 const SENDER_B: SenderId = [2u8; 32];
199
200 fn test_cfg() -> RateLimitConfig {
201 RateLimitConfig {
202 enabled: true,
203 submit_rate_per_min: 60,
204 submit_burst: 3,
205 bandwidth_per_min: 6_000,
206 bandwidth_burst: 6_000,
207 }
208 }
209
210 #[test]
211 fn disabled_admits_everything() {
212 let rl = RateLimiter::new(RateLimitConfig::disabled());
213 for _ in 0..100 {
214 rl.check(&SENDER_A, 1_000_000_000).unwrap();
215 }
216 }
217
218 #[test]
219 fn burst_then_limited() {
220 let rl = RateLimiter::new(test_cfg());
221 rl.check(&SENDER_A, 100).unwrap();
223 rl.check(&SENDER_A, 100).unwrap();
224 rl.check(&SENDER_A, 100).unwrap();
225 let err = rl.check(&SENDER_A, 100).unwrap_err();
227 assert!(err >= 1);
228 }
229
230 #[test]
231 fn bandwidth_exhaustion_limits() {
232 let rl = RateLimiter::new(test_cfg());
233 rl.check(&SENDER_A, 6_000).unwrap();
235 assert!(rl.check(&SENDER_A, 1).is_err());
237 }
238
239 #[test]
240 fn isolated_per_sender() {
241 let rl = RateLimiter::new(test_cfg());
242 for _ in 0..3 {
243 rl.check(&SENDER_A, 100).unwrap();
244 }
245 assert!(rl.check(&SENDER_A, 100).is_err());
247 rl.check(&SENDER_B, 100).unwrap();
248 }
249
250 #[test]
251 fn refills_over_time() {
252 let cfg = RateLimitConfig {
253 enabled: true,
254 submit_rate_per_min: 60,
255 submit_burst: 1,
256 bandwidth_per_min: 600_000,
257 bandwidth_burst: 600_000,
258 };
259 let rl = RateLimiter::new(cfg);
260 rl.check(&SENDER_A, 100).unwrap();
261 assert!(rl.check(&SENDER_A, 100).is_err());
262
263 {
265 let state = rl.get_or_create(&SENDER_A, Instant::now());
266 let mut state = state.lock();
267 state.requests.last -= Duration::from_secs(2);
268 }
269 rl.check(&SENDER_A, 100).unwrap();
271 }
272
273 #[test]
274 fn evict_stale_removes_untouched() {
275 let rl = RateLimiter::new(test_cfg());
276 rl.check(&SENDER_A, 100).unwrap();
277 assert_eq!(rl.tracked_senders(), 1);
278
279 {
281 let state = rl.get_or_create(&SENDER_A, Instant::now());
282 let mut state = state.lock();
283 state.last_touch -= STALE_ENTRY_TTL + Duration::from_secs(1);
284 }
285 rl.evict_stale();
286 assert_eq!(rl.tracked_senders(), 0);
287 }
288}