1use std::{
23 any::Any,
24 cell::{RefCell, UnsafeCell},
25 collections::VecDeque,
26 fmt::Debug,
27 marker::PhantomData,
28 rc::Rc,
29};
30
31use nautilus_core::{UnixNanos, correctness::FAILED};
32use ustr::Ustr;
33
34use crate::{
35 actor::{
36 Actor,
37 registry::{get_actor_unchecked, register_actor},
38 },
39 clock::Clock,
40 msgbus::{self, Endpoint, Handler, MStr, ShareableMessageHandler},
41 timer::{TimeEvent, TimeEventCallback},
42};
43
44#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct RateLimit {
47 pub limit: usize,
48 pub interval_ns: u64,
49}
50
51impl RateLimit {
52 #[must_use]
54 pub const fn new(limit: usize, interval_ns: u64) -> Self {
55 Self { limit, interval_ns }
56 }
57}
58
59pub struct Throttler<T, F> {
64 pub recv_count: usize,
66 pub sent_count: usize,
68 pub is_limiting: bool,
70 pub limit: usize,
72 pub buffer: VecDeque<T>,
74 pub timestamps: VecDeque<UnixNanos>,
76 pub clock: Rc<RefCell<dyn Clock>>,
78 pub actor_id: Ustr,
80 interval: u64,
82 timer_name: Ustr,
84 output_send: F,
86 output_drop: Option<F>,
88}
89
90impl<T, F> Actor for Throttler<T, F>
91where
92 T: 'static + Debug,
93 F: Fn(T) + 'static,
94{
95 fn id(&self) -> Ustr {
96 self.actor_id
97 }
98
99 fn handle(&mut self, _msg: &dyn Any) {}
100
101 fn as_any(&self) -> &dyn Any {
102 self
103 }
104}
105
106impl<T, F> Debug for Throttler<T, F>
107where
108 T: Debug,
109{
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 f.debug_struct(stringify!(InnerThrottler))
112 .field("recv_count", &self.recv_count)
113 .field("sent_count", &self.sent_count)
114 .field("is_limiting", &self.is_limiting)
115 .field("limit", &self.limit)
116 .field("buffer", &self.buffer)
117 .field("timestamps", &self.timestamps)
118 .field("interval", &self.interval)
119 .field("timer_name", &self.timer_name)
120 .finish()
121 }
122}
123
124impl<T, F> Throttler<T, F>
125where
126 T: Debug,
127{
128 #[inline]
129 pub fn new(
130 limit: usize,
131 interval: u64,
132 clock: Rc<RefCell<dyn Clock>>,
133 timer_name: String,
134 output_send: F,
135 output_drop: Option<F>,
136 actor_id: Ustr,
137 ) -> Self {
138 Self {
139 recv_count: 0,
140 sent_count: 0,
141 is_limiting: false,
142 limit,
143 buffer: VecDeque::new(),
144 timestamps: VecDeque::with_capacity(limit),
145 clock,
146 interval,
147 timer_name: Ustr::from(&timer_name),
148 output_send,
149 output_drop,
150 actor_id,
151 }
152 }
153
154 #[inline]
164 pub fn set_timer(&mut self, callback: Option<TimeEventCallback>) {
165 let delta = self.delta_next();
166 let mut clock = self.clock.borrow_mut();
167 if clock.timer_exists(&self.timer_name) {
168 clock.cancel_timer(&self.timer_name);
169 }
170 let alert_ts = clock.timestamp_ns() + delta;
171
172 clock
173 .set_time_alert_ns(&self.timer_name, alert_ts, callback, None)
174 .expect(FAILED);
175 }
176
177 #[inline]
179 pub fn delta_next(&mut self) -> u64 {
180 match self.timestamps.get(self.limit - 1) {
181 Some(ts) => {
182 let diff = self.clock.borrow().timestamp_ns().as_u64() - ts.as_u64();
183 self.interval.saturating_sub(diff)
184 }
185 None => 0,
186 }
187 }
188
189 #[inline]
191 pub fn reset(&mut self) {
192 self.buffer.clear();
193 self.recv_count = 0;
194 self.sent_count = 0;
195 self.is_limiting = false;
196 self.timestamps.clear();
197 }
198
199 #[inline]
201 pub fn used(&self) -> f64 {
202 if self.timestamps.is_empty() {
203 return 0.0;
204 }
205
206 let now = self.clock.borrow().timestamp_ns().as_i64();
207 let interval_start = now - self.interval as i64;
208
209 let messages_in_current_interval = self
210 .timestamps
211 .iter()
212 .take_while(|&&ts| ts.as_i64() > interval_start)
213 .count();
214
215 (messages_in_current_interval as f64) / (self.limit as f64)
216 }
217
218 #[inline]
220 pub fn qsize(&self) -> usize {
221 self.buffer.len()
222 }
223}
224
225impl<T, F> Throttler<T, F>
226where
227 T: 'static + Debug,
228 F: Fn(T) + 'static,
229{
230 pub fn to_actor(self) -> Rc<UnsafeCell<Self>> {
231 let process_handler = ThrottlerProcess::<T, F>::new(self.actor_id);
233 msgbus::register_any(
234 process_handler.id().as_str().into(),
235 ShareableMessageHandler::from(Rc::new(process_handler) as Rc<dyn Handler<dyn Any>>),
236 );
237
238 register_actor(self)
240 }
241
242 #[inline]
243 pub fn send_msg(&mut self, msg: T) {
244 let now = self.clock.borrow().timestamp_ns();
245
246 if self.timestamps.len() >= self.limit {
247 self.timestamps.pop_back();
248 }
249 self.timestamps.push_front(now);
250
251 self.sent_count += 1;
252 (self.output_send)(msg);
253 }
254
255 #[inline]
256 pub fn limit_msg(&mut self, msg: T) {
257 let callback = if self.output_drop.is_none() {
258 self.buffer.push_front(msg);
259 log::debug!("Buffering {}", self.buffer.len());
260 Some(ThrottlerProcess::<T, F>::new(self.actor_id).get_timer_callback())
261 } else {
262 log::debug!("Dropping");
263 if let Some(drop) = &self.output_drop {
264 drop(msg);
265 }
266 Some(throttler_resume::<T, F>(self.actor_id))
267 };
268 if !self.is_limiting {
269 log::debug!("Limiting");
270 self.set_timer(callback);
271 self.is_limiting = true;
272 }
273 }
274
275 #[inline]
276 pub fn send(&mut self, msg: T)
277 where
278 T: 'static,
279 F: Fn(T) + 'static,
280 {
281 self.recv_count += 1;
282
283 if self.is_limiting || self.delta_next() > 0 {
284 self.limit_msg(msg);
285 } else {
286 self.send_msg(msg);
287 }
288 }
289}
290
291struct ThrottlerProcess<T, F> {
296 actor_id: Ustr,
297 endpoint: MStr<Endpoint>,
298 phantom_t: PhantomData<T>,
299 phantom_f: PhantomData<F>,
300}
301
302impl<T, F> ThrottlerProcess<T, F>
303where
304 T: Debug,
305{
306 pub fn new(actor_id: Ustr) -> Self {
307 let endpoint = MStr::endpoint(format!("{actor_id}_process")).expect(FAILED);
308 Self {
309 actor_id,
310 endpoint,
311 phantom_t: PhantomData,
312 phantom_f: PhantomData,
313 }
314 }
315
316 pub fn get_timer_callback(&self) -> TimeEventCallback {
317 let endpoint = self.endpoint;
318 TimeEventCallback::from(move |event: TimeEvent| {
319 msgbus::send_any(endpoint, &(event));
320 })
321 }
322}
323
324impl<T, F> Handler<dyn Any> for ThrottlerProcess<T, F>
325where
326 T: 'static + Debug,
327 F: Fn(T) + 'static,
328{
329 fn id(&self) -> Ustr {
330 *self.endpoint
331 }
332
333 fn handle(&self, _message: &dyn Any) {
334 let mut throttler = get_actor_unchecked::<Throttler<T, F>>(&self.actor_id);
335 while let Some(msg) = throttler.buffer.pop_back() {
336 throttler.send_msg(msg);
337
338 if !throttler.buffer.is_empty() && throttler.delta_next() > 0 {
342 throttler.is_limiting = true;
343
344 let endpoint = self.endpoint;
345
346 throttler.set_timer(Some(TimeEventCallback::from(move |event: TimeEvent| {
348 msgbus::send_any(endpoint, &(event));
349 })));
350 return;
351 }
352 }
353
354 throttler.is_limiting = false;
355 }
356}
357
358pub fn throttler_resume<T, F>(actor_id: Ustr) -> TimeEventCallback
360where
361 T: 'static + Debug,
362 F: Fn(T) + 'static,
363{
364 TimeEventCallback::from(move |_event: TimeEvent| {
365 let mut throttler = get_actor_unchecked::<Throttler<T, F>>(&actor_id);
366 throttler.is_limiting = false;
367 })
368}
369
370#[cfg(test)]
371mod tests {
372 use std::{
373 cell::{RefCell, UnsafeCell},
374 rc::Rc,
375 };
376
377 use nautilus_core::UUID4;
378 use rstest::{fixture, rstest};
379 use ustr::Ustr;
380
381 use super::{RateLimit, Throttler, ThrottlerProcess};
382 use crate::{clock::TestClock, msgbus::Handler};
383 type SharedThrottler = Rc<UnsafeCell<Throttler<u64, Box<dyn Fn(u64)>>>>;
384
385 #[derive(Clone)]
390 struct TestThrottler {
391 throttler: SharedThrottler,
392 clock: Rc<RefCell<TestClock>>,
393 interval: u64,
394 }
395
396 #[allow(unsafe_code)]
397 impl TestThrottler {
398 #[allow(clippy::mut_from_ref)]
399 pub fn get_throttler(&self) -> &mut Throttler<u64, Box<dyn Fn(u64)>> {
400 unsafe { &mut *self.throttler.get() }
401 }
402 }
403
404 #[fixture]
405 pub fn test_throttler_buffered() -> TestThrottler {
406 let output_send: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
407 log::debug!("Sent: {msg}");
408 });
409 let clock = Rc::new(RefCell::new(TestClock::new()));
410 let inner_clock = Rc::clone(&clock);
411 let rate_limit = RateLimit::new(5, 10);
412 let interval = rate_limit.interval_ns;
413 let actor_id = Ustr::from(UUID4::new().as_str());
414
415 TestThrottler {
416 throttler: Throttler::new(
417 rate_limit.limit,
418 rate_limit.interval_ns,
419 clock,
420 "buffer_timer".to_string(),
421 output_send,
422 None,
423 actor_id,
424 )
425 .to_actor(),
426 clock: inner_clock,
427 interval,
428 }
429 }
430
431 #[fixture]
432 pub fn test_throttler_unbuffered() -> TestThrottler {
433 let output_send: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
434 log::debug!("Sent: {msg}");
435 });
436 let output_drop: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
437 log::debug!("Dropped: {msg}");
438 });
439 let clock = Rc::new(RefCell::new(TestClock::new()));
440 let inner_clock = Rc::clone(&clock);
441 let rate_limit = RateLimit::new(5, 10);
442 let interval = rate_limit.interval_ns;
443 let actor_id = Ustr::from(UUID4::new().as_str());
444
445 TestThrottler {
446 throttler: Throttler::new(
447 rate_limit.limit,
448 rate_limit.interval_ns,
449 clock,
450 "dropper_timer".to_string(),
451 output_send,
452 Some(output_drop),
453 actor_id,
454 )
455 .to_actor(),
456 clock: inner_clock,
457 interval,
458 }
459 }
460
461 #[rstest]
462 fn test_buffering_send_to_limit_becomes_throttled(test_throttler_buffered: TestThrottler) {
463 let throttler = test_throttler_buffered.get_throttler();
464 for _ in 0..6 {
465 throttler.send(42);
466 }
467 assert_eq!(throttler.qsize(), 1);
468
469 assert!(throttler.is_limiting);
470 assert_eq!(throttler.recv_count, 6);
471 assert_eq!(throttler.sent_count, 5);
472 assert_eq!(throttler.clock.borrow().timer_names(), vec!["buffer_timer"]);
473 }
474
475 #[rstest]
476 fn test_buffering_used_when_sent_to_limit_returns_one(test_throttler_buffered: TestThrottler) {
477 let throttler = test_throttler_buffered.get_throttler();
478
479 for _ in 0..5 {
480 throttler.send(42);
481 }
482
483 assert_eq!(throttler.used(), 1.0);
484 assert_eq!(throttler.recv_count, 5);
485 assert_eq!(throttler.sent_count, 5);
486 }
487
488 #[rstest]
489 fn test_buffering_used_when_half_interval_from_limit_returns_one(
490 test_throttler_buffered: TestThrottler,
491 ) {
492 let throttler = test_throttler_buffered.get_throttler();
493
494 for _ in 0..5 {
495 throttler.send(42);
496 }
497
498 let half_interval = test_throttler_buffered.interval / 2;
499 {
501 let mut clock = test_throttler_buffered.clock.borrow_mut();
502 clock.advance_time(half_interval.into(), true);
503 }
504
505 assert_eq!(throttler.used(), 1.0);
506 assert_eq!(throttler.recv_count, 5);
507 assert_eq!(throttler.sent_count, 5);
508 }
509
510 #[rstest]
511 fn test_buffering_used_before_limit_when_halfway_returns_half(
512 test_throttler_buffered: TestThrottler,
513 ) {
514 let throttler = test_throttler_buffered.get_throttler();
515
516 for _ in 0..3 {
517 throttler.send(42);
518 }
519
520 assert_eq!(throttler.used(), 0.6);
521 assert_eq!(throttler.recv_count, 3);
522 assert_eq!(throttler.sent_count, 3);
523 }
524
525 #[rstest]
526 fn test_buffering_refresh_when_at_limit_sends_remaining_items(
527 test_throttler_buffered: TestThrottler,
528 ) {
529 let throttler = test_throttler_buffered.get_throttler();
530
531 for _ in 0..6 {
532 throttler.send(42);
533 }
534
535 {
537 let mut clock = test_throttler_buffered.clock.borrow_mut();
538 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
539 for each_event in clock.match_handlers(time_events) {
540 drop(clock); each_event.callback.call(each_event.event);
543
544 clock = test_throttler_buffered.clock.borrow_mut();
546 }
547 }
548
549 assert_eq!(throttler.used(), 0.2);
551 assert_eq!(throttler.recv_count, 6);
552 assert_eq!(throttler.sent_count, 6);
553 assert_eq!(throttler.qsize(), 0);
554 }
555
556 #[rstest]
557 fn test_buffering_send_message_after_buffering_message(test_throttler_buffered: TestThrottler) {
558 let throttler = test_throttler_buffered.get_throttler();
559
560 for _ in 0..6 {
561 throttler.send(43);
562 }
563
564 {
566 let mut clock = test_throttler_buffered.clock.borrow_mut();
567 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
568 for each_event in clock.match_handlers(time_events) {
569 drop(clock); each_event.callback.call(each_event.event);
572
573 clock = test_throttler_buffered.clock.borrow_mut();
575 }
576 }
577
578 for _ in 0..6 {
579 throttler.send(42);
580 }
581
582 assert_eq!(throttler.used(), 1.0);
584 assert_eq!(throttler.recv_count, 12);
585 assert_eq!(throttler.sent_count, 10);
586 assert_eq!(throttler.qsize(), 2);
587 }
588
589 #[rstest]
590 fn test_buffering_send_message_after_halfway_after_buffering_message(
591 test_throttler_buffered: TestThrottler,
592 ) {
593 let throttler = test_throttler_buffered.get_throttler();
594
595 for _ in 0..6 {
596 throttler.send(42);
597 }
598
599 {
601 let mut clock = test_throttler_buffered.clock.borrow_mut();
602 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
603 for each_event in clock.match_handlers(time_events) {
604 drop(clock); each_event.callback.call(each_event.event);
607
608 clock = test_throttler_buffered.clock.borrow_mut();
610 }
611 }
612
613 for _ in 0..3 {
614 throttler.send(42);
615 }
616
617 assert_eq!(throttler.used(), 0.8);
619 assert_eq!(throttler.recv_count, 9);
620 assert_eq!(throttler.sent_count, 9);
621 assert_eq!(throttler.qsize(), 0);
622 }
623
624 #[rstest]
625 fn test_dropping_send_sends_message_to_handler(test_throttler_unbuffered: TestThrottler) {
626 let throttler = test_throttler_unbuffered.get_throttler();
627 throttler.send(42);
628
629 assert!(!throttler.is_limiting);
630 assert_eq!(throttler.recv_count, 1);
631 assert_eq!(throttler.sent_count, 1);
632 }
633
634 #[rstest]
635 fn test_dropping_send_to_limit_drops_message(test_throttler_unbuffered: TestThrottler) {
636 let throttler = test_throttler_unbuffered.get_throttler();
637 for _ in 0..6 {
638 throttler.send(42);
639 }
640 assert_eq!(throttler.qsize(), 0);
641
642 assert!(throttler.is_limiting);
643 assert_eq!(throttler.used(), 1.0);
644 assert_eq!(throttler.clock.borrow().timer_count(), 1);
645 assert_eq!(
646 throttler.clock.borrow().timer_names(),
647 vec!["dropper_timer"]
648 );
649 assert_eq!(throttler.recv_count, 6);
650 assert_eq!(throttler.sent_count, 5);
651 }
652
653 #[rstest]
654 fn test_dropping_advance_time_when_at_limit_dropped_message(
655 test_throttler_unbuffered: TestThrottler,
656 ) {
657 let throttler = test_throttler_unbuffered.get_throttler();
658 for _ in 0..6 {
659 throttler.send(42);
660 }
661
662 {
664 let mut clock = test_throttler_unbuffered.clock.borrow_mut();
665 let time_events = clock.advance_time(test_throttler_unbuffered.interval.into(), true);
666 for each_event in clock.match_handlers(time_events) {
667 drop(clock); each_event.callback.call(each_event.event);
670
671 clock = test_throttler_unbuffered.clock.borrow_mut();
673 }
674 }
675
676 assert_eq!(throttler.clock.borrow().timer_count(), 0);
677 assert!(!throttler.is_limiting);
678 assert_eq!(throttler.used(), 0.0);
679 assert_eq!(throttler.recv_count, 6);
680 assert_eq!(throttler.sent_count, 5);
681 }
682
683 #[rstest]
684 fn test_dropping_send_message_after_dropping_message(test_throttler_unbuffered: TestThrottler) {
685 let throttler = test_throttler_unbuffered.get_throttler();
686 for _ in 0..6 {
687 throttler.send(42);
688 }
689
690 {
692 let mut clock = test_throttler_unbuffered.clock.borrow_mut();
693 let time_events = clock.advance_time(test_throttler_unbuffered.interval.into(), true);
694 for each_event in clock.match_handlers(time_events) {
695 drop(clock); each_event.callback.call(each_event.event);
698
699 clock = test_throttler_unbuffered.clock.borrow_mut();
701 }
702 }
703
704 throttler.send(42);
705
706 assert_eq!(throttler.used(), 0.2);
707 assert_eq!(throttler.clock.borrow().timer_count(), 0);
708 assert!(!throttler.is_limiting);
709 assert_eq!(throttler.recv_count, 7);
710 assert_eq!(throttler.sent_count, 6);
711 }
712
713 use proptest::prelude::*;
718
719 #[derive(Clone, Debug)]
720 enum ThrottlerInput {
721 SendMessage(u64),
722 AdvanceClock(u8),
723 }
724
725 fn throttler_input_strategy() -> impl Strategy<Value = ThrottlerInput> {
727 prop_oneof![
728 2 => prop::bool::ANY.prop_map(|_| ThrottlerInput::SendMessage(42)),
729 8 => prop::num::u8::ANY.prop_map(|v| ThrottlerInput::AdvanceClock(v % 5 + 5)),
730 ]
731 }
732
733 fn throttler_test_strategy() -> impl Strategy<Value = Vec<ThrottlerInput>> {
735 prop::collection::vec(throttler_input_strategy(), 10..=150)
736 }
737
738 fn test_throttler_with_inputs(inputs: Vec<ThrottlerInput>, test_throttler: TestThrottler) {
739 let test_clock = test_throttler.clock.clone();
740 let interval = test_throttler.interval;
741 let throttler = test_throttler.get_throttler();
742 let mut sent_count = 0;
743
744 for input in inputs {
745 match input {
746 ThrottlerInput::SendMessage(msg) => {
747 throttler.send(msg);
748 sent_count += 1;
749 }
750 ThrottlerInput::AdvanceClock(duration) => {
751 let mut clock_ref = test_clock.borrow_mut();
752 let current_time = clock_ref.get_time_ns();
753 let time_events =
754 clock_ref.advance_time(current_time + u64::from(duration), true);
755 for each_event in clock_ref.match_handlers(time_events) {
756 drop(clock_ref);
757 each_event.callback.call(each_event.event);
758 clock_ref = test_clock.borrow_mut();
759 }
760 }
761 }
762
763 let buffered_messages = throttler.qsize() > 0;
768 let now = throttler.clock.borrow().timestamp_ns().as_u64();
769 let limit_filled_within_interval = throttler
770 .timestamps
771 .get(throttler.limit - 1)
772 .is_some_and(|&ts| (now - ts.as_u64()) < interval);
773 let expected_limiting = buffered_messages && limit_filled_within_interval;
774 assert_eq!(throttler.is_limiting, expected_limiting);
775
776 assert_eq!(sent_count, throttler.sent_count + throttler.qsize());
778 }
779
780 let time_events = test_clock
782 .borrow_mut()
783 .advance_time((interval * 100).into(), true);
784 let mut clock_ref = test_clock.borrow_mut();
785 for each_event in clock_ref.match_handlers(time_events) {
786 drop(clock_ref);
787 each_event.callback.call(each_event.event);
788 clock_ref = test_clock.borrow_mut();
789 }
790 assert_eq!(throttler.qsize(), 0);
791 }
792
793 #[rstest]
794 fn prop_test() {
795 proptest!(|(inputs in throttler_test_strategy())| {
798 let test_throttler = test_throttler_buffered();
799 test_throttler_with_inputs(inputs, test_throttler);
800 });
801 }
802
803 #[rstest]
804 fn test_throttler_process_id_returns_ustr() {
805 let actor_id = Ustr::from("test_throttler");
808 let process = ThrottlerProcess::<String, fn(String)>::new(actor_id);
809
810 let handler_id: Ustr = process.id();
812
813 assert!(handler_id.as_str().contains("test_throttler_process"));
815 assert!(!handler_id.is_empty());
816
817 let _type_check: Ustr = handler_id;
819 }
820}