nautilus_network/ratelimiter/
mod.rs1pub mod clock;
20mod gcra;
21mod nanos;
22pub mod quota;
23
24use std::{
25 fmt::Debug,
26 hash::Hash,
27 num::NonZeroU64,
28 sync::atomic::{AtomicU64, Ordering},
29 time::Duration,
30};
31
32use dashmap::DashMap;
33use futures_util::StreamExt;
34
35use self::{
36 clock::{Clock, FakeRelativeClock, MonotonicClock},
37 gcra::{Gcra, NotUntil},
38 nanos::Nanos,
39 quota::Quota,
40};
41
42#[derive(Debug, Default)]
51pub struct InMemoryState(AtomicU64);
52
53impl InMemoryState {
54 pub(crate) fn measure_and_replace_one<T, F, E>(&self, mut f: F) -> Result<T, E>
60 where
61 F: FnMut(Option<Nanos>) -> Result<(T, Nanos), E>,
62 {
63 let mut prev = self.0.load(Ordering::Acquire);
64 let mut decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
65 while let Ok((result, new_data)) = decision {
66 match self.0.compare_exchange_weak(
69 prev,
70 new_data.into(),
71 Ordering::Release,
72 Ordering::Relaxed,
73 ) {
74 Ok(_) => return Ok(result),
75 Err(e) => prev = e, }
77 decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
78 }
79 decision.map(|(result, _)| result)
82 }
83}
84
85pub type DashMapStateStore<K> = DashMap<K, InMemoryState>;
87
88pub trait StateStore {
99 type Key;
101
102 fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
119 where
120 F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>;
121}
122
123impl<K: Hash + Eq + Clone> StateStore for DashMapStateStore<K> {
124 type Key = K;
125
126 fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
127 where
128 F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>,
129 {
130 if let Some(v) = self.get(key) {
131 return v.measure_and_replace_one(f);
133 }
134 let entry = self.entry(key.clone()).or_default();
136 (*entry).measure_and_replace_one(f)
137 }
138}
139
140pub struct RateLimiter<K, C>
145where
146 C: Clock,
147{
148 default_gcra: Option<Gcra>,
149 state: DashMapStateStore<K>,
150 gcra: DashMap<K, Gcra>,
151 clock: C,
152 start: C::Instant,
153}
154
155impl<K, C> Debug for RateLimiter<K, C>
156where
157 K: Debug,
158 C: Clock,
159{
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 f.debug_struct(stringify!(RateLimiter)).finish()
162 }
163}
164
165impl<K> RateLimiter<K, MonotonicClock>
166where
167 K: Eq + Hash,
168{
169 #[must_use]
174 pub fn new_with_quota(base_quota: Option<Quota>, keyed_quotas: Vec<(K, Quota)>) -> Self {
175 let clock = MonotonicClock {};
176 let start = MonotonicClock::now(&clock);
177 let gcra: DashMap<_, _> = keyed_quotas
178 .into_iter()
179 .map(|(k, q)| (k, Gcra::new(q)))
180 .collect();
181 Self {
182 default_gcra: base_quota.map(Gcra::new),
183 state: DashMapStateStore::new(),
184 gcra,
185 clock,
186 start,
187 }
188 }
189}
190
191impl<K> RateLimiter<K, FakeRelativeClock>
192where
193 K: Hash + Eq + Clone,
194{
195 pub fn advance_clock(&self, by: Duration) {
199 self.clock.advance(by);
200 }
201}
202
203impl<K, C> RateLimiter<K, C>
204where
205 K: Hash + Eq + Clone,
206 C: Clock,
207{
208 pub fn add_quota_for_key(&self, key: K, value: Quota) {
210 self.gcra.insert(key, Gcra::new(value));
211 }
212
213 pub fn check_key(&self, key: &K) -> Result<(), NotUntil<C::Instant>> {
219 match self.gcra.get(key) {
220 Some(quota) => quota.test_and_update(self.start, key, &self.state, self.clock.now()),
221 None => self.default_gcra.as_ref().map_or(Ok(()), |gcra| {
222 gcra.test_and_update(self.start, key, &self.state, self.clock.now())
223 }),
224 }
225 }
226
227 pub async fn until_key_ready(&self, key: &K) {
229 loop {
230 match self.check_key(key) {
231 Ok(()) => {
232 break;
233 }
234 Err(e) => {
235 tokio::time::sleep(e.wait_time_from(self.clock.now())).await;
236 }
237 }
238 }
239 }
240
241 pub async fn await_keys_ready(&self, keys: Option<Vec<K>>) {
245 let keys = keys.unwrap_or_default();
246 let tasks = keys.iter().map(|key| self.until_key_ready(key));
247
248 futures::stream::iter(tasks)
249 .for_each_concurrent(None, |key_future| async move {
250 key_future.await;
251 })
252 .await;
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use std::{num::NonZeroU32, time::Duration};
259
260 use dashmap::DashMap;
261 use rstest::rstest;
262
263 use super::{
264 DashMapStateStore, RateLimiter,
265 clock::{Clock, FakeRelativeClock},
266 gcra::Gcra,
267 quota::Quota,
268 };
269
270 fn initialize_mock_rate_limiter() -> RateLimiter<String, FakeRelativeClock> {
271 let clock = FakeRelativeClock::default();
272 let start = clock.now();
273 let gcra = DashMap::new();
274 let base_quota = Quota::per_second(NonZeroU32::new(2).unwrap());
275 RateLimiter {
276 default_gcra: Some(Gcra::new(base_quota)),
277 state: DashMapStateStore::new(),
278 gcra,
279 clock,
280 start,
281 }
282 }
283
284 #[rstest]
285 fn test_default_quota() {
286 let mock_limiter = initialize_mock_rate_limiter();
287
288 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
290 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
291
292 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
294
295 mock_limiter.advance_clock(Duration::from_secs(1));
297 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
298 }
299
300 #[rstest]
301 fn test_custom_key_quota() {
302 let mock_limiter = initialize_mock_rate_limiter();
303
304 mock_limiter.add_quota_for_key(
306 "custom".to_string(),
307 Quota::per_second(NonZeroU32::new(1).unwrap()),
308 );
309
310 assert!(mock_limiter.check_key(&"custom".to_string()).is_ok());
312 assert!(mock_limiter.check_key(&"custom".to_string()).is_err());
313
314 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
316 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
317 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
318 }
319
320 #[rstest]
321 fn test_multiple_keys() {
322 let mock_limiter = initialize_mock_rate_limiter();
323
324 mock_limiter.add_quota_for_key(
325 "key1".to_string(),
326 Quota::per_second(NonZeroU32::new(1).unwrap()),
327 );
328 mock_limiter.add_quota_for_key(
329 "key2".to_string(),
330 Quota::per_second(NonZeroU32::new(3).unwrap()),
331 );
332
333 assert!(mock_limiter.check_key(&"key1".to_string()).is_ok());
335 assert!(mock_limiter.check_key(&"key1".to_string()).is_err());
336
337 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
339 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
340 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
341 assert!(mock_limiter.check_key(&"key2".to_string()).is_err());
342 }
343
344 #[rstest]
345 fn test_quota_reset() {
346 let mock_limiter = initialize_mock_rate_limiter();
347
348 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
350 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
351 assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
352
353 mock_limiter.advance_clock(Duration::from_millis(499));
355 assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
356
357 mock_limiter.advance_clock(Duration::from_millis(501));
359 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
360 }
361
362 #[rstest]
363 fn test_different_quotas() {
364 let mock_limiter = initialize_mock_rate_limiter();
365
366 mock_limiter.add_quota_for_key(
367 "per_second".to_string(),
368 Quota::per_second(NonZeroU32::new(2).unwrap()),
369 );
370 mock_limiter.add_quota_for_key(
371 "per_minute".to_string(),
372 Quota::per_minute(NonZeroU32::new(3).unwrap()),
373 );
374
375 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
377 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
378 assert!(mock_limiter.check_key(&"per_second".to_string()).is_err());
379
380 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
382 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
383 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
384 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
385
386 mock_limiter.advance_clock(Duration::from_secs(1));
388 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
389 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
390 }
391
392 #[tokio::test]
393 async fn test_await_keys_ready() {
394 let mock_limiter = initialize_mock_rate_limiter();
395
396 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
398 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
399
400 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
402
403 mock_limiter.advance_clock(Duration::from_secs(1));
405 mock_limiter
406 .await_keys_ready(Some(vec!["default".to_string()]))
407 .await;
408 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
409 }
410
411 #[rstest]
412 fn test_gcra_boundary_exact_replenishment() {
413 let mock_limiter = initialize_mock_rate_limiter();
416 let key = "boundary_test".to_string();
417
418 assert!(mock_limiter.check_key(&key).is_ok());
420 assert!(mock_limiter.check_key(&key).is_ok());
421
422 assert!(mock_limiter.check_key(&key).is_err());
424
425 let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
427 let replenish_interval = quota.replenish_interval();
428 mock_limiter.advance_clock(replenish_interval);
429
430 assert!(
432 mock_limiter.check_key(&key).is_ok(),
433 "Request at exact replenish boundary should be allowed"
434 );
435
436 assert!(
438 mock_limiter.check_key(&key).is_err(),
439 "Immediate follow-up should be rate-limited"
440 );
441 }
442}