1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
//! State stores for rate limiters
use std::{marker::PhantomData, prelude::v1::*};
pub mod direct;
mod in_memory;
pub mod keyed;
pub use self::in_memory::InMemoryState;
use crate::nanos::Nanos;
use crate::{clock, Quota};
use crate::{
gcra::Gcra,
middleware::{NoOpMiddleware, RateLimitingMiddleware},
};
pub use direct::*;
/// A way for rate limiters to keep state.
///
/// There are two important kinds of state stores: Direct and keyed. The direct kind have only
/// one state, and are useful for "global" rate limit enforcement (e.g. a process should never
/// do more than N tasks a day). The keyed kind allows one rate limit per key (e.g. an API
/// call budget per client API key).
///
/// A direct state store is expressed as [`StateStore::Key`] = [`NotKeyed`].
/// Keyed state stores have a
/// type parameter for the key and set their key to that.
pub trait StateStore {
/// The type of key that the state store can represent.
type Key;
/// Updates a state store's rate limiting state for a given key, using the given closure.
///
/// The closure parameter takes the old value (`None` if this is the first measurement) of the
/// state store at the key's location, checks if the request an be accommodated and:
///
/// * If the request is rate-limited, returns `Err(E)`.
/// * If the request can make it through, returns `Ok(T)` (an arbitrary positive return
/// value) and the updated state.
///
/// It is `measure_and_replace`'s job then to safely replace the value at the key - it must
/// only update the value if the value hasn't changed. The implementations in this
/// crate use `AtomicU64` operations for this.
fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
where
F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>;
}
/// A rate limiter.
///
/// This is the structure that ties together the parameters (how many cells to allow in what time
/// period) and the concrete state of rate limiting decisions. This crate ships in-memory state
/// stores, but it's possible (by implementing the [`StateStore`] trait) to make others.
#[derive(Debug)]
pub struct RateLimiter<K, S, C, MW = NoOpMiddleware>
where
S: StateStore<Key = K>,
C: clock::Clock,
MW: RateLimitingMiddleware<C::Instant>,
{
state: S,
gcra: Gcra,
clock: C,
start: C::Instant,
middleware: PhantomData<MW>,
}
impl<K, S, C, MW> RateLimiter<K, S, C, MW>
where
S: StateStore<Key = K>,
C: clock::Clock,
MW: RateLimitingMiddleware<C::Instant>,
{
/// Creates a new rate limiter from components.
///
/// This is the most generic way to construct a rate-limiter; most users should prefer
/// [`direct`] or other methods instead.
pub fn new(quota: Quota, state: S, clock: &C) -> Self {
let gcra = Gcra::new(quota);
let start = clock.now();
let clock = clock.clone();
RateLimiter {
state,
clock,
gcra,
start,
middleware: PhantomData,
}
}
/// Consumes the `RateLimiter` and returns the state store.
///
/// This is mostly useful for debugging and testing.
pub fn into_state_store(self) -> S {
self.state
}
}
impl<K, S, C, MW> RateLimiter<K, S, C, MW>
where
S: StateStore<Key = K>,
C: clock::Clock,
MW: RateLimitingMiddleware<C::Instant>,
{
/// Convert the given rate limiter into one that uses a different middleware.
pub fn with_middleware<Outer: RateLimitingMiddleware<C::Instant>>(
self,
) -> RateLimiter<K, S, C, Outer> {
RateLimiter {
middleware: PhantomData,
state: self.state,
gcra: self.gcra,
clock: self.clock,
start: self.start,
}
}
}
#[cfg(feature = "std")]
impl<K, S, C, MW> RateLimiter<K, S, C, MW>
where
S: StateStore<Key = K>,
C: clock::ReasonablyRealtime,
MW: RateLimitingMiddleware<C::Instant>,
{
pub(crate) fn reference_reading(&self) -> C::Instant {
self.clock.reference_point()
}
}
#[cfg(all(feature = "std", test))]
mod test {
use super::*;
use crate::Quota;
use all_asserts::assert_gt;
use nonzero_ext::nonzero;
#[test]
fn ratelimiter_impl_coverage() {
let lim = RateLimiter::direct(Quota::per_second(nonzero!(3u32)));
assert_gt!(format!("{:?}", lim).len(), 0);
}
}