1use std::{
17 any::Any,
18 cell::{RefCell, UnsafeCell},
19 collections::VecDeque,
20 fmt::Debug,
21 marker::PhantomData,
22 rc::Rc,
23};
24
25use nautilus_core::{UnixNanos, correctness::FAILED};
26use ustr::Ustr;
27
28use crate::{
29 actor::{
30 Actor,
31 registry::{get_actor_unchecked, register_actor},
32 },
33 clock::Clock,
34 msgbus::{
35 self,
36 handler::{MessageHandler, ShareableMessageHandler},
37 },
38 timer::{TimeEvent, TimeEventCallback},
39};
40
41#[derive(Debug, Clone, PartialEq, Eq)]
43pub struct RateLimit {
44 pub limit: usize,
45 pub interval_ns: u64,
46}
47
48impl RateLimit {
49 #[must_use]
51 pub const fn new(limit: usize, interval_ns: u64) -> Self {
52 Self { limit, interval_ns }
53 }
54}
55
56pub struct Throttler<T, F> {
61 pub recv_count: usize,
63 pub sent_count: usize,
65 pub is_limiting: bool,
67 pub limit: usize,
69 pub buffer: VecDeque<T>,
71 pub timestamps: VecDeque<UnixNanos>,
73 pub clock: Rc<RefCell<dyn Clock>>,
75 pub actor_id: Ustr,
77 interval: u64,
79 timer_name: String,
81 output_send: F,
83 output_drop: Option<F>,
85}
86
87impl<T, F> Actor for Throttler<T, F>
88where
89 T: 'static,
90 F: Fn(T) + 'static,
91{
92 fn id(&self) -> Ustr {
93 self.actor_id
94 }
95
96 fn handle(&mut self, _msg: &dyn Any) {}
97
98 fn as_any(&self) -> &dyn Any {
99 self
100 }
101}
102
103impl<T, F> Debug for Throttler<T, F>
104where
105 T: Debug,
106{
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 f.debug_struct(stringify!(InnerThrottler))
109 .field("recv_count", &self.recv_count)
110 .field("sent_count", &self.sent_count)
111 .field("is_limiting", &self.is_limiting)
112 .field("limit", &self.limit)
113 .field("buffer", &self.buffer)
114 .field("timestamps", &self.timestamps)
115 .field("interval", &self.interval)
116 .field("timer_name", &self.timer_name)
117 .finish()
118 }
119}
120
121impl<T, F> Throttler<T, F> {
122 #[inline]
123 pub fn new(
124 limit: usize,
125 interval: u64,
126 clock: Rc<RefCell<dyn Clock>>,
127 timer_name: String,
128 output_send: F,
129 output_drop: Option<F>,
130 actor_id: Ustr,
131 ) -> Self {
132 Self {
133 recv_count: 0,
134 sent_count: 0,
135 is_limiting: false,
136 limit,
137 buffer: VecDeque::new(),
138 timestamps: VecDeque::with_capacity(limit),
139 clock,
140 interval,
141 timer_name,
142 output_send,
143 output_drop,
144 actor_id,
145 }
146 }
147
148 #[inline]
154 pub fn set_timer(&mut self, callback: Option<TimeEventCallback>) {
155 let delta = self.delta_next();
156 let mut clock = self.clock.borrow_mut();
157 if clock.timer_names().contains(&self.timer_name.as_str()) {
158 clock.cancel_timer(&self.timer_name);
159 }
160 let alert_ts = clock.timestamp_ns() + delta;
161
162 clock
163 .set_time_alert_ns(&self.timer_name, alert_ts, callback)
164 .expect(FAILED);
165 }
166
167 #[inline]
169 pub fn delta_next(&mut self) -> u64 {
170 match self.timestamps.get(self.limit - 1) {
171 Some(ts) => {
172 let diff = self.clock.borrow().timestamp_ns().as_u64() - ts.as_u64();
173 self.interval.saturating_sub(diff)
174 }
175 None => 0,
176 }
177 }
178
179 #[inline]
181 pub fn reset(&mut self) {
182 self.buffer.clear();
183 self.recv_count = 0;
184 self.sent_count = 0;
185 self.is_limiting = false;
186 self.timestamps.clear();
187 }
188
189 #[inline]
191 pub fn used(&self) -> f64 {
192 if self.timestamps.is_empty() {
193 return 0.0;
194 }
195
196 let now = self.clock.borrow().timestamp_ns().as_i64();
197 let interval_start = now - self.interval as i64;
198
199 let messages_in_current_interval = self
200 .timestamps
201 .iter()
202 .take_while(|&&ts| ts.as_i64() > interval_start)
203 .count();
204
205 (messages_in_current_interval as f64) / (self.limit as f64)
206 }
207
208 #[inline]
210 pub fn qsize(&self) -> usize {
211 self.buffer.len()
212 }
213}
214
215impl<T, F> Throttler<T, F>
216where
217 T: 'static,
218 F: Fn(T) + 'static,
219{
220 pub fn to_actor(self) -> Rc<UnsafeCell<Self>> {
221 let process_handler = ThrottlerProcess::<T, F>::new(self.actor_id);
223 msgbus::register(
224 process_handler.id().as_str(),
225 ShareableMessageHandler::from(Rc::new(process_handler) as Rc<dyn MessageHandler>),
226 );
227
228 let actor = Rc::new(UnsafeCell::new(self));
230 register_actor(actor.clone());
231
232 actor
233 }
234
235 #[inline]
236 pub fn send_msg(&mut self, msg: T) {
237 let now = self.clock.borrow().timestamp_ns();
238
239 if self.timestamps.len() >= self.limit {
240 self.timestamps.pop_back();
241 }
242 self.timestamps.push_front(now);
243
244 self.sent_count += 1;
245 (self.output_send)(msg);
246 }
247
248 #[inline]
249 pub fn limit_msg(&mut self, msg: T) {
250 let callback = if self.output_drop.is_none() {
251 self.buffer.push_front(msg);
252 log::debug!("Buffering {}", self.buffer.len());
253 Some(ThrottlerProcess::<T, F>::new(self.actor_id).get_timer_callback())
254 } else {
255 log::debug!("Dropping");
256 if let Some(drop) = &self.output_drop {
257 drop(msg);
258 }
259 Some(throttler_resume::<T, F>(self.actor_id))
260 };
261 if !self.is_limiting {
262 log::debug!("Limiting");
263 self.set_timer(callback);
264 self.is_limiting = true;
265 }
266 }
267
268 #[inline]
269 pub fn send(&mut self, msg: T)
270 where
271 T: 'static,
272 F: Fn(T) + 'static,
273 {
274 self.recv_count += 1;
275
276 if self.is_limiting || self.delta_next() > 0 {
277 self.limit_msg(msg);
278 } else {
279 self.send_msg(msg);
280 }
281 }
282}
283
284struct ThrottlerProcess<T, F> {
289 actor_id: Ustr,
290 endpoint: Ustr,
291 phantom_t: PhantomData<T>,
292 phantom_f: PhantomData<F>,
293}
294
295impl<T, F> ThrottlerProcess<T, F> {
296 pub fn new(actor_id: Ustr) -> Self {
297 let endpoint = Ustr::from(&format!("{}_process", actor_id));
298 Self {
299 actor_id,
300 endpoint,
301 phantom_t: PhantomData,
302 phantom_f: PhantomData,
303 }
304 }
305
306 pub fn get_timer_callback(&self) -> TimeEventCallback {
307 let endpoint_clone = self.endpoint;
308 let process_callback = Rc::new(move |_event: TimeEvent| {
309 msgbus::send(&endpoint_clone, &());
310 });
311 TimeEventCallback::Rust(process_callback)
312 }
313}
314
315impl<T, F> MessageHandler for ThrottlerProcess<T, F>
316where
317 T: 'static,
318 F: Fn(T) + 'static,
319{
320 fn id(&self) -> Ustr {
321 self.endpoint
322 }
323
324 fn handle(&self, _message: &dyn Any) {
325 let throttler = get_actor_unchecked::<Throttler<T, F>>(&self.actor_id);
326 while let Some(msg) = throttler.buffer.pop_back() {
327 throttler.send_msg(msg);
328
329 if !throttler.buffer.is_empty() && throttler.delta_next() > 0 {
333 throttler.is_limiting = true;
334
335 let endpoint_clone = self.endpoint;
336 let process_callback = Rc::new(move |_event: TimeEvent| {
338 msgbus::send(&endpoint_clone, &());
339 });
340 throttler.set_timer(Some(TimeEventCallback::Rust(process_callback)));
341 return;
342 }
343 }
344
345 throttler.is_limiting = false;
346 }
347
348 fn as_any(&self) -> &dyn Any {
349 self
350 }
351}
352
353pub fn throttler_resume<T, F>(actor_id: Ustr) -> TimeEventCallback
355where
356 T: 'static,
357 F: Fn(T) + 'static,
358{
359 let callback = Rc::new(move |_event: TimeEvent| {
360 let throttler = get_actor_unchecked::<Throttler<T, F>>(&actor_id);
361 throttler.is_limiting = false;
362 });
363
364 TimeEventCallback::Rust(callback)
365}
366
367#[cfg(test)]
371mod tests {
372 use std::{
373 cell::{RefCell, UnsafeCell},
374 rc::Rc,
375 };
376
377 use nautilus_core::UUID4;
378 use rstest::{fixture, rstest};
379 use ustr::Ustr;
380
381 use super::{RateLimit, Throttler};
382 use crate::clock::TestClock;
383 type SharedThrottler = Rc<UnsafeCell<Throttler<u64, Box<dyn Fn(u64)>>>>;
384
385 #[derive(Clone)]
390 struct TestThrottler {
391 throttler: SharedThrottler,
392 clock: Rc<RefCell<TestClock>>,
393 interval: u64,
394 }
395
396 impl TestThrottler {
397 #[allow(clippy::mut_from_ref)]
398 pub fn get_throttler(&self) -> &mut Throttler<u64, Box<dyn Fn(u64)>> {
399 unsafe { &mut *self.throttler.get() }
400 }
401 }
402
403 #[fixture]
404 pub fn test_throttler_buffered() -> TestThrottler {
405 let output_send: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
406 log::debug!("Sent: {msg}");
407 });
408 let clock = Rc::new(RefCell::new(TestClock::new()));
409 let inner_clock = Rc::clone(&clock);
410 let rate_limit = RateLimit::new(5, 10);
411 let interval = rate_limit.interval_ns;
412 let actor_id = Ustr::from(&UUID4::new().to_string());
413
414 TestThrottler {
415 throttler: Throttler::new(
416 rate_limit.limit,
417 rate_limit.interval_ns,
418 clock,
419 "buffer_timer".to_string(),
420 output_send,
421 None,
422 actor_id,
423 )
424 .to_actor(),
425 clock: inner_clock,
426 interval,
427 }
428 }
429
430 #[fixture]
431 pub fn test_throttler_unbuffered() -> TestThrottler {
432 let output_send: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
433 log::debug!("Sent: {msg}");
434 });
435 let output_drop: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
436 log::debug!("Dropped: {msg}");
437 });
438 let clock = Rc::new(RefCell::new(TestClock::new()));
439 let inner_clock = Rc::clone(&clock);
440 let rate_limit = RateLimit::new(5, 10);
441 let interval = rate_limit.interval_ns;
442 let actor_id = Ustr::from(&UUID4::new().to_string());
443
444 TestThrottler {
445 throttler: Throttler::new(
446 rate_limit.limit,
447 rate_limit.interval_ns,
448 clock,
449 "dropper_timer".to_string(),
450 output_send,
451 Some(output_drop),
452 actor_id,
453 )
454 .to_actor(),
455 clock: inner_clock,
456 interval,
457 }
458 }
459
460 #[rstest]
461 fn test_buffering_send_to_limit_becomes_throttled(test_throttler_buffered: TestThrottler) {
462 let throttler = test_throttler_buffered.get_throttler();
463 for _ in 0..6 {
464 throttler.send(42);
465 }
466 assert_eq!(throttler.qsize(), 1);
467
468 assert!(throttler.is_limiting);
469 assert_eq!(throttler.recv_count, 6);
470 assert_eq!(throttler.sent_count, 5);
471 assert_eq!(throttler.clock.borrow().timer_names(), vec!["buffer_timer"]);
472 }
473
474 #[rstest]
475 fn test_buffering_used_when_sent_to_limit_returns_one(test_throttler_buffered: TestThrottler) {
476 let throttler = test_throttler_buffered.get_throttler();
477
478 for _ in 0..5 {
479 throttler.send(42);
480 }
481
482 assert_eq!(throttler.used(), 1.0);
483 assert_eq!(throttler.recv_count, 5);
484 assert_eq!(throttler.sent_count, 5);
485 }
486
487 #[rstest]
488 fn test_buffering_used_when_half_interval_from_limit_returns_one(
489 test_throttler_buffered: TestThrottler,
490 ) {
491 let throttler = test_throttler_buffered.get_throttler();
492
493 for _ in 0..5 {
494 throttler.send(42);
495 }
496
497 let half_interval = test_throttler_buffered.interval / 2;
498 {
500 let mut clock = test_throttler_buffered.clock.borrow_mut();
501 clock.advance_time(half_interval.into(), true);
502 }
503
504 assert_eq!(throttler.used(), 1.0);
505 assert_eq!(throttler.recv_count, 5);
506 assert_eq!(throttler.sent_count, 5);
507 }
508
509 #[rstest]
510 fn test_buffering_used_before_limit_when_halfway_returns_half(
511 test_throttler_buffered: TestThrottler,
512 ) {
513 let throttler = test_throttler_buffered.get_throttler();
514
515 for _ in 0..3 {
516 throttler.send(42);
517 }
518
519 assert_eq!(throttler.used(), 0.6);
520 assert_eq!(throttler.recv_count, 3);
521 assert_eq!(throttler.sent_count, 3);
522 }
523
524 #[rstest]
525 fn test_buffering_refresh_when_at_limit_sends_remaining_items(
526 test_throttler_buffered: TestThrottler,
527 ) {
528 let throttler = test_throttler_buffered.get_throttler();
529
530 for _ in 0..6 {
531 throttler.send(42);
532 }
533
534 {
536 let mut clock = test_throttler_buffered.clock.borrow_mut();
537 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
538 for each_event in clock.match_handlers(time_events) {
539 drop(clock); each_event.callback.call(each_event.event);
542
543 clock = test_throttler_buffered.clock.borrow_mut();
545 }
546 }
547
548 assert_eq!(throttler.used(), 0.2);
550 assert_eq!(throttler.recv_count, 6);
551 assert_eq!(throttler.sent_count, 6);
552 assert_eq!(throttler.qsize(), 0);
553 }
554
555 #[rstest]
556 fn test_buffering_send_message_after_buffering_message(test_throttler_buffered: TestThrottler) {
557 let throttler = test_throttler_buffered.get_throttler();
558
559 for _ in 0..6 {
560 throttler.send(43);
561 }
562
563 {
565 let mut clock = test_throttler_buffered.clock.borrow_mut();
566 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
567 for each_event in clock.match_handlers(time_events) {
568 drop(clock); each_event.callback.call(each_event.event);
571
572 clock = test_throttler_buffered.clock.borrow_mut();
574 }
575 }
576
577 for _ in 0..6 {
578 throttler.send(42);
579 }
580
581 assert_eq!(throttler.used(), 1.0);
583 assert_eq!(throttler.recv_count, 12);
584 assert_eq!(throttler.sent_count, 10);
585 assert_eq!(throttler.qsize(), 2);
586 }
587
588 #[rstest]
589 fn test_buffering_send_message_after_halfway_after_buffering_message(
590 test_throttler_buffered: TestThrottler,
591 ) {
592 let throttler = test_throttler_buffered.get_throttler();
593
594 for _ in 0..6 {
595 throttler.send(42);
596 }
597
598 {
600 let mut clock = test_throttler_buffered.clock.borrow_mut();
601 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
602 for each_event in clock.match_handlers(time_events) {
603 drop(clock); each_event.callback.call(each_event.event);
606
607 clock = test_throttler_buffered.clock.borrow_mut();
609 }
610 }
611
612 for _ in 0..3 {
613 throttler.send(42);
614 }
615
616 assert_eq!(throttler.used(), 0.8);
618 assert_eq!(throttler.recv_count, 9);
619 assert_eq!(throttler.sent_count, 9);
620 assert_eq!(throttler.qsize(), 0);
621 }
622
623 #[rstest]
624 fn test_dropping_send_sends_message_to_handler(test_throttler_unbuffered: TestThrottler) {
625 let throttler = test_throttler_unbuffered.get_throttler();
626 throttler.send(42);
627
628 assert!(!throttler.is_limiting);
629 assert_eq!(throttler.recv_count, 1);
630 assert_eq!(throttler.sent_count, 1);
631 }
632
633 #[rstest]
634 fn test_dropping_send_to_limit_drops_message(test_throttler_unbuffered: TestThrottler) {
635 let throttler = test_throttler_unbuffered.get_throttler();
636 for _ in 0..6 {
637 throttler.send(42);
638 }
639 assert_eq!(throttler.qsize(), 0);
640
641 assert!(throttler.is_limiting);
642 assert_eq!(throttler.used(), 1.0);
643 assert_eq!(throttler.clock.borrow().timer_count(), 1);
644 assert_eq!(
645 throttler.clock.borrow().timer_names(),
646 vec!["dropper_timer"]
647 );
648 assert_eq!(throttler.recv_count, 6);
649 assert_eq!(throttler.sent_count, 5);
650 }
651
652 #[rstest]
653 fn test_dropping_advance_time_when_at_limit_dropped_message(
654 test_throttler_unbuffered: TestThrottler,
655 ) {
656 let throttler = test_throttler_unbuffered.get_throttler();
657 for _ in 0..6 {
658 throttler.send(42);
659 }
660
661 {
663 let mut clock = test_throttler_unbuffered.clock.borrow_mut();
664 let time_events = clock.advance_time(test_throttler_unbuffered.interval.into(), true);
665 for each_event in clock.match_handlers(time_events) {
666 drop(clock); each_event.callback.call(each_event.event);
669
670 clock = test_throttler_unbuffered.clock.borrow_mut();
672 }
673 }
674
675 assert_eq!(throttler.clock.borrow().timer_count(), 0);
676 assert!(!throttler.is_limiting);
677 assert_eq!(throttler.used(), 0.0);
678 assert_eq!(throttler.recv_count, 6);
679 assert_eq!(throttler.sent_count, 5);
680 }
681
682 #[rstest]
683 fn test_dropping_send_message_after_dropping_message(test_throttler_unbuffered: TestThrottler) {
684 let throttler = test_throttler_unbuffered.get_throttler();
685 for _ in 0..6 {
686 throttler.send(42);
687 }
688
689 {
691 let mut clock = test_throttler_unbuffered.clock.borrow_mut();
692 let time_events = clock.advance_time(test_throttler_unbuffered.interval.into(), true);
693 for each_event in clock.match_handlers(time_events) {
694 drop(clock); each_event.callback.call(each_event.event);
697
698 clock = test_throttler_unbuffered.clock.borrow_mut();
700 }
701 }
702
703 throttler.send(42);
704
705 assert_eq!(throttler.used(), 0.2);
706 assert_eq!(throttler.clock.borrow().timer_count(), 0);
707 assert!(!throttler.is_limiting);
708 assert_eq!(throttler.recv_count, 7);
709 assert_eq!(throttler.sent_count, 6);
710 }
711
712 use proptest::prelude::*;
713
714 #[derive(Clone, Debug)]
715 enum ThrottlerInput {
716 SendMessage(u64),
717 AdvanceClock(u8),
718 }
719
720 fn throttler_input_strategy() -> impl Strategy<Value = ThrottlerInput> {
722 prop_oneof![
723 2 => prop::bool::ANY.prop_map(|_| ThrottlerInput::SendMessage(42)),
724 8 => prop::num::u8::ANY.prop_map(|v| ThrottlerInput::AdvanceClock(v % 5 + 5)),
725 ]
726 }
727
728 fn throttler_test_strategy() -> impl Strategy<Value = Vec<ThrottlerInput>> {
730 prop::collection::vec(throttler_input_strategy(), 10..=150)
731 }
732
733 fn test_throttler_with_inputs(inputs: Vec<ThrottlerInput>, test_throttler: TestThrottler) {
734 let test_clock = test_throttler.clock.clone();
735 let interval = test_throttler.interval;
736 let throttler = test_throttler.get_throttler();
737 let mut sent_count = 0;
738
739 for input in inputs {
740 match input {
741 ThrottlerInput::SendMessage(msg) => {
742 throttler.send(msg);
743 sent_count += 1;
744 }
745 ThrottlerInput::AdvanceClock(duration) => {
746 let mut clock_ref = test_clock.borrow_mut();
747 let current_time = clock_ref.get_time_ns();
748 let time_events =
749 clock_ref.advance_time(current_time + u64::from(duration), true);
750 for each_event in clock_ref.match_handlers(time_events) {
751 drop(clock_ref);
752 each_event.callback.call(each_event.event);
753 clock_ref = test_clock.borrow_mut();
754 }
755 }
756 }
757
758 let buffered_messages = throttler.qsize() > 0;
763 let now = throttler.clock.borrow().timestamp_ns().as_u64();
764 let limit_filled_within_interval = throttler
765 .timestamps
766 .get(throttler.limit - 1)
767 .is_some_and(|&ts| (now - ts.as_u64()) < interval);
768 let expected_limiting = buffered_messages && limit_filled_within_interval;
769 assert_eq!(throttler.is_limiting, expected_limiting);
770
771 assert_eq!(sent_count, throttler.sent_count + throttler.qsize());
773 }
774
775 let time_events = test_clock
777 .borrow_mut()
778 .advance_time((interval * 100).into(), true);
779 let mut clock_ref = test_clock.borrow_mut();
780 for each_event in clock_ref.match_handlers(time_events) {
781 drop(clock_ref);
782 each_event.callback.call(each_event.event);
783 clock_ref = test_clock.borrow_mut();
784 }
785 assert_eq!(throttler.qsize(), 0);
786 }
787
788 #[test]
789 #[ignore = "Used for manually testing failing cases"]
790 fn test_case() {
791 let inputs = [
792 ThrottlerInput::SendMessage(42),
793 ThrottlerInput::AdvanceClock(5),
794 ThrottlerInput::SendMessage(42),
795 ThrottlerInput::SendMessage(42),
796 ThrottlerInput::SendMessage(42),
797 ThrottlerInput::SendMessage(42),
798 ThrottlerInput::SendMessage(42),
799 ThrottlerInput::AdvanceClock(5),
800 ThrottlerInput::SendMessage(42),
801 ThrottlerInput::SendMessage(42),
802 ]
803 .to_vec();
804
805 let test_throttler = test_throttler_buffered();
806 test_throttler_with_inputs(inputs, test_throttler);
807 }
808
809 #[test]
810 fn prop_test() {
811 let test_throttler = test_throttler_buffered();
812
813 proptest!(move |(inputs in throttler_test_strategy())| {
814 test_throttler_with_inputs(inputs, test_throttler.clone());
815 let throttler = unsafe { &mut *(test_throttler.throttler.get() as *mut _ as *mut Throttler<u64, Box<dyn Fn(u64)>>) };
817 throttler.reset();
818 throttler.clock.borrow_mut().reset();
819 });
820 }
821}