1use std::{
23 any::Any,
24 cell::{RefCell, UnsafeCell},
25 collections::VecDeque,
26 fmt::Debug,
27 marker::PhantomData,
28 rc::Rc,
29};
30
31use nautilus_core::{UnixNanos, correctness::FAILED};
32use ustr::Ustr;
33
34use crate::{
35 actor::{
36 Actor,
37 registry::{get_actor_unchecked, register_actor},
38 },
39 clock::Clock,
40 msgbus::{
41 self,
42 handler::{MessageHandler, ShareableMessageHandler},
43 },
44 timer::{TimeEvent, TimeEventCallback},
45};
46
47#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct RateLimit {
50 pub limit: usize,
51 pub interval_ns: u64,
52}
53
54impl RateLimit {
55 #[must_use]
57 pub const fn new(limit: usize, interval_ns: u64) -> Self {
58 Self { limit, interval_ns }
59 }
60}
61
62pub struct Throttler<T, F> {
67 pub recv_count: usize,
69 pub sent_count: usize,
71 pub is_limiting: bool,
73 pub limit: usize,
75 pub buffer: VecDeque<T>,
77 pub timestamps: VecDeque<UnixNanos>,
79 pub clock: Rc<RefCell<dyn Clock>>,
81 pub actor_id: Ustr,
83 interval: u64,
85 timer_name: Ustr,
87 output_send: F,
89 output_drop: Option<F>,
91}
92
93impl<T, F> Actor for Throttler<T, F>
94where
95 T: 'static + Debug,
96 F: Fn(T) + 'static,
97{
98 fn id(&self) -> Ustr {
99 self.actor_id
100 }
101
102 fn handle(&mut self, _msg: &dyn Any) {}
103
104 fn as_any(&self) -> &dyn Any {
105 self
106 }
107}
108
109impl<T, F> Debug for Throttler<T, F>
110where
111 T: Debug,
112{
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct(stringify!(InnerThrottler))
115 .field("recv_count", &self.recv_count)
116 .field("sent_count", &self.sent_count)
117 .field("is_limiting", &self.is_limiting)
118 .field("limit", &self.limit)
119 .field("buffer", &self.buffer)
120 .field("timestamps", &self.timestamps)
121 .field("interval", &self.interval)
122 .field("timer_name", &self.timer_name)
123 .finish()
124 }
125}
126
127impl<T, F> Throttler<T, F>
128where
129 T: Debug,
130{
131 #[inline]
132 pub fn new(
133 limit: usize,
134 interval: u64,
135 clock: Rc<RefCell<dyn Clock>>,
136 timer_name: String,
137 output_send: F,
138 output_drop: Option<F>,
139 actor_id: Ustr,
140 ) -> Self {
141 Self {
142 recv_count: 0,
143 sent_count: 0,
144 is_limiting: false,
145 limit,
146 buffer: VecDeque::new(),
147 timestamps: VecDeque::with_capacity(limit),
148 clock,
149 interval,
150 timer_name: Ustr::from(&timer_name),
151 output_send,
152 output_drop,
153 actor_id,
154 }
155 }
156
157 #[inline]
167 pub fn set_timer(&mut self, callback: Option<TimeEventCallback>) {
168 let delta = self.delta_next();
169 let mut clock = self.clock.borrow_mut();
170 if clock.timer_exists(&self.timer_name) {
171 clock.cancel_timer(&self.timer_name);
172 }
173 let alert_ts = clock.timestamp_ns() + delta;
174
175 clock
176 .set_time_alert_ns(&self.timer_name, alert_ts, callback, None)
177 .expect(FAILED);
178 }
179
180 #[inline]
182 pub fn delta_next(&mut self) -> u64 {
183 match self.timestamps.get(self.limit - 1) {
184 Some(ts) => {
185 let diff = self.clock.borrow().timestamp_ns().as_u64() - ts.as_u64();
186 self.interval.saturating_sub(diff)
187 }
188 None => 0,
189 }
190 }
191
192 #[inline]
194 pub fn reset(&mut self) {
195 self.buffer.clear();
196 self.recv_count = 0;
197 self.sent_count = 0;
198 self.is_limiting = false;
199 self.timestamps.clear();
200 }
201
202 #[inline]
204 pub fn used(&self) -> f64 {
205 if self.timestamps.is_empty() {
206 return 0.0;
207 }
208
209 let now = self.clock.borrow().timestamp_ns().as_i64();
210 let interval_start = now - self.interval as i64;
211
212 let messages_in_current_interval = self
213 .timestamps
214 .iter()
215 .take_while(|&&ts| ts.as_i64() > interval_start)
216 .count();
217
218 (messages_in_current_interval as f64) / (self.limit as f64)
219 }
220
221 #[inline]
223 pub fn qsize(&self) -> usize {
224 self.buffer.len()
225 }
226}
227
228impl<T, F> Throttler<T, F>
229where
230 T: 'static + Debug,
231 F: Fn(T) + 'static,
232{
233 pub fn to_actor(self) -> Rc<UnsafeCell<Self>> {
234 let process_handler = ThrottlerProcess::<T, F>::new(self.actor_id);
236 msgbus::register(
237 process_handler.id().as_str().into(),
238 ShareableMessageHandler::from(Rc::new(process_handler) as Rc<dyn MessageHandler>),
239 );
240
241 register_actor(self)
243 }
244
245 #[inline]
246 pub fn send_msg(&mut self, msg: T) {
247 let now = self.clock.borrow().timestamp_ns();
248
249 if self.timestamps.len() >= self.limit {
250 self.timestamps.pop_back();
251 }
252 self.timestamps.push_front(now);
253
254 self.sent_count += 1;
255 (self.output_send)(msg);
256 }
257
258 #[inline]
259 pub fn limit_msg(&mut self, msg: T) {
260 let callback = if self.output_drop.is_none() {
261 self.buffer.push_front(msg);
262 log::debug!("Buffering {}", self.buffer.len());
263 Some(ThrottlerProcess::<T, F>::new(self.actor_id).get_timer_callback())
264 } else {
265 log::debug!("Dropping");
266 if let Some(drop) = &self.output_drop {
267 drop(msg);
268 }
269 Some(throttler_resume::<T, F>(self.actor_id))
270 };
271 if !self.is_limiting {
272 log::debug!("Limiting");
273 self.set_timer(callback);
274 self.is_limiting = true;
275 }
276 }
277
278 #[inline]
279 pub fn send(&mut self, msg: T)
280 where
281 T: 'static,
282 F: Fn(T) + 'static,
283 {
284 self.recv_count += 1;
285
286 if self.is_limiting || self.delta_next() > 0 {
287 self.limit_msg(msg);
288 } else {
289 self.send_msg(msg);
290 }
291 }
292}
293
294struct ThrottlerProcess<T, F> {
299 actor_id: Ustr,
300 endpoint: Ustr,
301 phantom_t: PhantomData<T>,
302 phantom_f: PhantomData<F>,
303}
304
305impl<T, F> ThrottlerProcess<T, F>
306where
307 T: Debug,
308{
309 pub fn new(actor_id: Ustr) -> Self {
310 let endpoint = Ustr::from(&format!("{actor_id}_process"));
311 Self {
312 actor_id,
313 endpoint,
314 phantom_t: PhantomData,
315 phantom_f: PhantomData,
316 }
317 }
318
319 pub fn get_timer_callback(&self) -> TimeEventCallback {
320 let endpoint = self.endpoint.into(); TimeEventCallback::from(move |event: TimeEvent| {
322 msgbus::send_any(endpoint, &(event));
323 })
324 }
325}
326
327impl<T, F> MessageHandler for ThrottlerProcess<T, F>
328where
329 T: 'static + Debug,
330 F: Fn(T) + 'static,
331{
332 fn id(&self) -> Ustr {
333 self.endpoint
334 }
335
336 fn handle(&self, _message: &dyn Any) {
337 let throttler = get_actor_unchecked::<Throttler<T, F>>(&self.actor_id);
338 while let Some(msg) = throttler.buffer.pop_back() {
339 throttler.send_msg(msg);
340
341 if !throttler.buffer.is_empty() && throttler.delta_next() > 0 {
345 throttler.is_limiting = true;
346
347 let endpoint = self.endpoint.into(); throttler.set_timer(Some(TimeEventCallback::from(move |event: TimeEvent| {
351 msgbus::send_any(endpoint, &(event));
352 })));
353 return;
354 }
355 }
356
357 throttler.is_limiting = false;
358 }
359
360 fn as_any(&self) -> &dyn Any {
361 self
362 }
363}
364
365pub fn throttler_resume<T, F>(actor_id: Ustr) -> TimeEventCallback
367where
368 T: 'static + Debug,
369 F: Fn(T) + 'static,
370{
371 TimeEventCallback::from(move |_event: TimeEvent| {
372 let throttler = get_actor_unchecked::<Throttler<T, F>>(&actor_id);
373 throttler.is_limiting = false;
374 })
375}
376
377#[cfg(test)]
381mod tests {
382 use std::{
383 cell::{RefCell, UnsafeCell},
384 rc::Rc,
385 };
386
387 use nautilus_core::UUID4;
388 use rstest::{fixture, rstest};
389 use ustr::Ustr;
390
391 use super::{RateLimit, Throttler};
392 use crate::clock::TestClock;
393 type SharedThrottler = Rc<UnsafeCell<Throttler<u64, Box<dyn Fn(u64)>>>>;
394
395 #[derive(Clone)]
400 struct TestThrottler {
401 throttler: SharedThrottler,
402 clock: Rc<RefCell<TestClock>>,
403 interval: u64,
404 }
405
406 #[allow(unsafe_code)]
407 impl TestThrottler {
408 #[allow(clippy::mut_from_ref)]
409 pub fn get_throttler(&self) -> &mut Throttler<u64, Box<dyn Fn(u64)>> {
410 unsafe { &mut *self.throttler.get() }
411 }
412 }
413
414 #[fixture]
415 pub fn test_throttler_buffered() -> TestThrottler {
416 let output_send: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
417 log::debug!("Sent: {msg}");
418 });
419 let clock = Rc::new(RefCell::new(TestClock::new()));
420 let inner_clock = Rc::clone(&clock);
421 let rate_limit = RateLimit::new(5, 10);
422 let interval = rate_limit.interval_ns;
423 let actor_id = Ustr::from(&UUID4::new().to_string());
424
425 TestThrottler {
426 throttler: Throttler::new(
427 rate_limit.limit,
428 rate_limit.interval_ns,
429 clock,
430 "buffer_timer".to_string(),
431 output_send,
432 None,
433 actor_id,
434 )
435 .to_actor(),
436 clock: inner_clock,
437 interval,
438 }
439 }
440
441 #[fixture]
442 pub fn test_throttler_unbuffered() -> TestThrottler {
443 let output_send: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
444 log::debug!("Sent: {msg}");
445 });
446 let output_drop: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
447 log::debug!("Dropped: {msg}");
448 });
449 let clock = Rc::new(RefCell::new(TestClock::new()));
450 let inner_clock = Rc::clone(&clock);
451 let rate_limit = RateLimit::new(5, 10);
452 let interval = rate_limit.interval_ns;
453 let actor_id = Ustr::from(&UUID4::new().to_string());
454
455 TestThrottler {
456 throttler: Throttler::new(
457 rate_limit.limit,
458 rate_limit.interval_ns,
459 clock,
460 "dropper_timer".to_string(),
461 output_send,
462 Some(output_drop),
463 actor_id,
464 )
465 .to_actor(),
466 clock: inner_clock,
467 interval,
468 }
469 }
470
471 #[rstest]
472 fn test_buffering_send_to_limit_becomes_throttled(test_throttler_buffered: TestThrottler) {
473 let throttler = test_throttler_buffered.get_throttler();
474 for _ in 0..6 {
475 throttler.send(42);
476 }
477 assert_eq!(throttler.qsize(), 1);
478
479 assert!(throttler.is_limiting);
480 assert_eq!(throttler.recv_count, 6);
481 assert_eq!(throttler.sent_count, 5);
482 assert_eq!(throttler.clock.borrow().timer_names(), vec!["buffer_timer"]);
483 }
484
485 #[rstest]
486 fn test_buffering_used_when_sent_to_limit_returns_one(test_throttler_buffered: TestThrottler) {
487 let throttler = test_throttler_buffered.get_throttler();
488
489 for _ in 0..5 {
490 throttler.send(42);
491 }
492
493 assert_eq!(throttler.used(), 1.0);
494 assert_eq!(throttler.recv_count, 5);
495 assert_eq!(throttler.sent_count, 5);
496 }
497
498 #[rstest]
499 fn test_buffering_used_when_half_interval_from_limit_returns_one(
500 test_throttler_buffered: TestThrottler,
501 ) {
502 let throttler = test_throttler_buffered.get_throttler();
503
504 for _ in 0..5 {
505 throttler.send(42);
506 }
507
508 let half_interval = test_throttler_buffered.interval / 2;
509 {
511 let mut clock = test_throttler_buffered.clock.borrow_mut();
512 clock.advance_time(half_interval.into(), true);
513 }
514
515 assert_eq!(throttler.used(), 1.0);
516 assert_eq!(throttler.recv_count, 5);
517 assert_eq!(throttler.sent_count, 5);
518 }
519
520 #[rstest]
521 fn test_buffering_used_before_limit_when_halfway_returns_half(
522 test_throttler_buffered: TestThrottler,
523 ) {
524 let throttler = test_throttler_buffered.get_throttler();
525
526 for _ in 0..3 {
527 throttler.send(42);
528 }
529
530 assert_eq!(throttler.used(), 0.6);
531 assert_eq!(throttler.recv_count, 3);
532 assert_eq!(throttler.sent_count, 3);
533 }
534
535 #[rstest]
536 fn test_buffering_refresh_when_at_limit_sends_remaining_items(
537 test_throttler_buffered: TestThrottler,
538 ) {
539 let throttler = test_throttler_buffered.get_throttler();
540
541 for _ in 0..6 {
542 throttler.send(42);
543 }
544
545 {
547 let mut clock = test_throttler_buffered.clock.borrow_mut();
548 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
549 for each_event in clock.match_handlers(time_events) {
550 drop(clock); each_event.callback.call(each_event.event);
553
554 clock = test_throttler_buffered.clock.borrow_mut();
556 }
557 }
558
559 assert_eq!(throttler.used(), 0.2);
561 assert_eq!(throttler.recv_count, 6);
562 assert_eq!(throttler.sent_count, 6);
563 assert_eq!(throttler.qsize(), 0);
564 }
565
566 #[rstest]
567 fn test_buffering_send_message_after_buffering_message(test_throttler_buffered: TestThrottler) {
568 let throttler = test_throttler_buffered.get_throttler();
569
570 for _ in 0..6 {
571 throttler.send(43);
572 }
573
574 {
576 let mut clock = test_throttler_buffered.clock.borrow_mut();
577 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
578 for each_event in clock.match_handlers(time_events) {
579 drop(clock); each_event.callback.call(each_event.event);
582
583 clock = test_throttler_buffered.clock.borrow_mut();
585 }
586 }
587
588 for _ in 0..6 {
589 throttler.send(42);
590 }
591
592 assert_eq!(throttler.used(), 1.0);
594 assert_eq!(throttler.recv_count, 12);
595 assert_eq!(throttler.sent_count, 10);
596 assert_eq!(throttler.qsize(), 2);
597 }
598
599 #[rstest]
600 fn test_buffering_send_message_after_halfway_after_buffering_message(
601 test_throttler_buffered: TestThrottler,
602 ) {
603 let throttler = test_throttler_buffered.get_throttler();
604
605 for _ in 0..6 {
606 throttler.send(42);
607 }
608
609 {
611 let mut clock = test_throttler_buffered.clock.borrow_mut();
612 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
613 for each_event in clock.match_handlers(time_events) {
614 drop(clock); each_event.callback.call(each_event.event);
617
618 clock = test_throttler_buffered.clock.borrow_mut();
620 }
621 }
622
623 for _ in 0..3 {
624 throttler.send(42);
625 }
626
627 assert_eq!(throttler.used(), 0.8);
629 assert_eq!(throttler.recv_count, 9);
630 assert_eq!(throttler.sent_count, 9);
631 assert_eq!(throttler.qsize(), 0);
632 }
633
634 #[rstest]
635 fn test_dropping_send_sends_message_to_handler(test_throttler_unbuffered: TestThrottler) {
636 let throttler = test_throttler_unbuffered.get_throttler();
637 throttler.send(42);
638
639 assert!(!throttler.is_limiting);
640 assert_eq!(throttler.recv_count, 1);
641 assert_eq!(throttler.sent_count, 1);
642 }
643
644 #[rstest]
645 fn test_dropping_send_to_limit_drops_message(test_throttler_unbuffered: TestThrottler) {
646 let throttler = test_throttler_unbuffered.get_throttler();
647 for _ in 0..6 {
648 throttler.send(42);
649 }
650 assert_eq!(throttler.qsize(), 0);
651
652 assert!(throttler.is_limiting);
653 assert_eq!(throttler.used(), 1.0);
654 assert_eq!(throttler.clock.borrow().timer_count(), 1);
655 assert_eq!(
656 throttler.clock.borrow().timer_names(),
657 vec!["dropper_timer"]
658 );
659 assert_eq!(throttler.recv_count, 6);
660 assert_eq!(throttler.sent_count, 5);
661 }
662
663 #[rstest]
664 fn test_dropping_advance_time_when_at_limit_dropped_message(
665 test_throttler_unbuffered: TestThrottler,
666 ) {
667 let throttler = test_throttler_unbuffered.get_throttler();
668 for _ in 0..6 {
669 throttler.send(42);
670 }
671
672 {
674 let mut clock = test_throttler_unbuffered.clock.borrow_mut();
675 let time_events = clock.advance_time(test_throttler_unbuffered.interval.into(), true);
676 for each_event in clock.match_handlers(time_events) {
677 drop(clock); each_event.callback.call(each_event.event);
680
681 clock = test_throttler_unbuffered.clock.borrow_mut();
683 }
684 }
685
686 assert_eq!(throttler.clock.borrow().timer_count(), 0);
687 assert!(!throttler.is_limiting);
688 assert_eq!(throttler.used(), 0.0);
689 assert_eq!(throttler.recv_count, 6);
690 assert_eq!(throttler.sent_count, 5);
691 }
692
693 #[rstest]
694 fn test_dropping_send_message_after_dropping_message(test_throttler_unbuffered: TestThrottler) {
695 let throttler = test_throttler_unbuffered.get_throttler();
696 for _ in 0..6 {
697 throttler.send(42);
698 }
699
700 {
702 let mut clock = test_throttler_unbuffered.clock.borrow_mut();
703 let time_events = clock.advance_time(test_throttler_unbuffered.interval.into(), true);
704 for each_event in clock.match_handlers(time_events) {
705 drop(clock); each_event.callback.call(each_event.event);
708
709 clock = test_throttler_unbuffered.clock.borrow_mut();
711 }
712 }
713
714 throttler.send(42);
715
716 assert_eq!(throttler.used(), 0.2);
717 assert_eq!(throttler.clock.borrow().timer_count(), 0);
718 assert!(!throttler.is_limiting);
719 assert_eq!(throttler.recv_count, 7);
720 assert_eq!(throttler.sent_count, 6);
721 }
722
723 use proptest::prelude::*;
728
729 #[derive(Clone, Debug)]
730 enum ThrottlerInput {
731 SendMessage(u64),
732 AdvanceClock(u8),
733 }
734
735 fn throttler_input_strategy() -> impl Strategy<Value = ThrottlerInput> {
737 prop_oneof![
738 2 => prop::bool::ANY.prop_map(|_| ThrottlerInput::SendMessage(42)),
739 8 => prop::num::u8::ANY.prop_map(|v| ThrottlerInput::AdvanceClock(v % 5 + 5)),
740 ]
741 }
742
743 fn throttler_test_strategy() -> impl Strategy<Value = Vec<ThrottlerInput>> {
745 prop::collection::vec(throttler_input_strategy(), 10..=150)
746 }
747
748 fn test_throttler_with_inputs(inputs: Vec<ThrottlerInput>, test_throttler: TestThrottler) {
749 let test_clock = test_throttler.clock.clone();
750 let interval = test_throttler.interval;
751 let throttler = test_throttler.get_throttler();
752 let mut sent_count = 0;
753
754 for input in inputs {
755 match input {
756 ThrottlerInput::SendMessage(msg) => {
757 throttler.send(msg);
758 sent_count += 1;
759 }
760 ThrottlerInput::AdvanceClock(duration) => {
761 let mut clock_ref = test_clock.borrow_mut();
762 let current_time = clock_ref.get_time_ns();
763 let time_events =
764 clock_ref.advance_time(current_time + u64::from(duration), true);
765 for each_event in clock_ref.match_handlers(time_events) {
766 drop(clock_ref);
767 each_event.callback.call(each_event.event);
768 clock_ref = test_clock.borrow_mut();
769 }
770 }
771 }
772
773 let buffered_messages = throttler.qsize() > 0;
778 let now = throttler.clock.borrow().timestamp_ns().as_u64();
779 let limit_filled_within_interval = throttler
780 .timestamps
781 .get(throttler.limit - 1)
782 .is_some_and(|&ts| (now - ts.as_u64()) < interval);
783 let expected_limiting = buffered_messages && limit_filled_within_interval;
784 assert_eq!(throttler.is_limiting, expected_limiting);
785
786 assert_eq!(sent_count, throttler.sent_count + throttler.qsize());
788 }
789
790 let time_events = test_clock
792 .borrow_mut()
793 .advance_time((interval * 100).into(), true);
794 let mut clock_ref = test_clock.borrow_mut();
795 for each_event in clock_ref.match_handlers(time_events) {
796 drop(clock_ref);
797 each_event.callback.call(each_event.event);
798 clock_ref = test_clock.borrow_mut();
799 }
800 assert_eq!(throttler.qsize(), 0);
801 }
802
803 #[ignore = "Used for manually testing failing cases"]
804 #[rstest]
805 fn test_case() {
806 let inputs = [
807 ThrottlerInput::SendMessage(42),
808 ThrottlerInput::AdvanceClock(5),
809 ThrottlerInput::SendMessage(42),
810 ThrottlerInput::SendMessage(42),
811 ThrottlerInput::SendMessage(42),
812 ThrottlerInput::SendMessage(42),
813 ThrottlerInput::SendMessage(42),
814 ThrottlerInput::AdvanceClock(5),
815 ThrottlerInput::SendMessage(42),
816 ThrottlerInput::SendMessage(42),
817 ]
818 .to_vec();
819
820 let test_throttler = test_throttler_buffered();
821 test_throttler_with_inputs(inputs, test_throttler);
822 }
823
824 #[rstest]
825 #[allow(unsafe_code)]
826 fn prop_test() {
827 let test_throttler = test_throttler_buffered();
828
829 proptest!(move |(inputs in throttler_test_strategy())| {
830 test_throttler_with_inputs(inputs, test_throttler.clone());
831 let throttler = unsafe { &mut *test_throttler.throttler.get() };
833 throttler.reset();
834 throttler.clock.borrow_mut().reset();
835 });
836 }
837}