nautilus_network/ratelimiter/
mod.rs1pub mod clock;
20mod gcra;
21mod nanos;
22pub mod quota;
23
24use std::{
25 fmt::Debug,
26 hash::Hash,
27 num::NonZeroU64,
28 sync::atomic::{AtomicU64, Ordering},
29 time::Duration,
30};
31
32use dashmap::DashMap;
33use futures_util::StreamExt;
34use tokio::time::sleep;
35
36use self::{
37 clock::{Clock, FakeRelativeClock, MonotonicClock},
38 gcra::{Gcra, NotUntil},
39 nanos::Nanos,
40 quota::Quota,
41};
42
43#[derive(Debug, Default)]
52pub struct InMemoryState(AtomicU64);
53
54impl InMemoryState {
55 pub(crate) fn measure_and_replace_one<T, F, E>(&self, mut f: F) -> Result<T, E>
61 where
62 F: FnMut(Option<Nanos>) -> Result<(T, Nanos), E>,
63 {
64 let mut prev = self.0.load(Ordering::Acquire);
65 let mut decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
66 while let Ok((result, new_data)) = decision {
67 match self.0.compare_exchange_weak(
70 prev,
71 new_data.into(),
72 Ordering::Release,
73 Ordering::Relaxed,
74 ) {
75 Ok(_) => return Ok(result),
76 Err(next_prev) => prev = next_prev, }
78 decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
79 }
80 decision.map(|(result, _)| result)
83 }
84}
85
86pub type DashMapStateStore<K> = DashMap<K, InMemoryState>;
88
89pub trait StateStore {
100 type Key;
102
103 fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
120 where
121 F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>;
122}
123
124impl<K: Hash + Eq + Clone> StateStore for DashMapStateStore<K> {
125 type Key = K;
126
127 fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
128 where
129 F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>,
130 {
131 if let Some(v) = self.get(key) {
132 return v.measure_and_replace_one(f);
134 }
135 let entry = self.entry(key.clone()).or_default();
137 (*entry).measure_and_replace_one(f)
138 }
139}
140
141pub struct RateLimiter<K, C>
146where
147 C: Clock,
148{
149 default_gcra: Option<Gcra>,
150 state: DashMapStateStore<K>,
151 gcra: DashMap<K, Gcra>,
152 clock: C,
153 start: C::Instant,
154}
155
156impl<K, C> Debug for RateLimiter<K, C>
157where
158 K: Debug,
159 C: Clock,
160{
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 f.debug_struct(stringify!(RateLimiter)).finish()
163 }
164}
165
166impl<K> RateLimiter<K, MonotonicClock>
167where
168 K: Eq + Hash,
169{
170 #[must_use]
175 pub fn new_with_quota(base_quota: Option<Quota>, keyed_quotas: Vec<(K, Quota)>) -> Self {
176 let clock = MonotonicClock {};
177 let start = MonotonicClock::now(&clock);
178 let gcra = DashMap::from_iter(keyed_quotas.into_iter().map(|(k, q)| (k, Gcra::new(q))));
179 Self {
180 default_gcra: base_quota.map(Gcra::new),
181 state: DashMapStateStore::new(),
182 gcra,
183 clock,
184 start,
185 }
186 }
187}
188
189impl<K> RateLimiter<K, FakeRelativeClock>
190where
191 K: Hash + Eq + Clone,
192{
193 pub fn advance_clock(&self, by: Duration) {
197 self.clock.advance(by);
198 }
199}
200
201impl<K, C> RateLimiter<K, C>
202where
203 K: Hash + Eq + Clone,
204 C: Clock,
205{
206 pub fn add_quota_for_key(&self, key: K, value: Quota) {
208 self.gcra.insert(key, Gcra::new(value));
209 }
210
211 pub fn check_key(&self, key: &K) -> Result<(), NotUntil<C::Instant>> {
217 match self.gcra.get(key) {
218 Some(quota) => quota.test_and_update(self.start, key, &self.state, self.clock.now()),
219 None => self.default_gcra.as_ref().map_or(Ok(()), |gcra| {
220 gcra.test_and_update(self.start, key, &self.state, self.clock.now())
221 }),
222 }
223 }
224
225 pub async fn until_key_ready(&self, key: &K) {
227 loop {
228 match self.check_key(key) {
229 Ok(()) => {
230 break;
231 }
232 Err(neg) => {
233 sleep(neg.wait_time_from(self.clock.now())).await;
234 }
235 }
236 }
237 }
238
239 pub async fn await_keys_ready(&self, keys: Option<Vec<K>>) {
243 let keys = keys.unwrap_or_default();
244 let tasks = keys.iter().map(|key| self.until_key_ready(key));
245
246 futures::stream::iter(tasks)
247 .for_each_concurrent(None, |key_future| async move {
248 key_future.await;
249 })
250 .await;
251 }
252}
253
254#[cfg(test)]
258mod tests {
259 use std::{num::NonZeroU32, time::Duration};
260
261 use dashmap::DashMap;
262 use rstest::rstest;
263
264 use super::{
265 DashMapStateStore, RateLimiter,
266 clock::{Clock, FakeRelativeClock},
267 gcra::Gcra,
268 quota::Quota,
269 };
270
271 fn initialize_mock_rate_limiter() -> RateLimiter<String, FakeRelativeClock> {
272 let clock = FakeRelativeClock::default();
273 let start = clock.now();
274 let gcra = DashMap::new();
275 let base_quota = Quota::per_second(NonZeroU32::new(2).unwrap());
276 RateLimiter {
277 default_gcra: Some(Gcra::new(base_quota)),
278 state: DashMapStateStore::new(),
279 gcra,
280 clock,
281 start,
282 }
283 }
284
285 #[rstest]
286 fn test_default_quota() {
287 let mock_limiter = initialize_mock_rate_limiter();
288
289 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
291 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
292
293 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
295
296 mock_limiter.advance_clock(Duration::from_secs(1));
298 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
299 }
300
301 #[rstest]
302 fn test_custom_key_quota() {
303 let mock_limiter = initialize_mock_rate_limiter();
304
305 mock_limiter.add_quota_for_key(
307 "custom".to_string(),
308 Quota::per_second(NonZeroU32::new(1).unwrap()),
309 );
310
311 assert!(mock_limiter.check_key(&"custom".to_string()).is_ok());
313 assert!(mock_limiter.check_key(&"custom".to_string()).is_err());
314
315 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
317 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
318 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
319 }
320
321 #[rstest]
322 fn test_multiple_keys() {
323 let mock_limiter = initialize_mock_rate_limiter();
324
325 mock_limiter.add_quota_for_key(
326 "key1".to_string(),
327 Quota::per_second(NonZeroU32::new(1).unwrap()),
328 );
329 mock_limiter.add_quota_for_key(
330 "key2".to_string(),
331 Quota::per_second(NonZeroU32::new(3).unwrap()),
332 );
333
334 assert!(mock_limiter.check_key(&"key1".to_string()).is_ok());
336 assert!(mock_limiter.check_key(&"key1".to_string()).is_err());
337
338 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
340 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
341 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
342 assert!(mock_limiter.check_key(&"key2".to_string()).is_err());
343 }
344
345 #[rstest]
346 fn test_quota_reset() {
347 let mock_limiter = initialize_mock_rate_limiter();
348
349 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
351 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
352 assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
353
354 mock_limiter.advance_clock(Duration::from_millis(499));
356 assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
357
358 mock_limiter.advance_clock(Duration::from_millis(501));
360 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
361 }
362
363 #[rstest]
364 fn test_different_quotas() {
365 let mock_limiter = initialize_mock_rate_limiter();
366
367 mock_limiter.add_quota_for_key(
368 "per_second".to_string(),
369 Quota::per_second(NonZeroU32::new(2).unwrap()),
370 );
371 mock_limiter.add_quota_for_key(
372 "per_minute".to_string(),
373 Quota::per_minute(NonZeroU32::new(3).unwrap()),
374 );
375
376 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
378 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
379 assert!(mock_limiter.check_key(&"per_second".to_string()).is_err());
380
381 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
383 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
384 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
385 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
386
387 mock_limiter.advance_clock(Duration::from_secs(1));
389 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
390 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
391 }
392
393 #[tokio::test]
394 async fn test_await_keys_ready() {
395 let mock_limiter = initialize_mock_rate_limiter();
396
397 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
399 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
400
401 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
403
404 mock_limiter.advance_clock(Duration::from_secs(1));
406 mock_limiter
407 .await_keys_ready(Some(vec!["default".to_string()]))
408 .await;
409 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
410 }
411
412 #[rstest]
413 fn test_gcra_boundary_exact_replenishment() {
414 let mock_limiter = initialize_mock_rate_limiter();
417 let key = "boundary_test".to_string();
418
419 assert!(mock_limiter.check_key(&key).is_ok());
421 assert!(mock_limiter.check_key(&key).is_ok());
422
423 assert!(mock_limiter.check_key(&key).is_err());
425
426 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
428 let replenish_interval = quota.replenish_interval();
429 mock_limiter.advance_clock(replenish_interval);
430
431 assert!(
433 mock_limiter.check_key(&key).is_ok(),
434 "Request at exact replenish boundary should be allowed"
435 );
436
437 assert!(
439 mock_limiter.check_key(&key).is_err(),
440 "Immediate follow-up should be rate-limited"
441 );
442 }
443}