nautilus_network/ratelimiter/
mod.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16#![allow(clippy::missing_errors_doc)] // Under development
17
18//! A rate limiter implementation heavily inspired by [governor](https://github.com/antifuchs/governor)
19//!
20//! The governor does not support different quota for different key. It is an open [issue](https://github.com/antifuchs/governor/issues/193)
21pub mod clock;
22mod gcra;
23mod nanos;
24pub mod quota;
25
26use std::{
27    hash::Hash,
28    num::NonZeroU64,
29    sync::atomic::{AtomicU64, Ordering},
30    time::Duration,
31};
32
33use dashmap::DashMap;
34use futures_util::StreamExt;
35use tokio::time::sleep;
36
37use self::{
38    clock::{Clock, FakeRelativeClock, MonotonicClock},
39    gcra::{Gcra, NotUntil},
40    nanos::Nanos,
41    quota::Quota,
42};
43
44/// An in-memory representation of a GCRA's rate-limiting state.
45///
46/// Implemented using [`AtomicU64`] operations, this state representation can be used to
47/// construct rate limiting states for other in-memory states: e.g., this crate uses
48/// `InMemoryState` as the states it tracks in the keyed rate limiters it implements.
49///
50/// Internally, the number tracked here is the theoretical arrival time (a GCRA term) in number of
51/// nanoseconds since the rate limiter was created.
52#[derive(Default)]
53pub struct InMemoryState(AtomicU64);
54
55impl InMemoryState {
56    pub(crate) fn measure_and_replace_one<T, F, E>(&self, mut f: F) -> Result<T, E>
57    where
58        F: FnMut(Option<Nanos>) -> Result<(T, Nanos), E>,
59    {
60        let mut prev = self.0.load(Ordering::Acquire);
61        let mut decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
62        while let Ok((result, new_data)) = decision {
63            match self.0.compare_exchange_weak(
64                prev,
65                new_data.into(),
66                Ordering::Release,
67                Ordering::Relaxed,
68            ) {
69                Ok(_) => return Ok(result),
70                Err(next_prev) => prev = next_prev,
71            }
72            decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
73        }
74        // This map shouldn't be needed, as we only get here in the error case, but the compiler
75        // can't see it.
76        decision.map(|(result, _)| result)
77    }
78}
79
80/// A concurrent, thread-safe and fairly performant hashmap based on [`DashMap`].
81pub type DashMapStateStore<K> = DashMap<K, InMemoryState>;
82
83/// A way for rate limiters to keep state.
84///
85/// There are two important kinds of state stores: Direct and keyed. The direct kind have only
86/// one state, and are useful for "global" rate limit enforcement (e.g. a process should never
87/// do more than N tasks a day). The keyed kind allows one rate limit per key (e.g. an API
88/// call budget per client API key).
89///
90/// A direct state store is expressed as [`StateStore::Key`] = `NotKeyed`.
91/// Keyed state stores have a
92/// type parameter for the key and set their key to that.
93pub trait StateStore {
94    /// The type of key that the state store can represent.
95    type Key;
96
97    /// Updates a state store's rate limiting state for a given key, using the given closure.
98    ///
99    /// The closure parameter takes the old value (`None` if this is the first measurement) of the
100    /// state store at the key's location, checks if the request an be accommodated and:
101    ///
102    /// * If the request is rate-limited, returns `Err(E)`.
103    /// * If the request can make it through, returns `Ok(T)` (an arbitrary positive return
104    ///   value) and the updated state.
105    ///
106    /// It is `measure_and_replace`'s job then to safely replace the value at the key - it must
107    /// only update the value if the value hasn't changed. The implementations in this
108    /// crate use `AtomicU64` operations for this.
109    fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
110    where
111        F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>;
112}
113
114impl<K: Hash + Eq + Clone> StateStore for DashMapStateStore<K> {
115    type Key = K;
116
117    fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
118    where
119        F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>,
120    {
121        if let Some(v) = self.get(key) {
122            // fast path: measure existing entry
123            return v.measure_and_replace_one(f);
124        }
125        // make an entry and measure that:
126        let entry = self.entry(key.clone()).or_default();
127        (*entry).measure_and_replace_one(f)
128    }
129}
130
131pub struct RateLimiter<K, C>
132where
133    C: Clock,
134{
135    default_gcra: Option<Gcra>,
136    state: DashMapStateStore<K>,
137    gcra: DashMap<K, Gcra>,
138    clock: C,
139    start: C::Instant,
140}
141
142impl<K> RateLimiter<K, MonotonicClock>
143where
144    K: Eq + Hash,
145{
146    pub fn new_with_quota(base_quota: Option<Quota>, keyed_quotas: Vec<(K, Quota)>) -> Self {
147        let clock = MonotonicClock {};
148        let start = MonotonicClock::now(&clock);
149        let gcra = DashMap::from_iter(keyed_quotas.into_iter().map(|(k, q)| (k, Gcra::new(q))));
150        Self {
151            default_gcra: base_quota.map(Gcra::new),
152            state: DashMapStateStore::new(),
153            gcra,
154            clock,
155            start,
156        }
157    }
158}
159
160impl<K> RateLimiter<K, FakeRelativeClock>
161where
162    K: Hash + Eq + Clone,
163{
164    pub fn advance_clock(&self, by: Duration) {
165        self.clock.advance(by);
166    }
167}
168
169impl<K, C> RateLimiter<K, C>
170where
171    K: Hash + Eq + Clone,
172    C: Clock,
173{
174    pub fn add_quota_for_key(&self, key: K, value: Quota) {
175        self.gcra.insert(key, Gcra::new(value));
176    }
177
178    pub fn check_key(&self, key: &K) -> Result<(), NotUntil<C::Instant>> {
179        match self.gcra.get(key) {
180            Some(quota) => quota.test_and_update(self.start, key, &self.state, self.clock.now()),
181            None => self.default_gcra.as_ref().map_or(Ok(()), |gcra| {
182                gcra.test_and_update(self.start, key, &self.state, self.clock.now())
183            }),
184        }
185    }
186
187    pub async fn until_key_ready(&self, key: &K) {
188        loop {
189            match self.check_key(key) {
190                Ok(()) => {
191                    break;
192                }
193                Err(neg) => {
194                    sleep(neg.wait_time_from(self.clock.now())).await;
195                }
196            }
197        }
198    }
199
200    pub async fn await_keys_ready(&self, keys: Option<Vec<K>>) {
201        let keys = keys.unwrap_or_default();
202        let tasks = keys.iter().map(|key| self.until_key_ready(key));
203
204        futures::stream::iter(tasks)
205            .for_each_concurrent(None, |key_future| async move {
206                key_future.await;
207            })
208            .await;
209    }
210}
211
212////////////////////////////////////////////////////////////////////////////////
213// Tests
214////////////////////////////////////////////////////////////////////////////////
215#[cfg(test)]
216mod tests {
217    use std::{num::NonZeroU32, time::Duration};
218
219    use dashmap::DashMap;
220
221    use super::{
222        DashMapStateStore, RateLimiter,
223        clock::{Clock, FakeRelativeClock},
224        gcra::Gcra,
225        quota::Quota,
226    };
227
228    fn initialize_mock_rate_limiter() -> RateLimiter<String, FakeRelativeClock> {
229        let clock = FakeRelativeClock::default();
230        let start = clock.now();
231        let gcra = DashMap::new();
232        let base_quota = Quota::per_second(NonZeroU32::new(2).unwrap());
233        RateLimiter {
234            default_gcra: Some(Gcra::new(base_quota)),
235            state: DashMapStateStore::new(),
236            gcra,
237            clock,
238            start,
239        }
240    }
241
242    #[test]
243    fn test_default_quota() {
244        let mock_limiter = initialize_mock_rate_limiter();
245
246        // Check base quota is not exceeded
247        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
248        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
249
250        // Check base quota is exceeded
251        assert!(mock_limiter.check_key(&"default".to_string()).is_err());
252
253        // Increment clock and check base quota is reset
254        mock_limiter.advance_clock(Duration::from_secs(1));
255        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
256    }
257
258    #[test]
259    fn test_custom_key_quota() {
260        let mock_limiter = initialize_mock_rate_limiter();
261
262        // Add new key quota pair
263        mock_limiter.add_quota_for_key(
264            "custom".to_string(),
265            Quota::per_second(NonZeroU32::new(1).unwrap()),
266        );
267
268        // Check custom quota
269        assert!(mock_limiter.check_key(&"custom".to_string()).is_ok());
270        assert!(mock_limiter.check_key(&"custom".to_string()).is_err());
271
272        // Check that default quota still applies to other keys
273        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
274        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
275        assert!(mock_limiter.check_key(&"default".to_string()).is_err());
276    }
277
278    #[test]
279    fn test_multiple_keys() {
280        let mock_limiter = initialize_mock_rate_limiter();
281
282        mock_limiter.add_quota_for_key(
283            "key1".to_string(),
284            Quota::per_second(NonZeroU32::new(1).unwrap()),
285        );
286        mock_limiter.add_quota_for_key(
287            "key2".to_string(),
288            Quota::per_second(NonZeroU32::new(3).unwrap()),
289        );
290
291        // Test key1
292        assert!(mock_limiter.check_key(&"key1".to_string()).is_ok());
293        assert!(mock_limiter.check_key(&"key1".to_string()).is_err());
294
295        // Test key2
296        assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
297        assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
298        assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
299        assert!(mock_limiter.check_key(&"key2".to_string()).is_err());
300    }
301
302    #[test]
303    fn test_quota_reset() {
304        let mock_limiter = initialize_mock_rate_limiter();
305
306        // Exhaust quota
307        assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
308        assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
309        assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
310
311        // Advance clock by less than a second
312        mock_limiter.advance_clock(Duration::from_millis(499));
313        assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
314
315        // Advance clock to reset
316        mock_limiter.advance_clock(Duration::from_millis(501));
317        assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
318    }
319
320    #[test]
321    fn test_different_quotas() {
322        let mock_limiter = initialize_mock_rate_limiter();
323
324        mock_limiter.add_quota_for_key(
325            "per_second".to_string(),
326            Quota::per_second(NonZeroU32::new(2).unwrap()),
327        );
328        mock_limiter.add_quota_for_key(
329            "per_minute".to_string(),
330            Quota::per_minute(NonZeroU32::new(3).unwrap()),
331        );
332
333        // Test per_second quota
334        assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
335        assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
336        assert!(mock_limiter.check_key(&"per_second".to_string()).is_err());
337
338        // Test per_minute quota
339        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
340        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
341        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
342        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
343
344        // Advance clock and check reset
345        mock_limiter.advance_clock(Duration::from_secs(1));
346        assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
347        assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
348    }
349
350    #[tokio::test]
351    async fn test_await_keys_ready() {
352        let mock_limiter = initialize_mock_rate_limiter();
353
354        // Check base quota is not exceeded
355        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
356        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
357
358        // Check base quota is exceeded
359        assert!(mock_limiter.check_key(&"default".to_string()).is_err());
360
361        // Wait keys to be ready and check base quota is reset
362        mock_limiter.advance_clock(Duration::from_secs(1));
363        mock_limiter
364            .await_keys_ready(Some(vec!["default".to_string()]))
365            .await;
366        assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
367    }
368}