nautilus_network/ratelimiter/
mod.rs1pub mod clock;
20mod gcra;
21mod nanos;
22pub mod quota;
23
24use std::{
25 fmt::Debug,
26 hash::Hash,
27 num::NonZeroU64,
28 sync::atomic::{AtomicU64, Ordering},
29 time::Duration,
30};
31
32use dashmap::DashMap;
33use futures_util::StreamExt;
34
35use self::{
36 clock::{Clock, FakeRelativeClock, MonotonicClock},
37 gcra::{Gcra, NotUntil},
38 nanos::Nanos,
39 quota::Quota,
40};
41
42#[derive(Debug, Default)]
51pub struct InMemoryState(AtomicU64);
52
53impl InMemoryState {
54 pub(crate) fn measure_and_replace_one<T, F, E>(&self, mut f: F) -> Result<T, E>
60 where
61 F: FnMut(Option<Nanos>) -> Result<(T, Nanos), E>,
62 {
63 let mut prev = self.0.load(Ordering::Acquire);
64 let mut decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
65 while let Ok((result, new_data)) = decision {
66 match self.0.compare_exchange_weak(
69 prev,
70 new_data.into(),
71 Ordering::Release,
72 Ordering::Relaxed,
73 ) {
74 Ok(_) => return Ok(result),
75 Err(e) => prev = e, }
77 decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
78 }
79 decision.map(|(result, _)| result)
82 }
83}
84
85pub type DashMapStateStore<K> = DashMap<K, InMemoryState>;
87
88pub trait StateStore {
99 type Key;
101
102 fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
119 where
120 F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>;
121}
122
123impl<K: Hash + Eq + Clone> StateStore for DashMapStateStore<K> {
124 type Key = K;
125
126 fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
127 where
128 F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>,
129 {
130 if let Some(v) = self.get(key) {
131 return v.measure_and_replace_one(f);
133 }
134 let entry = self.entry(key.clone()).or_default();
136 (*entry).measure_and_replace_one(f)
137 }
138}
139
140pub struct RateLimiter<K, C>
145where
146 C: Clock,
147{
148 default_gcra: Option<Gcra>,
149 state: DashMapStateStore<K>,
150 gcra: DashMap<K, Gcra>,
151 clock: C,
152 start: C::Instant,
153}
154
155impl<K, C> Debug for RateLimiter<K, C>
156where
157 K: Debug,
158 C: Clock,
159{
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 f.debug_struct(stringify!(RateLimiter)).finish()
162 }
163}
164
165impl<K> RateLimiter<K, MonotonicClock>
166where
167 K: Eq + Hash,
168{
169 #[must_use]
174 pub fn new_with_quota(base_quota: Option<Quota>, keyed_quotas: Vec<(K, Quota)>) -> Self {
175 let clock = MonotonicClock {};
176 let start = MonotonicClock::now(&clock);
177 let gcra: DashMap<_, _> = keyed_quotas
178 .into_iter()
179 .map(|(k, q)| (k, Gcra::new(q)))
180 .collect();
181 Self {
182 default_gcra: base_quota.map(Gcra::new),
183 state: DashMapStateStore::new(),
184 gcra,
185 clock,
186 start,
187 }
188 }
189}
190
191impl<K> RateLimiter<K, FakeRelativeClock>
192where
193 K: Hash + Eq + Clone,
194{
195 pub fn advance_clock(&self, by: Duration) {
199 self.clock.advance(by);
200 }
201}
202
203impl<K, C> RateLimiter<K, C>
204where
205 K: Hash + Eq + Clone,
206 C: Clock,
207{
208 pub fn add_quota_for_key(&self, key: K, value: Quota) {
210 self.gcra.insert(key, Gcra::new(value));
211 }
212
213 pub fn check_key(&self, key: &K) -> Result<(), NotUntil<C::Instant>> {
219 match self.gcra.get(key) {
220 Some(quota) => quota.test_and_update(self.start, key, &self.state, self.clock.now()),
221 None => self.default_gcra.as_ref().map_or(Ok(()), |gcra| {
222 gcra.test_and_update(self.start, key, &self.state, self.clock.now())
223 }),
224 }
225 }
226
227 pub async fn until_key_ready(&self, key: &K) {
229 loop {
230 match self.check_key(key) {
231 Ok(()) => {
232 break;
233 }
234 Err(e) => {
235 tokio::time::sleep(e.wait_time_from(self.clock.now())).await;
236 }
237 }
238 }
239 }
240
241 pub async fn await_keys_ready(&self, keys: Option<&[K]>) {
245 let Some(keys) = keys else {
246 return;
247 };
248
249 let tasks = keys.iter().map(|key| self.until_key_ready(key));
250
251 futures::stream::iter(tasks)
252 .for_each_concurrent(None, |key_future| async move {
253 key_future.await;
254 })
255 .await;
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use std::{num::NonZeroU32, time::Duration};
262
263 use dashmap::DashMap;
264 use rstest::rstest;
265
266 use super::{
267 DashMapStateStore, RateLimiter,
268 clock::{Clock, FakeRelativeClock},
269 gcra::Gcra,
270 quota::Quota,
271 };
272
273 fn initialize_mock_rate_limiter() -> RateLimiter<String, FakeRelativeClock> {
274 let clock = FakeRelativeClock::default();
275 let start = clock.now();
276 let gcra = DashMap::new();
277 let base_quota = Quota::per_second(NonZeroU32::new(2).unwrap());
278 RateLimiter {
279 default_gcra: Some(Gcra::new(base_quota)),
280 state: DashMapStateStore::new(),
281 gcra,
282 clock,
283 start,
284 }
285 }
286
287 #[rstest]
288 fn test_default_quota() {
289 let mock_limiter = initialize_mock_rate_limiter();
290
291 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
293 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
294
295 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
297
298 mock_limiter.advance_clock(Duration::from_secs(1));
300 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
301 }
302
303 #[rstest]
304 fn test_custom_key_quota() {
305 let mock_limiter = initialize_mock_rate_limiter();
306
307 mock_limiter.add_quota_for_key(
309 "custom".to_string(),
310 Quota::per_second(NonZeroU32::new(1).unwrap()),
311 );
312
313 assert!(mock_limiter.check_key(&"custom".to_string()).is_ok());
315 assert!(mock_limiter.check_key(&"custom".to_string()).is_err());
316
317 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
319 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
320 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
321 }
322
323 #[rstest]
324 fn test_multiple_keys() {
325 let mock_limiter = initialize_mock_rate_limiter();
326
327 mock_limiter.add_quota_for_key(
328 "key1".to_string(),
329 Quota::per_second(NonZeroU32::new(1).unwrap()),
330 );
331 mock_limiter.add_quota_for_key(
332 "key2".to_string(),
333 Quota::per_second(NonZeroU32::new(3).unwrap()),
334 );
335
336 assert!(mock_limiter.check_key(&"key1".to_string()).is_ok());
338 assert!(mock_limiter.check_key(&"key1".to_string()).is_err());
339
340 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
342 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
343 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
344 assert!(mock_limiter.check_key(&"key2".to_string()).is_err());
345 }
346
347 #[rstest]
348 fn test_quota_reset() {
349 let mock_limiter = initialize_mock_rate_limiter();
350
351 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
353 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
354 assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
355
356 mock_limiter.advance_clock(Duration::from_millis(499));
358 assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
359
360 mock_limiter.advance_clock(Duration::from_millis(501));
362 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
363 }
364
365 #[rstest]
366 fn test_different_quotas() {
367 let mock_limiter = initialize_mock_rate_limiter();
368
369 mock_limiter.add_quota_for_key(
370 "per_second".to_string(),
371 Quota::per_second(NonZeroU32::new(2).unwrap()),
372 );
373 mock_limiter.add_quota_for_key(
374 "per_minute".to_string(),
375 Quota::per_minute(NonZeroU32::new(3).unwrap()),
376 );
377
378 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
380 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
381 assert!(mock_limiter.check_key(&"per_second".to_string()).is_err());
382
383 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
385 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
386 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
387 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
388
389 mock_limiter.advance_clock(Duration::from_secs(1));
391 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
392 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
393 }
394
395 #[tokio::test]
396 async fn test_await_keys_ready() {
397 let mock_limiter = initialize_mock_rate_limiter();
398
399 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
401 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
402
403 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
405
406 mock_limiter.advance_clock(Duration::from_secs(1));
408 let keys = ["default".to_string()];
409 mock_limiter.await_keys_ready(Some(keys.as_slice())).await;
410 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
411 }
412
413 #[rstest]
414 fn test_gcra_boundary_exact_replenishment() {
415 let mock_limiter = initialize_mock_rate_limiter();
418 let key = "boundary_test".to_string();
419
420 assert!(mock_limiter.check_key(&key).is_ok());
422 assert!(mock_limiter.check_key(&key).is_ok());
423
424 assert!(mock_limiter.check_key(&key).is_err());
426
427 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
429 let replenish_interval = quota.replenish_interval();
430 mock_limiter.advance_clock(replenish_interval);
431
432 assert!(
434 mock_limiter.check_key(&key).is_ok(),
435 "Request at exact replenish boundary should be allowed"
436 );
437
438 assert!(
440 mock_limiter.check_key(&key).is_err(),
441 "Immediate follow-up should be rate-limited"
442 );
443 }
444}