1// -------------------------------------------------------------------------------------------------
2// Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3// https://nautechsystems.io
4//
5// Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6// You may not use this file except in compliance with the License.
7// You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// -------------------------------------------------------------------------------------------------
1516#![allow(clippy::missing_errors_doc)] // Under development
1718//! A rate limiter implementation heavily inspired by [governor](https://github.com/antifuchs/governor)
19//!
20//! The governor does not support different quota for different key. It is an open [issue](https://github.com/antifuchs/governor/issues/193)
21pub mod clock;
22mod gcra;
23mod nanos;
24pub mod quota;
2526use std::{
27 hash::Hash,
28 num::NonZeroU64,
29 sync::atomic::{AtomicU64, Ordering},
30 time::Duration,
31};
3233use dashmap::DashMap;
34use futures_util::StreamExt;
35use tokio::time::sleep;
3637use self::{
38 clock::{Clock, FakeRelativeClock, MonotonicClock},
39 gcra::{Gcra, NotUntil},
40 nanos::Nanos,
41 quota::Quota,
42};
4344/// An in-memory representation of a GCRA's rate-limiting state.
45///
46/// Implemented using [`AtomicU64`] operations, this state representation can be used to
47/// construct rate limiting states for other in-memory states: e.g., this crate uses
48/// `InMemoryState` as the states it tracks in the keyed rate limiters it implements.
49///
50/// Internally, the number tracked here is the theoretical arrival time (a GCRA term) in number of
51/// nanoseconds since the rate limiter was created.
52#[derive(Default)]
53pub struct InMemoryState(AtomicU64);
5455impl InMemoryState {
56pub(crate) fn measure_and_replace_one<T, F, E>(&self, mut f: F) -> Result<T, E>
57where
58F: FnMut(Option<Nanos>) -> Result<(T, Nanos), E>,
59 {
60let mut prev = self.0.load(Ordering::Acquire);
61let mut decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
62while let Ok((result, new_data)) = decision {
63match self.0.compare_exchange_weak(
64 prev,
65 new_data.into(),
66 Ordering::Release,
67 Ordering::Relaxed,
68 ) {
69Ok(_) => return Ok(result),
70Err(next_prev) => prev = next_prev,
71 }
72 decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
73 }
74// This map shouldn't be needed, as we only get here in the error case, but the compiler
75 // can't see it.
76decision.map(|(result, _)| result)
77 }
78}
7980/// A concurrent, thread-safe and fairly performant hashmap based on [`DashMap`].
81pub type DashMapStateStore<K> = DashMap<K, InMemoryState>;
8283/// A way for rate limiters to keep state.
84///
85/// There are two important kinds of state stores: Direct and keyed. The direct kind have only
86/// one state, and are useful for "global" rate limit enforcement (e.g. a process should never
87/// do more than N tasks a day). The keyed kind allows one rate limit per key (e.g. an API
88/// call budget per client API key).
89///
90/// A direct state store is expressed as [`StateStore::Key`] = `NotKeyed`.
91/// Keyed state stores have a
92/// type parameter for the key and set their key to that.
93pub trait StateStore {
94/// The type of key that the state store can represent.
95type Key;
9697/// Updates a state store's rate limiting state for a given key, using the given closure.
98 ///
99 /// The closure parameter takes the old value (`None` if this is the first measurement) of the
100 /// state store at the key's location, checks if the request an be accommodated and:
101 ///
102 /// * If the request is rate-limited, returns `Err(E)`.
103 /// * If the request can make it through, returns `Ok(T)` (an arbitrary positive return
104 /// value) and the updated state.
105 ///
106 /// It is `measure_and_replace`'s job then to safely replace the value at the key - it must
107 /// only update the value if the value hasn't changed. The implementations in this
108 /// crate use `AtomicU64` operations for this.
109fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
110where
111F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>;
112}
113114impl<K: Hash + Eq + Clone> StateStore for DashMapStateStore<K> {
115type Key = K;
116117fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
118where
119F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>,
120 {
121if let Some(v) = self.get(key) {
122// fast path: measure existing entry
123return v.measure_and_replace_one(f);
124 }
125// make an entry and measure that:
126let entry = self.entry(key.clone()).or_default();
127 (*entry).measure_and_replace_one(f)
128 }
129}
130131pub struct RateLimiter<K, C>
132where
133C: Clock,
134{
135 default_gcra: Option<Gcra>,
136 state: DashMapStateStore<K>,
137 gcra: DashMap<K, Gcra>,
138 clock: C,
139 start: C::Instant,
140}
141142impl<K> RateLimiter<K, MonotonicClock>
143where
144K: Eq + Hash,
145{
146pub fn new_with_quota(base_quota: Option<Quota>, keyed_quotas: Vec<(K, Quota)>) -> Self {
147let clock = MonotonicClock {};
148let start = MonotonicClock::now(&clock);
149let gcra = DashMap::from_iter(keyed_quotas.into_iter().map(|(k, q)| (k, Gcra::new(q))));
150Self {
151 default_gcra: base_quota.map(Gcra::new),
152 state: DashMapStateStore::new(),
153 gcra,
154 clock,
155 start,
156 }
157 }
158}
159160impl<K> RateLimiter<K, FakeRelativeClock>
161where
162K: Hash + Eq + Clone,
163{
164pub fn advance_clock(&self, by: Duration) {
165self.clock.advance(by);
166 }
167}
168169impl<K, C> RateLimiter<K, C>
170where
171K: Hash + Eq + Clone,
172 C: Clock,
173{
174pub fn add_quota_for_key(&self, key: K, value: Quota) {
175self.gcra.insert(key, Gcra::new(value));
176 }
177178pub fn check_key(&self, key: &K) -> Result<(), NotUntil<C::Instant>> {
179match self.gcra.get(key) {
180Some(quota) => quota.test_and_update(self.start, key, &self.state, self.clock.now()),
181None => self.default_gcra.as_ref().map_or(Ok(()), |gcra| {
182 gcra.test_and_update(self.start, key, &self.state, self.clock.now())
183 }),
184 }
185 }
186187pub async fn until_key_ready(&self, key: &K) {
188loop {
189match self.check_key(key) {
190Ok(()) => {
191break;
192 }
193Err(neg) => {
194 sleep(neg.wait_time_from(self.clock.now())).await;
195 }
196 }
197 }
198 }
199200pub async fn await_keys_ready(&self, keys: Option<Vec<K>>) {
201let keys = keys.unwrap_or_default();
202let tasks = keys.iter().map(|key| self.until_key_ready(key));
203204 futures::stream::iter(tasks)
205 .for_each_concurrent(None, |key_future| async move {
206 key_future.await;
207 })
208 .await;
209 }
210}
211212////////////////////////////////////////////////////////////////////////////////
213// Tests
214////////////////////////////////////////////////////////////////////////////////
215#[cfg(test)]
216mod tests {
217use std::{num::NonZeroU32, time::Duration};
218219use dashmap::DashMap;
220221use super::{
222 DashMapStateStore, RateLimiter,
223 clock::{Clock, FakeRelativeClock},
224 gcra::Gcra,
225 quota::Quota,
226 };
227228fn initialize_mock_rate_limiter() -> RateLimiter<String, FakeRelativeClock> {
229let clock = FakeRelativeClock::default();
230let start = clock.now();
231let gcra = DashMap::new();
232let base_quota = Quota::per_second(NonZeroU32::new(2).unwrap());
233 RateLimiter {
234 default_gcra: Some(Gcra::new(base_quota)),
235 state: DashMapStateStore::new(),
236 gcra,
237 clock,
238 start,
239 }
240 }
241242#[test]
243fn test_default_quota() {
244let mock_limiter = initialize_mock_rate_limiter();
245246// Check base quota is not exceeded
247assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
248assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
249250// Check base quota is exceeded
251assert!(mock_limiter.check_key(&"default".to_string()).is_err());
252253// Increment clock and check base quota is reset
254mock_limiter.advance_clock(Duration::from_secs(1));
255assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
256 }
257258#[test]
259fn test_custom_key_quota() {
260let mock_limiter = initialize_mock_rate_limiter();
261262// Add new key quota pair
263mock_limiter.add_quota_for_key(
264"custom".to_string(),
265 Quota::per_second(NonZeroU32::new(1).unwrap()),
266 );
267268// Check custom quota
269assert!(mock_limiter.check_key(&"custom".to_string()).is_ok());
270assert!(mock_limiter.check_key(&"custom".to_string()).is_err());
271272// Check that default quota still applies to other keys
273assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
274assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
275assert!(mock_limiter.check_key(&"default".to_string()).is_err());
276 }
277278#[test]
279fn test_multiple_keys() {
280let mock_limiter = initialize_mock_rate_limiter();
281282 mock_limiter.add_quota_for_key(
283"key1".to_string(),
284 Quota::per_second(NonZeroU32::new(1).unwrap()),
285 );
286 mock_limiter.add_quota_for_key(
287"key2".to_string(),
288 Quota::per_second(NonZeroU32::new(3).unwrap()),
289 );
290291// Test key1
292assert!(mock_limiter.check_key(&"key1".to_string()).is_ok());
293assert!(mock_limiter.check_key(&"key1".to_string()).is_err());
294295// Test key2
296assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
297assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
298assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
299assert!(mock_limiter.check_key(&"key2".to_string()).is_err());
300 }
301302#[test]
303fn test_quota_reset() {
304let mock_limiter = initialize_mock_rate_limiter();
305306// Exhaust quota
307assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
308assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
309assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
310311// Advance clock by less than a second
312mock_limiter.advance_clock(Duration::from_millis(499));
313assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
314315// Advance clock to reset
316mock_limiter.advance_clock(Duration::from_millis(501));
317assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
318 }
319320#[test]
321fn test_different_quotas() {
322let mock_limiter = initialize_mock_rate_limiter();
323324 mock_limiter.add_quota_for_key(
325"per_second".to_string(),
326 Quota::per_second(NonZeroU32::new(2).unwrap()),
327 );
328 mock_limiter.add_quota_for_key(
329"per_minute".to_string(),
330 Quota::per_minute(NonZeroU32::new(3).unwrap()),
331 );
332333// Test per_second quota
334assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
335assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
336assert!(mock_limiter.check_key(&"per_second".to_string()).is_err());
337338// Test per_minute quota
339assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
340assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
341assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
342assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
343344// Advance clock and check reset
345mock_limiter.advance_clock(Duration::from_secs(1));
346assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
347assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
348 }
349350#[tokio::test]
351async fn test_await_keys_ready() {
352let mock_limiter = initialize_mock_rate_limiter();
353354// Check base quota is not exceeded
355assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
356assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
357358// Check base quota is exceeded
359assert!(mock_limiter.check_key(&"default".to_string()).is_err());
360361// Wait keys to be ready and check base quota is reset
362mock_limiter.advance_clock(Duration::from_secs(1));
363 mock_limiter
364 .await_keys_ready(Some(vec!["default".to_string()]))
365 .await;
366assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
367 }
368}