nautilus_network/ratelimiter/
mod.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! A rate limiter implementation heavily inspired by [governor](https://github.com/antifuchs/governor).
17//!
18//! The governor does not support different quota for different key. It is an open [issue](https://github.com/antifuchs/governor/issues/193).
19pub mod clock;
20mod gcra;
21mod nanos;
22pub mod quota;
23
24use std::{
25    fmt::Debug,
26    hash::Hash,
27    num::NonZeroU64,
28    sync::atomic::{AtomicU64, Ordering},
29    time::Duration,
30};
31
32use dashmap::DashMap;
33use futures_util::StreamExt;
34
35use self::{
36    clock::{Clock, FakeRelativeClock, MonotonicClock},
37    gcra::{Gcra, NotUntil},
38    nanos::Nanos,
39    quota::Quota,
40};
41
42/// An in-memory representation of a GCRA's rate-limiting state.
43///
44/// Implemented using [`AtomicU64`] operations, this state representation can be used to
45/// construct rate limiting states for other in-memory states: e.g., this crate uses
46/// `InMemoryState` as the states it tracks in the keyed rate limiters it implements.
47///
48/// Internally, the number tracked here is the theoretical arrival time (a GCRA term) in number of
49/// nanoseconds since the rate limiter was created.
50#[derive(Debug, Default)]
51pub struct InMemoryState(AtomicU64);
52
53impl InMemoryState {
54    /// Measures and updates the GCRA's state atomically, retrying on concurrent modifications.
55    ///
56    /// # Errors
57    ///
58    /// Returns an error if the provided closure returns an error.
59    pub(crate) fn measure_and_replace_one<T, F, E>(&self, mut f: F) -> Result<T, E>
60    where
61        F: FnMut(Option<Nanos>) -> Result<(T, Nanos), E>,
62    {
63        let mut prev = self.0.load(Ordering::Acquire);
64        let mut decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
65        while let Ok((result, new_data)) = decision {
66            // Lock-free CAS loop: retry with current value if another thread modified it,
67            // uses weak variant (faster) since spurious failures are fine in a retry loop.
68            match self.0.compare_exchange_weak(
69                prev,
70                new_data.into(),
71                Ordering::Release,
72                Ordering::Relaxed,
73            ) {
74                Ok(_) => return Ok(result),
75                Err(e) => prev = e, // Retry with value written by another thread
76            }
77            decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
78        }
79        // This map shouldn't be needed, as we only get here in the error case, but the compiler
80        // can't see it.
81        decision.map(|(result, _)| result)
82    }
83}
84
85/// A concurrent, thread-safe and fairly performant hashmap based on [`DashMap`].
86pub type DashMapStateStore<K> = DashMap<K, InMemoryState>;
87
88/// A way for rate limiters to keep state.
89///
90/// There are two important kinds of state stores: Direct and keyed. The direct kind have only
91/// one state, and are useful for "global" rate limit enforcement (e.g. a process should never
92/// do more than N tasks a day). The keyed kind allows one rate limit per key (e.g. an API
93/// call budget per client API key).
94///
95/// A direct state store is expressed as [`StateStore::Key`] = `NotKeyed`.
96/// Keyed state stores have a
97/// type parameter for the key and set their key to that.
98pub trait StateStore {
99    /// The type of key that the state store can represent.
100    type Key;
101
102    /// Updates a state store's rate limiting state for a given key, using the given closure.
103    ///
104    /// The closure parameter takes the old value (`None` if this is the first measurement) of the
105    /// state store at the key's location, checks if the request an be accommodated and:
106    ///
107    /// - If the request is rate-limited, returns `Err(E)`.
108    /// - If the request can make it through, returns `Ok(T)` (an arbitrary positive return
109    ///   value) and the updated state.
110    ///
111    /// It is `measure_and_replace`'s job then to safely replace the value at the key - it must
112    /// only update the value if the value hasn't changed. The implementations in this
113    /// crate use `AtomicU64` operations for this.
114    ///
115    /// # Errors
116    ///
117    /// Returns `Err(E)` if the closure returns an error or the request is rate-limited.
118    fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
119    where
120        F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>;
121}
122
123impl<K: Hash + Eq + Clone> StateStore for DashMapStateStore<K> {
124    type Key = K;
125
126    fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
127    where
128        F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>,
129    {
130        if let Some(v) = self.get(key) {
131            // fast path: measure existing entry
132            return v.measure_and_replace_one(f);
133        }
134        // make an entry and measure that:
135        let entry = self.entry(key.clone()).or_default();
136        (*entry).measure_and_replace_one(f)
137    }
138}
139
140/// A rate limiter that enforces different quotas per key using the GCRA algorithm.
141///
142/// This implementation allows setting different rate limits for different keys,
143/// with an optional default quota for keys that don't have specific quotas.
144pub struct RateLimiter<K, C>
145where
146    C: Clock,
147{
148    default_gcra: Option<Gcra>,
149    state: DashMapStateStore<K>,
150    gcra: DashMap<K, Gcra>,
151    clock: C,
152    start: C::Instant,
153}
154
155impl<K, C> Debug for RateLimiter<K, C>
156where
157    K: Debug,
158    C: Clock,
159{
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        f.debug_struct(stringify!(RateLimiter)).finish()
162    }
163}
164
165impl<K> RateLimiter<K, MonotonicClock>
166where
167    K: Eq + Hash,
168{
169    /// Creates a new rate limiter with a base quota and keyed quotas.
170    ///
171    /// The base quota applies to all keys that don't have specific quotas.
172    /// Keyed quotas override the base quota for specific keys.
173    #[must_use]
174    pub fn new_with_quota(base_quota: Option<Quota>, keyed_quotas: Vec<(K, Quota)>) -> Self {
175        let clock = MonotonicClock {};
176        let start = MonotonicClock::now(&clock);
177        let gcra: DashMap<_, _> = keyed_quotas
178            .into_iter()
179            .map(|(k, q)| (k, Gcra::new(q)))
180            .collect();
181        Self {
182            default_gcra: base_quota.map(Gcra::new),
183            state: DashMapStateStore::new(),
184            gcra,
185            clock,
186            start,
187        }
188    }
189}
190
191impl<K> RateLimiter<K, FakeRelativeClock>
192where
193    K: Hash + Eq + Clone,
194{
195    /// Advances the fake clock by the specified duration.
196    ///
197    /// This is only available for testing with `FakeRelativeClock`.
198    pub fn advance_clock(&self, by: Duration) {
199        self.clock.advance(by);
200    }
201}
202
203impl<K, C> RateLimiter<K, C>
204where
205    K: Hash + Eq + Clone,
206    C: Clock,
207{
208    /// Adds or updates a quota for a specific key.
209    pub fn add_quota_for_key(&self, key: K, value: Quota) {
210        self.gcra.insert(key, Gcra::new(value));
211    }
212
213    /// Checks if the given key is allowed under the rate limit.
214    ///
215    /// # Errors
216    ///
217    /// Returns `Err(NotUntil)` if the key is rate-limited, indicating when it will be allowed.
218    pub fn check_key(&self, key: &K) -> Result<(), NotUntil<C::Instant>> {
219        match self.gcra.get(key) {
220            Some(quota) => quota.test_and_update(self.start, key, &self.state, self.clock.now()),
221            None => self.default_gcra.as_ref().map_or(Ok(()), |gcra| {
222                gcra.test_and_update(self.start, key, &self.state, self.clock.now())
223            }),
224        }
225    }
226
227    /// Waits until the specified key is ready (not rate-limited).
228    pub async fn until_key_ready(&self, key: &K) {
229        loop {
230            match self.check_key(key) {
231                Ok(()) => {
232                    break;
233                }
234                Err(e) => {
235                    tokio::time::sleep(e.wait_time_from(self.clock.now())).await;
236                }
237            }
238        }
239    }
240
241    /// Waits until all specified keys are ready (not rate-limited).
242    ///
243    /// If no keys are provided, this function returns immediately.
244    pub async fn await_keys_ready(&self, keys: Option<Vec<K>>) {
245        let keys = keys.unwrap_or_default();
246        let tasks = keys.iter().map(|key| self.until_key_ready(key));
247
248        futures::stream::iter(tasks)
249            .for_each_concurrent(None, |key_future| async move {
250                key_future.await;
251            })
252            .await;
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use std::{num::NonZeroU32, time::Duration};
259
260    use dashmap::DashMap;
261    use rstest::rstest;
262
263    use super::{
264        DashMapStateStore, RateLimiter,
265        clock::{Clock, FakeRelativeClock},
266        gcra::Gcra,
267        quota::Quota,
268    };
269
270    fn initialize_mock_rate_limiter() -> RateLimiter<String, FakeRelativeClock> {
271        let clock = FakeRelativeClock::default();
272        let start = clock.now();
273        let gcra = DashMap::new();
274        let base_quota = Quota::per_second(NonZeroU32::new(2).unwrap());
275        RateLimiter {
276            default_gcra: Some(Gcra::new(base_quota)),
277            state: DashMapStateStore::new(),
278            gcra,
279            clock,
280            start,
281        }
282    }
283
284    #[rstest]
285    fn test_default_quota() {
286        let mock_limiter = initialize_mock_rate_limiter();
287
288        // Check base quota is not exceeded
289        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
290        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
291
292        // Check base quota is exceeded
293        assert!(mock_limiter.check_key(&"default".to_string()).is_err());
294
295        // Increment clock and check base quota is reset
296        mock_limiter.advance_clock(Duration::from_secs(1));
297        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
298    }
299
300    #[rstest]
301    fn test_custom_key_quota() {
302        let mock_limiter = initialize_mock_rate_limiter();
303
304        // Add new key quota pair
305        mock_limiter.add_quota_for_key(
306            "custom".to_string(),
307            Quota::per_second(NonZeroU32::new(1).unwrap()),
308        );
309
310        // Check custom quota
311        assert!(mock_limiter.check_key(&"custom".to_string()).is_ok());
312        assert!(mock_limiter.check_key(&"custom".to_string()).is_err());
313
314        // Check that default quota still applies to other keys
315        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
316        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
317        assert!(mock_limiter.check_key(&"default".to_string()).is_err());
318    }
319
320    #[rstest]
321    fn test_multiple_keys() {
322        let mock_limiter = initialize_mock_rate_limiter();
323
324        mock_limiter.add_quota_for_key(
325            "key1".to_string(),
326            Quota::per_second(NonZeroU32::new(1).unwrap()),
327        );
328        mock_limiter.add_quota_for_key(
329            "key2".to_string(),
330            Quota::per_second(NonZeroU32::new(3).unwrap()),
331        );
332
333        // Test key1
334        assert!(mock_limiter.check_key(&"key1".to_string()).is_ok());
335        assert!(mock_limiter.check_key(&"key1".to_string()).is_err());
336
337        // Test key2
338        assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
339        assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
340        assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
341        assert!(mock_limiter.check_key(&"key2".to_string()).is_err());
342    }
343
344    #[rstest]
345    fn test_quota_reset() {
346        let mock_limiter = initialize_mock_rate_limiter();
347
348        // Exhaust quota
349        assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
350        assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
351        assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
352
353        // Advance clock by less than a second
354        mock_limiter.advance_clock(Duration::from_millis(499));
355        assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
356
357        // Advance clock to reset
358        mock_limiter.advance_clock(Duration::from_millis(501));
359        assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
360    }
361
362    #[rstest]
363    fn test_different_quotas() {
364        let mock_limiter = initialize_mock_rate_limiter();
365
366        mock_limiter.add_quota_for_key(
367            "per_second".to_string(),
368            Quota::per_second(NonZeroU32::new(2).unwrap()),
369        );
370        mock_limiter.add_quota_for_key(
371            "per_minute".to_string(),
372            Quota::per_minute(NonZeroU32::new(3).unwrap()),
373        );
374
375        // Test per_second quota
376        assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
377        assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
378        assert!(mock_limiter.check_key(&"per_second".to_string()).is_err());
379
380        // Test per_minute quota
381        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
382        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
383        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
384        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
385
386        // Advance clock and check reset
387        mock_limiter.advance_clock(Duration::from_secs(1));
388        assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
389        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
390    }
391
392    #[tokio::test]
393    async fn test_await_keys_ready() {
394        let mock_limiter = initialize_mock_rate_limiter();
395
396        // Check base quota is not exceeded
397        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
398        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
399
400        // Check base quota is exceeded
401        assert!(mock_limiter.check_key(&"default".to_string()).is_err());
402
403        // Wait keys to be ready and check base quota is reset
404        mock_limiter.advance_clock(Duration::from_secs(1));
405        mock_limiter
406            .await_keys_ready(Some(vec!["default".to_string()]))
407            .await;
408        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
409    }
410
411    #[rstest]
412    fn test_gcra_boundary_exact_replenishment() {
413        // Test GCRA boundary condition where t0 equals earliest_time exactly.
414        // This exercises the saturating_sub edge case deterministically without sleeps.
415        let mock_limiter = initialize_mock_rate_limiter();
416        let key = "boundary_test".to_string();
417
418        // Consume entire burst capacity (2 requests)
419        assert!(mock_limiter.check_key(&key).is_ok());
420        assert!(mock_limiter.check_key(&key).is_ok());
421
422        // Next request should be rate-limited
423        assert!(mock_limiter.check_key(&key).is_err());
424
425        // Advance clock by exactly one replenish interval (500ms for 2 req/sec)
426        let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
427        let replenish_interval = quota.replenish_interval();
428        mock_limiter.advance_clock(replenish_interval);
429
430        // At the exact boundary (t0 == earliest_time), request should be allowed
431        assert!(
432            mock_limiter.check_key(&key).is_ok(),
433            "Request at exact replenish boundary should be allowed"
434        );
435
436        // But the next immediate request should be denied (burst exhausted again)
437        assert!(
438            mock_limiter.check_key(&key).is_err(),
439            "Immediate follow-up should be rate-limited"
440        );
441    }
442}