nautilus_network/ratelimiter/
mod.rs1#![allow(clippy::missing_errors_doc)] pub mod clock;
22mod gcra;
23mod nanos;
24pub mod quota;
25
26use std::{
27 hash::Hash,
28 num::NonZeroU64,
29 sync::atomic::{AtomicU64, Ordering},
30 time::Duration,
31};
32
33use dashmap::DashMap;
34use futures_util::StreamExt;
35use tokio::time::sleep;
36
37use self::{
38 clock::{Clock, FakeRelativeClock, MonotonicClock},
39 gcra::{Gcra, NotUntil},
40 nanos::Nanos,
41 quota::Quota,
42};
43
44#[derive(Default)]
53pub struct InMemoryState(AtomicU64);
54
55impl InMemoryState {
56 pub(crate) fn measure_and_replace_one<T, F, E>(&self, mut f: F) -> Result<T, E>
57 where
58 F: FnMut(Option<Nanos>) -> Result<(T, Nanos), E>,
59 {
60 let mut prev = self.0.load(Ordering::Acquire);
61 let mut decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
62 while let Ok((result, new_data)) = decision {
63 match self.0.compare_exchange_weak(
64 prev,
65 new_data.into(),
66 Ordering::Release,
67 Ordering::Relaxed,
68 ) {
69 Ok(_) => return Ok(result),
70 Err(next_prev) => prev = next_prev,
71 }
72 decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
73 }
74 decision.map(|(result, _)| result)
77 }
78}
79
80pub type DashMapStateStore<K> = DashMap<K, InMemoryState>;
82
83pub trait StateStore {
94 type Key;
96
97 fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
110 where
111 F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>;
112}
113
114impl<K: Hash + Eq + Clone> StateStore for DashMapStateStore<K> {
115 type Key = K;
116
117 fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
118 where
119 F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>,
120 {
121 if let Some(v) = self.get(key) {
122 return v.measure_and_replace_one(f);
124 }
125 let entry = self.entry(key.clone()).or_default();
127 (*entry).measure_and_replace_one(f)
128 }
129}
130
131pub struct RateLimiter<K, C>
132where
133 C: Clock,
134{
135 default_gcra: Option<Gcra>,
136 state: DashMapStateStore<K>,
137 gcra: DashMap<K, Gcra>,
138 clock: C,
139 start: C::Instant,
140}
141
142impl<K> RateLimiter<K, MonotonicClock>
143where
144 K: Eq + Hash,
145{
146 pub fn new_with_quota(base_quota: Option<Quota>, keyed_quotas: Vec<(K, Quota)>) -> Self {
147 let clock = MonotonicClock {};
148 let start = MonotonicClock::now(&clock);
149 let gcra = DashMap::from_iter(keyed_quotas.into_iter().map(|(k, q)| (k, Gcra::new(q))));
150 Self {
151 default_gcra: base_quota.map(Gcra::new),
152 state: DashMapStateStore::new(),
153 gcra,
154 clock,
155 start,
156 }
157 }
158}
159
160impl<K> RateLimiter<K, FakeRelativeClock>
161where
162 K: Hash + Eq + Clone,
163{
164 pub fn advance_clock(&self, by: Duration) {
165 self.clock.advance(by);
166 }
167}
168
169impl<K, C> RateLimiter<K, C>
170where
171 K: Hash + Eq + Clone,
172 C: Clock,
173{
174 pub fn add_quota_for_key(&self, key: K, value: Quota) {
175 self.gcra.insert(key, Gcra::new(value));
176 }
177
178 pub fn check_key(&self, key: &K) -> Result<(), NotUntil<C::Instant>> {
179 match self.gcra.get(key) {
180 Some(quota) => quota.test_and_update(self.start, key, &self.state, self.clock.now()),
181 None => self.default_gcra.as_ref().map_or(Ok(()), |gcra| {
182 gcra.test_and_update(self.start, key, &self.state, self.clock.now())
183 }),
184 }
185 }
186
187 pub async fn until_key_ready(&self, key: &K) {
188 loop {
189 match self.check_key(key) {
190 Ok(()) => {
191 break;
192 }
193 Err(neg) => {
194 sleep(neg.wait_time_from(self.clock.now())).await;
195 }
196 }
197 }
198 }
199
200 pub async fn await_keys_ready(&self, keys: Option<Vec<K>>) {
201 let keys = keys.unwrap_or_default();
202 let tasks = keys.iter().map(|key| self.until_key_ready(key));
203
204 futures::stream::iter(tasks)
205 .for_each_concurrent(None, |key_future| async move {
206 key_future.await;
207 })
208 .await;
209 }
210}
211
212#[cfg(test)]
216mod tests {
217 use std::{num::NonZeroU32, time::Duration};
218
219 use dashmap::DashMap;
220
221 use super::{
222 DashMapStateStore, RateLimiter,
223 clock::{Clock, FakeRelativeClock},
224 gcra::Gcra,
225 quota::Quota,
226 };
227
228 fn initialize_mock_rate_limiter() -> RateLimiter<String, FakeRelativeClock> {
229 let clock = FakeRelativeClock::default();
230 let start = clock.now();
231 let gcra = DashMap::new();
232 let base_quota = Quota::per_second(NonZeroU32::new(2).unwrap());
233 RateLimiter {
234 default_gcra: Some(Gcra::new(base_quota)),
235 state: DashMapStateStore::new(),
236 gcra,
237 clock,
238 start,
239 }
240 }
241
242 #[test]
243 fn test_default_quota() {
244 let mock_limiter = initialize_mock_rate_limiter();
245
246 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
248 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
249
250 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
252
253 mock_limiter.advance_clock(Duration::from_secs(1));
255 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
256 }
257
258 #[test]
259 fn test_custom_key_quota() {
260 let mock_limiter = initialize_mock_rate_limiter();
261
262 mock_limiter.add_quota_for_key(
264 "custom".to_string(),
265 Quota::per_second(NonZeroU32::new(1).unwrap()),
266 );
267
268 assert!(mock_limiter.check_key(&"custom".to_string()).is_ok());
270 assert!(mock_limiter.check_key(&"custom".to_string()).is_err());
271
272 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
274 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
275 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
276 }
277
278 #[test]
279 fn test_multiple_keys() {
280 let mock_limiter = initialize_mock_rate_limiter();
281
282 mock_limiter.add_quota_for_key(
283 "key1".to_string(),
284 Quota::per_second(NonZeroU32::new(1).unwrap()),
285 );
286 mock_limiter.add_quota_for_key(
287 "key2".to_string(),
288 Quota::per_second(NonZeroU32::new(3).unwrap()),
289 );
290
291 assert!(mock_limiter.check_key(&"key1".to_string()).is_ok());
293 assert!(mock_limiter.check_key(&"key1".to_string()).is_err());
294
295 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
297 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
298 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
299 assert!(mock_limiter.check_key(&"key2".to_string()).is_err());
300 }
301
302 #[test]
303 fn test_quota_reset() {
304 let mock_limiter = initialize_mock_rate_limiter();
305
306 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
308 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
309 assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
310
311 mock_limiter.advance_clock(Duration::from_millis(499));
313 assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
314
315 mock_limiter.advance_clock(Duration::from_millis(501));
317 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
318 }
319
320 #[test]
321 fn test_different_quotas() {
322 let mock_limiter = initialize_mock_rate_limiter();
323
324 mock_limiter.add_quota_for_key(
325 "per_second".to_string(),
326 Quota::per_second(NonZeroU32::new(2).unwrap()),
327 );
328 mock_limiter.add_quota_for_key(
329 "per_minute".to_string(),
330 Quota::per_minute(NonZeroU32::new(3).unwrap()),
331 );
332
333 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
335 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
336 assert!(mock_limiter.check_key(&"per_second".to_string()).is_err());
337
338 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
340 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
341 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
342 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
343
344 mock_limiter.advance_clock(Duration::from_secs(1));
346 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
347 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
348 }
349
350 #[tokio::test]
351 async fn test_await_keys_ready() {
352 let mock_limiter = initialize_mock_rate_limiter();
353
354 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
356 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
357
358 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
360
361 mock_limiter.advance_clock(Duration::from_secs(1));
363 mock_limiter
364 .await_keys_ready(Some(vec!["default".to_string()]))
365 .await;
366 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
367 }
368}