nautilus_network/ratelimiter/
mod.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! A rate limiter implementation heavily inspired by [governor](https://github.com/antifuchs/governor).
17//!
18//! The governor does not support different quota for different key. It is an open [issue](https://github.com/antifuchs/governor/issues/193).
19pub mod clock;
20mod gcra;
21mod nanos;
22pub mod quota;
23
24use std::{
25    fmt::Debug,
26    hash::Hash,
27    num::NonZeroU64,
28    sync::atomic::{AtomicU64, Ordering},
29    time::Duration,
30};
31
32use dashmap::DashMap;
33use futures_util::StreamExt;
34use tokio::time::sleep;
35
36use self::{
37    clock::{Clock, FakeRelativeClock, MonotonicClock},
38    gcra::{Gcra, NotUntil},
39    nanos::Nanos,
40    quota::Quota,
41};
42
43/// An in-memory representation of a GCRA's rate-limiting state.
44///
45/// Implemented using [`AtomicU64`] operations, this state representation can be used to
46/// construct rate limiting states for other in-memory states: e.g., this crate uses
47/// `InMemoryState` as the states it tracks in the keyed rate limiters it implements.
48///
49/// Internally, the number tracked here is the theoretical arrival time (a GCRA term) in number of
50/// nanoseconds since the rate limiter was created.
51#[derive(Debug, Default)]
52pub struct InMemoryState(AtomicU64);
53
54impl InMemoryState {
55    /// Measures and updates the GCRA's state atomically, retrying on concurrent modifications.
56    ///
57    /// # Errors
58    ///
59    /// Returns an error if the provided closure returns an error.
60    pub(crate) fn measure_and_replace_one<T, F, E>(&self, mut f: F) -> Result<T, E>
61    where
62        F: FnMut(Option<Nanos>) -> Result<(T, Nanos), E>,
63    {
64        let mut prev = self.0.load(Ordering::Acquire);
65        let mut decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
66        while let Ok((result, new_data)) = decision {
67            // Lock-free CAS loop: retry with current value if another thread modified it,
68            // uses weak variant (faster) since spurious failures are fine in a retry loop.
69            match self.0.compare_exchange_weak(
70                prev,
71                new_data.into(),
72                Ordering::Release,
73                Ordering::Relaxed,
74            ) {
75                Ok(_) => return Ok(result),
76                Err(next_prev) => prev = next_prev, // Retry with value written by another thread
77            }
78            decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
79        }
80        // This map shouldn't be needed, as we only get here in the error case, but the compiler
81        // can't see it.
82        decision.map(|(result, _)| result)
83    }
84}
85
86/// A concurrent, thread-safe and fairly performant hashmap based on [`DashMap`].
87pub type DashMapStateStore<K> = DashMap<K, InMemoryState>;
88
89/// A way for rate limiters to keep state.
90///
91/// There are two important kinds of state stores: Direct and keyed. The direct kind have only
92/// one state, and are useful for "global" rate limit enforcement (e.g. a process should never
93/// do more than N tasks a day). The keyed kind allows one rate limit per key (e.g. an API
94/// call budget per client API key).
95///
96/// A direct state store is expressed as [`StateStore::Key`] = `NotKeyed`.
97/// Keyed state stores have a
98/// type parameter for the key and set their key to that.
99pub trait StateStore {
100    /// The type of key that the state store can represent.
101    type Key;
102
103    /// Updates a state store's rate limiting state for a given key, using the given closure.
104    ///
105    /// The closure parameter takes the old value (`None` if this is the first measurement) of the
106    /// state store at the key's location, checks if the request an be accommodated and:
107    ///
108    /// - If the request is rate-limited, returns `Err(E)`.
109    /// - If the request can make it through, returns `Ok(T)` (an arbitrary positive return
110    ///   value) and the updated state.
111    ///
112    /// It is `measure_and_replace`'s job then to safely replace the value at the key - it must
113    /// only update the value if the value hasn't changed. The implementations in this
114    /// crate use `AtomicU64` operations for this.
115    ///
116    /// # Errors
117    ///
118    /// Returns `Err(E)` if the closure returns an error or the request is rate-limited.
119    fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
120    where
121        F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>;
122}
123
124impl<K: Hash + Eq + Clone> StateStore for DashMapStateStore<K> {
125    type Key = K;
126
127    fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
128    where
129        F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>,
130    {
131        if let Some(v) = self.get(key) {
132            // fast path: measure existing entry
133            return v.measure_and_replace_one(f);
134        }
135        // make an entry and measure that:
136        let entry = self.entry(key.clone()).or_default();
137        (*entry).measure_and_replace_one(f)
138    }
139}
140
141/// A rate limiter that enforces different quotas per key using the GCRA algorithm.
142///
143/// This implementation allows setting different rate limits for different keys,
144/// with an optional default quota for keys that don't have specific quotas.
145pub struct RateLimiter<K, C>
146where
147    C: Clock,
148{
149    default_gcra: Option<Gcra>,
150    state: DashMapStateStore<K>,
151    gcra: DashMap<K, Gcra>,
152    clock: C,
153    start: C::Instant,
154}
155
156impl<K, C> Debug for RateLimiter<K, C>
157where
158    K: Debug,
159    C: Clock,
160{
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        f.debug_struct(stringify!(RateLimiter)).finish()
163    }
164}
165
166impl<K> RateLimiter<K, MonotonicClock>
167where
168    K: Eq + Hash,
169{
170    /// Creates a new rate limiter with a base quota and keyed quotas.
171    ///
172    /// The base quota applies to all keys that don't have specific quotas.
173    /// Keyed quotas override the base quota for specific keys.
174    #[must_use]
175    pub fn new_with_quota(base_quota: Option<Quota>, keyed_quotas: Vec<(K, Quota)>) -> Self {
176        let clock = MonotonicClock {};
177        let start = MonotonicClock::now(&clock);
178        let gcra = DashMap::from_iter(keyed_quotas.into_iter().map(|(k, q)| (k, Gcra::new(q))));
179        Self {
180            default_gcra: base_quota.map(Gcra::new),
181            state: DashMapStateStore::new(),
182            gcra,
183            clock,
184            start,
185        }
186    }
187}
188
189impl<K> RateLimiter<K, FakeRelativeClock>
190where
191    K: Hash + Eq + Clone,
192{
193    /// Advances the fake clock by the specified duration.
194    ///
195    /// This is only available for testing with `FakeRelativeClock`.
196    pub fn advance_clock(&self, by: Duration) {
197        self.clock.advance(by);
198    }
199}
200
201impl<K, C> RateLimiter<K, C>
202where
203    K: Hash + Eq + Clone,
204    C: Clock,
205{
206    /// Adds or updates a quota for a specific key.
207    pub fn add_quota_for_key(&self, key: K, value: Quota) {
208        self.gcra.insert(key, Gcra::new(value));
209    }
210
211    /// Checks if the given key is allowed under the rate limit.
212    ///
213    /// # Errors
214    ///
215    /// Returns `Err(NotUntil)` if the key is rate-limited, indicating when it will be allowed.
216    pub fn check_key(&self, key: &K) -> Result<(), NotUntil<C::Instant>> {
217        match self.gcra.get(key) {
218            Some(quota) => quota.test_and_update(self.start, key, &self.state, self.clock.now()),
219            None => self.default_gcra.as_ref().map_or(Ok(()), |gcra| {
220                gcra.test_and_update(self.start, key, &self.state, self.clock.now())
221            }),
222        }
223    }
224
225    /// Waits until the specified key is ready (not rate-limited).
226    pub async fn until_key_ready(&self, key: &K) {
227        loop {
228            match self.check_key(key) {
229                Ok(()) => {
230                    break;
231                }
232                Err(neg) => {
233                    sleep(neg.wait_time_from(self.clock.now())).await;
234                }
235            }
236        }
237    }
238
239    /// Waits until all specified keys are ready (not rate-limited).
240    ///
241    /// If no keys are provided, this function returns immediately.
242    pub async fn await_keys_ready(&self, keys: Option<Vec<K>>) {
243        let keys = keys.unwrap_or_default();
244        let tasks = keys.iter().map(|key| self.until_key_ready(key));
245
246        futures::stream::iter(tasks)
247            .for_each_concurrent(None, |key_future| async move {
248                key_future.await;
249            })
250            .await;
251    }
252}
253
254////////////////////////////////////////////////////////////////////////////////
255// Tests
256////////////////////////////////////////////////////////////////////////////////
257#[cfg(test)]
258mod tests {
259    use std::{num::NonZeroU32, time::Duration};
260
261    use dashmap::DashMap;
262    use rstest::rstest;
263
264    use super::{
265        DashMapStateStore, RateLimiter,
266        clock::{Clock, FakeRelativeClock},
267        gcra::Gcra,
268        quota::Quota,
269    };
270
271    fn initialize_mock_rate_limiter() -> RateLimiter<String, FakeRelativeClock> {
272        let clock = FakeRelativeClock::default();
273        let start = clock.now();
274        let gcra = DashMap::new();
275        let base_quota = Quota::per_second(NonZeroU32::new(2).unwrap());
276        RateLimiter {
277            default_gcra: Some(Gcra::new(base_quota)),
278            state: DashMapStateStore::new(),
279            gcra,
280            clock,
281            start,
282        }
283    }
284
285    #[rstest]
286    fn test_default_quota() {
287        let mock_limiter = initialize_mock_rate_limiter();
288
289        // Check base quota is not exceeded
290        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
291        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
292
293        // Check base quota is exceeded
294        assert!(mock_limiter.check_key(&"default".to_string()).is_err());
295
296        // Increment clock and check base quota is reset
297        mock_limiter.advance_clock(Duration::from_secs(1));
298        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
299    }
300
301    #[rstest]
302    fn test_custom_key_quota() {
303        let mock_limiter = initialize_mock_rate_limiter();
304
305        // Add new key quota pair
306        mock_limiter.add_quota_for_key(
307            "custom".to_string(),
308            Quota::per_second(NonZeroU32::new(1).unwrap()),
309        );
310
311        // Check custom quota
312        assert!(mock_limiter.check_key(&"custom".to_string()).is_ok());
313        assert!(mock_limiter.check_key(&"custom".to_string()).is_err());
314
315        // Check that default quota still applies to other keys
316        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
317        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
318        assert!(mock_limiter.check_key(&"default".to_string()).is_err());
319    }
320
321    #[rstest]
322    fn test_multiple_keys() {
323        let mock_limiter = initialize_mock_rate_limiter();
324
325        mock_limiter.add_quota_for_key(
326            "key1".to_string(),
327            Quota::per_second(NonZeroU32::new(1).unwrap()),
328        );
329        mock_limiter.add_quota_for_key(
330            "key2".to_string(),
331            Quota::per_second(NonZeroU32::new(3).unwrap()),
332        );
333
334        // Test key1
335        assert!(mock_limiter.check_key(&"key1".to_string()).is_ok());
336        assert!(mock_limiter.check_key(&"key1".to_string()).is_err());
337
338        // Test key2
339        assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
340        assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
341        assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
342        assert!(mock_limiter.check_key(&"key2".to_string()).is_err());
343    }
344
345    #[rstest]
346    fn test_quota_reset() {
347        let mock_limiter = initialize_mock_rate_limiter();
348
349        // Exhaust quota
350        assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
351        assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
352        assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
353
354        // Advance clock by less than a second
355        mock_limiter.advance_clock(Duration::from_millis(499));
356        assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
357
358        // Advance clock to reset
359        mock_limiter.advance_clock(Duration::from_millis(501));
360        assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
361    }
362
363    #[rstest]
364    fn test_different_quotas() {
365        let mock_limiter = initialize_mock_rate_limiter();
366
367        mock_limiter.add_quota_for_key(
368            "per_second".to_string(),
369            Quota::per_second(NonZeroU32::new(2).unwrap()),
370        );
371        mock_limiter.add_quota_for_key(
372            "per_minute".to_string(),
373            Quota::per_minute(NonZeroU32::new(3).unwrap()),
374        );
375
376        // Test per_second quota
377        assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
378        assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
379        assert!(mock_limiter.check_key(&"per_second".to_string()).is_err());
380
381        // Test per_minute quota
382        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
383        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
384        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
385        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
386
387        // Advance clock and check reset
388        mock_limiter.advance_clock(Duration::from_secs(1));
389        assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
390        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
391    }
392
393    #[tokio::test]
394    async fn test_await_keys_ready() {
395        let mock_limiter = initialize_mock_rate_limiter();
396
397        // Check base quota is not exceeded
398        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
399        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
400
401        // Check base quota is exceeded
402        assert!(mock_limiter.check_key(&"default".to_string()).is_err());
403
404        // Wait keys to be ready and check base quota is reset
405        mock_limiter.advance_clock(Duration::from_secs(1));
406        mock_limiter
407            .await_keys_ready(Some(vec!["default".to_string()]))
408            .await;
409        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
410    }
411
412    #[rstest]
413    fn test_gcra_boundary_exact_replenishment() {
414        // Test GCRA boundary condition where t0 equals earliest_time exactly.
415        // This exercises the saturating_sub edge case deterministically without sleeps.
416        let mock_limiter = initialize_mock_rate_limiter();
417        let key = "boundary_test".to_string();
418
419        // Consume entire burst capacity (2 requests)
420        assert!(mock_limiter.check_key(&key).is_ok());
421        assert!(mock_limiter.check_key(&key).is_ok());
422
423        // Next request should be rate-limited
424        assert!(mock_limiter.check_key(&key).is_err());
425
426        // Advance clock by exactly one replenish interval (500ms for 2 req/sec)
427        let quota = Quota::per_second(NonZeroU32::new(2).unwrap());
428        let replenish_interval = quota.replenish_interval();
429        mock_limiter.advance_clock(replenish_interval);
430
431        // At the exact boundary (t0 == earliest_time), request should be allowed
432        assert!(
433            mock_limiter.check_key(&key).is_ok(),
434            "Request at exact replenish boundary should be allowed"
435        );
436
437        // But the next immediate request should be denied (burst exhausted again)
438        assert!(
439            mock_limiter.check_key(&key).is_err(),
440            "Immediate follow-up should be rate-limited"
441        );
442    }
443}