1use std::{
23 any::Any,
24 cell::{RefCell, UnsafeCell},
25 collections::VecDeque,
26 fmt::Debug,
27 marker::PhantomData,
28 rc::Rc,
29};
30
31use nautilus_core::{UnixNanos, correctness::FAILED};
32use ustr::Ustr;
33
34use crate::{
35 actor::{
36 Actor,
37 registry::{get_actor_unchecked, register_actor},
38 },
39 clock::Clock,
40 msgbus::{
41 self, Endpoint, MStr,
42 handler::{MessageHandler, ShareableMessageHandler},
43 },
44 timer::{TimeEvent, TimeEventCallback},
45};
46
47#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct RateLimit {
50 pub limit: usize,
51 pub interval_ns: u64,
52}
53
54impl RateLimit {
55 #[must_use]
57 pub const fn new(limit: usize, interval_ns: u64) -> Self {
58 Self { limit, interval_ns }
59 }
60}
61
62pub struct Throttler<T, F> {
67 pub recv_count: usize,
69 pub sent_count: usize,
71 pub is_limiting: bool,
73 pub limit: usize,
75 pub buffer: VecDeque<T>,
77 pub timestamps: VecDeque<UnixNanos>,
79 pub clock: Rc<RefCell<dyn Clock>>,
81 pub actor_id: Ustr,
83 interval: u64,
85 timer_name: Ustr,
87 output_send: F,
89 output_drop: Option<F>,
91}
92
93impl<T, F> Actor for Throttler<T, F>
94where
95 T: 'static + Debug,
96 F: Fn(T) + 'static,
97{
98 fn id(&self) -> Ustr {
99 self.actor_id
100 }
101
102 fn handle(&mut self, _msg: &dyn Any) {}
103
104 fn as_any(&self) -> &dyn Any {
105 self
106 }
107}
108
109impl<T, F> Debug for Throttler<T, F>
110where
111 T: Debug,
112{
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct(stringify!(InnerThrottler))
115 .field("recv_count", &self.recv_count)
116 .field("sent_count", &self.sent_count)
117 .field("is_limiting", &self.is_limiting)
118 .field("limit", &self.limit)
119 .field("buffer", &self.buffer)
120 .field("timestamps", &self.timestamps)
121 .field("interval", &self.interval)
122 .field("timer_name", &self.timer_name)
123 .finish()
124 }
125}
126
127impl<T, F> Throttler<T, F>
128where
129 T: Debug,
130{
131 #[inline]
132 pub fn new(
133 limit: usize,
134 interval: u64,
135 clock: Rc<RefCell<dyn Clock>>,
136 timer_name: String,
137 output_send: F,
138 output_drop: Option<F>,
139 actor_id: Ustr,
140 ) -> Self {
141 Self {
142 recv_count: 0,
143 sent_count: 0,
144 is_limiting: false,
145 limit,
146 buffer: VecDeque::new(),
147 timestamps: VecDeque::with_capacity(limit),
148 clock,
149 interval,
150 timer_name: Ustr::from(&timer_name),
151 output_send,
152 output_drop,
153 actor_id,
154 }
155 }
156
157 #[inline]
167 pub fn set_timer(&mut self, callback: Option<TimeEventCallback>) {
168 let delta = self.delta_next();
169 let mut clock = self.clock.borrow_mut();
170 if clock.timer_exists(&self.timer_name) {
171 clock.cancel_timer(&self.timer_name);
172 }
173 let alert_ts = clock.timestamp_ns() + delta;
174
175 clock
176 .set_time_alert_ns(&self.timer_name, alert_ts, callback, None)
177 .expect(FAILED);
178 }
179
180 #[inline]
182 pub fn delta_next(&mut self) -> u64 {
183 match self.timestamps.get(self.limit - 1) {
184 Some(ts) => {
185 let diff = self.clock.borrow().timestamp_ns().as_u64() - ts.as_u64();
186 self.interval.saturating_sub(diff)
187 }
188 None => 0,
189 }
190 }
191
192 #[inline]
194 pub fn reset(&mut self) {
195 self.buffer.clear();
196 self.recv_count = 0;
197 self.sent_count = 0;
198 self.is_limiting = false;
199 self.timestamps.clear();
200 }
201
202 #[inline]
204 pub fn used(&self) -> f64 {
205 if self.timestamps.is_empty() {
206 return 0.0;
207 }
208
209 let now = self.clock.borrow().timestamp_ns().as_i64();
210 let interval_start = now - self.interval as i64;
211
212 let messages_in_current_interval = self
213 .timestamps
214 .iter()
215 .take_while(|&&ts| ts.as_i64() > interval_start)
216 .count();
217
218 (messages_in_current_interval as f64) / (self.limit as f64)
219 }
220
221 #[inline]
223 pub fn qsize(&self) -> usize {
224 self.buffer.len()
225 }
226}
227
228impl<T, F> Throttler<T, F>
229where
230 T: 'static + Debug,
231 F: Fn(T) + 'static,
232{
233 pub fn to_actor(self) -> Rc<UnsafeCell<Self>> {
234 let process_handler = ThrottlerProcess::<T, F>::new(self.actor_id);
236 msgbus::register(
237 process_handler.id().as_str().into(),
238 ShareableMessageHandler::from(Rc::new(process_handler) as Rc<dyn MessageHandler>),
239 );
240
241 register_actor(self)
243 }
244
245 #[inline]
246 pub fn send_msg(&mut self, msg: T) {
247 let now = self.clock.borrow().timestamp_ns();
248
249 if self.timestamps.len() >= self.limit {
250 self.timestamps.pop_back();
251 }
252 self.timestamps.push_front(now);
253
254 self.sent_count += 1;
255 (self.output_send)(msg);
256 }
257
258 #[inline]
259 pub fn limit_msg(&mut self, msg: T) {
260 let callback = if self.output_drop.is_none() {
261 self.buffer.push_front(msg);
262 log::debug!("Buffering {}", self.buffer.len());
263 Some(ThrottlerProcess::<T, F>::new(self.actor_id).get_timer_callback())
264 } else {
265 log::debug!("Dropping");
266 if let Some(drop) = &self.output_drop {
267 drop(msg);
268 }
269 Some(throttler_resume::<T, F>(self.actor_id))
270 };
271 if !self.is_limiting {
272 log::debug!("Limiting");
273 self.set_timer(callback);
274 self.is_limiting = true;
275 }
276 }
277
278 #[inline]
279 pub fn send(&mut self, msg: T)
280 where
281 T: 'static,
282 F: Fn(T) + 'static,
283 {
284 self.recv_count += 1;
285
286 if self.is_limiting || self.delta_next() > 0 {
287 self.limit_msg(msg);
288 } else {
289 self.send_msg(msg);
290 }
291 }
292}
293
294struct ThrottlerProcess<T, F> {
299 actor_id: Ustr,
300 endpoint: MStr<Endpoint>,
301 phantom_t: PhantomData<T>,
302 phantom_f: PhantomData<F>,
303}
304
305impl<T, F> ThrottlerProcess<T, F>
306where
307 T: Debug,
308{
309 pub fn new(actor_id: Ustr) -> Self {
310 let endpoint = MStr::endpoint(format!("{actor_id}_process")).expect(FAILED);
311 Self {
312 actor_id,
313 endpoint,
314 phantom_t: PhantomData,
315 phantom_f: PhantomData,
316 }
317 }
318
319 pub fn get_timer_callback(&self) -> TimeEventCallback {
320 let endpoint = self.endpoint;
321 TimeEventCallback::from(move |event: TimeEvent| {
322 msgbus::send_any(endpoint, &(event));
323 })
324 }
325}
326
327impl<T, F> MessageHandler for ThrottlerProcess<T, F>
328where
329 T: 'static + Debug,
330 F: Fn(T) + 'static,
331{
332 fn id(&self) -> Ustr {
333 *self.endpoint
334 }
335
336 fn handle(&self, _message: &dyn Any) {
337 let mut throttler = get_actor_unchecked::<Throttler<T, F>>(&self.actor_id);
338 while let Some(msg) = throttler.buffer.pop_back() {
339 throttler.send_msg(msg);
340
341 if !throttler.buffer.is_empty() && throttler.delta_next() > 0 {
345 throttler.is_limiting = true;
346
347 let endpoint = self.endpoint;
348
349 throttler.set_timer(Some(TimeEventCallback::from(move |event: TimeEvent| {
351 msgbus::send_any(endpoint, &(event));
352 })));
353 return;
354 }
355 }
356
357 throttler.is_limiting = false;
358 }
359
360 fn as_any(&self) -> &dyn Any {
361 self
362 }
363}
364
365pub fn throttler_resume<T, F>(actor_id: Ustr) -> TimeEventCallback
367where
368 T: 'static + Debug,
369 F: Fn(T) + 'static,
370{
371 TimeEventCallback::from(move |_event: TimeEvent| {
372 let mut throttler = get_actor_unchecked::<Throttler<T, F>>(&actor_id);
373 throttler.is_limiting = false;
374 })
375}
376
377#[cfg(test)]
378mod tests {
379 use std::{
380 cell::{RefCell, UnsafeCell},
381 rc::Rc,
382 };
383
384 use nautilus_core::UUID4;
385 use rstest::{fixture, rstest};
386 use ustr::Ustr;
387
388 use super::{RateLimit, Throttler, ThrottlerProcess};
389 use crate::{clock::TestClock, msgbus::handler::MessageHandler};
390 type SharedThrottler = Rc<UnsafeCell<Throttler<u64, Box<dyn Fn(u64)>>>>;
391
392 #[derive(Clone)]
397 struct TestThrottler {
398 throttler: SharedThrottler,
399 clock: Rc<RefCell<TestClock>>,
400 interval: u64,
401 }
402
403 #[allow(unsafe_code)]
404 impl TestThrottler {
405 #[allow(clippy::mut_from_ref)]
406 pub fn get_throttler(&self) -> &mut Throttler<u64, Box<dyn Fn(u64)>> {
407 unsafe { &mut *self.throttler.get() }
408 }
409 }
410
411 #[fixture]
412 pub fn test_throttler_buffered() -> TestThrottler {
413 let output_send: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
414 log::debug!("Sent: {msg}");
415 });
416 let clock = Rc::new(RefCell::new(TestClock::new()));
417 let inner_clock = Rc::clone(&clock);
418 let rate_limit = RateLimit::new(5, 10);
419 let interval = rate_limit.interval_ns;
420 let actor_id = Ustr::from(UUID4::new().as_str());
421
422 TestThrottler {
423 throttler: Throttler::new(
424 rate_limit.limit,
425 rate_limit.interval_ns,
426 clock,
427 "buffer_timer".to_string(),
428 output_send,
429 None,
430 actor_id,
431 )
432 .to_actor(),
433 clock: inner_clock,
434 interval,
435 }
436 }
437
438 #[fixture]
439 pub fn test_throttler_unbuffered() -> TestThrottler {
440 let output_send: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
441 log::debug!("Sent: {msg}");
442 });
443 let output_drop: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
444 log::debug!("Dropped: {msg}");
445 });
446 let clock = Rc::new(RefCell::new(TestClock::new()));
447 let inner_clock = Rc::clone(&clock);
448 let rate_limit = RateLimit::new(5, 10);
449 let interval = rate_limit.interval_ns;
450 let actor_id = Ustr::from(UUID4::new().as_str());
451
452 TestThrottler {
453 throttler: Throttler::new(
454 rate_limit.limit,
455 rate_limit.interval_ns,
456 clock,
457 "dropper_timer".to_string(),
458 output_send,
459 Some(output_drop),
460 actor_id,
461 )
462 .to_actor(),
463 clock: inner_clock,
464 interval,
465 }
466 }
467
468 #[rstest]
469 fn test_buffering_send_to_limit_becomes_throttled(test_throttler_buffered: TestThrottler) {
470 let throttler = test_throttler_buffered.get_throttler();
471 for _ in 0..6 {
472 throttler.send(42);
473 }
474 assert_eq!(throttler.qsize(), 1);
475
476 assert!(throttler.is_limiting);
477 assert_eq!(throttler.recv_count, 6);
478 assert_eq!(throttler.sent_count, 5);
479 assert_eq!(throttler.clock.borrow().timer_names(), vec!["buffer_timer"]);
480 }
481
482 #[rstest]
483 fn test_buffering_used_when_sent_to_limit_returns_one(test_throttler_buffered: TestThrottler) {
484 let throttler = test_throttler_buffered.get_throttler();
485
486 for _ in 0..5 {
487 throttler.send(42);
488 }
489
490 assert_eq!(throttler.used(), 1.0);
491 assert_eq!(throttler.recv_count, 5);
492 assert_eq!(throttler.sent_count, 5);
493 }
494
495 #[rstest]
496 fn test_buffering_used_when_half_interval_from_limit_returns_one(
497 test_throttler_buffered: TestThrottler,
498 ) {
499 let throttler = test_throttler_buffered.get_throttler();
500
501 for _ in 0..5 {
502 throttler.send(42);
503 }
504
505 let half_interval = test_throttler_buffered.interval / 2;
506 {
508 let mut clock = test_throttler_buffered.clock.borrow_mut();
509 clock.advance_time(half_interval.into(), true);
510 }
511
512 assert_eq!(throttler.used(), 1.0);
513 assert_eq!(throttler.recv_count, 5);
514 assert_eq!(throttler.sent_count, 5);
515 }
516
517 #[rstest]
518 fn test_buffering_used_before_limit_when_halfway_returns_half(
519 test_throttler_buffered: TestThrottler,
520 ) {
521 let throttler = test_throttler_buffered.get_throttler();
522
523 for _ in 0..3 {
524 throttler.send(42);
525 }
526
527 assert_eq!(throttler.used(), 0.6);
528 assert_eq!(throttler.recv_count, 3);
529 assert_eq!(throttler.sent_count, 3);
530 }
531
532 #[rstest]
533 fn test_buffering_refresh_when_at_limit_sends_remaining_items(
534 test_throttler_buffered: TestThrottler,
535 ) {
536 let throttler = test_throttler_buffered.get_throttler();
537
538 for _ in 0..6 {
539 throttler.send(42);
540 }
541
542 {
544 let mut clock = test_throttler_buffered.clock.borrow_mut();
545 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
546 for each_event in clock.match_handlers(time_events) {
547 drop(clock); each_event.callback.call(each_event.event);
550
551 clock = test_throttler_buffered.clock.borrow_mut();
553 }
554 }
555
556 assert_eq!(throttler.used(), 0.2);
558 assert_eq!(throttler.recv_count, 6);
559 assert_eq!(throttler.sent_count, 6);
560 assert_eq!(throttler.qsize(), 0);
561 }
562
563 #[rstest]
564 fn test_buffering_send_message_after_buffering_message(test_throttler_buffered: TestThrottler) {
565 let throttler = test_throttler_buffered.get_throttler();
566
567 for _ in 0..6 {
568 throttler.send(43);
569 }
570
571 {
573 let mut clock = test_throttler_buffered.clock.borrow_mut();
574 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
575 for each_event in clock.match_handlers(time_events) {
576 drop(clock); each_event.callback.call(each_event.event);
579
580 clock = test_throttler_buffered.clock.borrow_mut();
582 }
583 }
584
585 for _ in 0..6 {
586 throttler.send(42);
587 }
588
589 assert_eq!(throttler.used(), 1.0);
591 assert_eq!(throttler.recv_count, 12);
592 assert_eq!(throttler.sent_count, 10);
593 assert_eq!(throttler.qsize(), 2);
594 }
595
596 #[rstest]
597 fn test_buffering_send_message_after_halfway_after_buffering_message(
598 test_throttler_buffered: TestThrottler,
599 ) {
600 let throttler = test_throttler_buffered.get_throttler();
601
602 for _ in 0..6 {
603 throttler.send(42);
604 }
605
606 {
608 let mut clock = test_throttler_buffered.clock.borrow_mut();
609 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
610 for each_event in clock.match_handlers(time_events) {
611 drop(clock); each_event.callback.call(each_event.event);
614
615 clock = test_throttler_buffered.clock.borrow_mut();
617 }
618 }
619
620 for _ in 0..3 {
621 throttler.send(42);
622 }
623
624 assert_eq!(throttler.used(), 0.8);
626 assert_eq!(throttler.recv_count, 9);
627 assert_eq!(throttler.sent_count, 9);
628 assert_eq!(throttler.qsize(), 0);
629 }
630
631 #[rstest]
632 fn test_dropping_send_sends_message_to_handler(test_throttler_unbuffered: TestThrottler) {
633 let throttler = test_throttler_unbuffered.get_throttler();
634 throttler.send(42);
635
636 assert!(!throttler.is_limiting);
637 assert_eq!(throttler.recv_count, 1);
638 assert_eq!(throttler.sent_count, 1);
639 }
640
641 #[rstest]
642 fn test_dropping_send_to_limit_drops_message(test_throttler_unbuffered: TestThrottler) {
643 let throttler = test_throttler_unbuffered.get_throttler();
644 for _ in 0..6 {
645 throttler.send(42);
646 }
647 assert_eq!(throttler.qsize(), 0);
648
649 assert!(throttler.is_limiting);
650 assert_eq!(throttler.used(), 1.0);
651 assert_eq!(throttler.clock.borrow().timer_count(), 1);
652 assert_eq!(
653 throttler.clock.borrow().timer_names(),
654 vec!["dropper_timer"]
655 );
656 assert_eq!(throttler.recv_count, 6);
657 assert_eq!(throttler.sent_count, 5);
658 }
659
660 #[rstest]
661 fn test_dropping_advance_time_when_at_limit_dropped_message(
662 test_throttler_unbuffered: TestThrottler,
663 ) {
664 let throttler = test_throttler_unbuffered.get_throttler();
665 for _ in 0..6 {
666 throttler.send(42);
667 }
668
669 {
671 let mut clock = test_throttler_unbuffered.clock.borrow_mut();
672 let time_events = clock.advance_time(test_throttler_unbuffered.interval.into(), true);
673 for each_event in clock.match_handlers(time_events) {
674 drop(clock); each_event.callback.call(each_event.event);
677
678 clock = test_throttler_unbuffered.clock.borrow_mut();
680 }
681 }
682
683 assert_eq!(throttler.clock.borrow().timer_count(), 0);
684 assert!(!throttler.is_limiting);
685 assert_eq!(throttler.used(), 0.0);
686 assert_eq!(throttler.recv_count, 6);
687 assert_eq!(throttler.sent_count, 5);
688 }
689
690 #[rstest]
691 fn test_dropping_send_message_after_dropping_message(test_throttler_unbuffered: TestThrottler) {
692 let throttler = test_throttler_unbuffered.get_throttler();
693 for _ in 0..6 {
694 throttler.send(42);
695 }
696
697 {
699 let mut clock = test_throttler_unbuffered.clock.borrow_mut();
700 let time_events = clock.advance_time(test_throttler_unbuffered.interval.into(), true);
701 for each_event in clock.match_handlers(time_events) {
702 drop(clock); each_event.callback.call(each_event.event);
705
706 clock = test_throttler_unbuffered.clock.borrow_mut();
708 }
709 }
710
711 throttler.send(42);
712
713 assert_eq!(throttler.used(), 0.2);
714 assert_eq!(throttler.clock.borrow().timer_count(), 0);
715 assert!(!throttler.is_limiting);
716 assert_eq!(throttler.recv_count, 7);
717 assert_eq!(throttler.sent_count, 6);
718 }
719
720 use proptest::prelude::*;
725
726 #[derive(Clone, Debug)]
727 enum ThrottlerInput {
728 SendMessage(u64),
729 AdvanceClock(u8),
730 }
731
732 fn throttler_input_strategy() -> impl Strategy<Value = ThrottlerInput> {
734 prop_oneof![
735 2 => prop::bool::ANY.prop_map(|_| ThrottlerInput::SendMessage(42)),
736 8 => prop::num::u8::ANY.prop_map(|v| ThrottlerInput::AdvanceClock(v % 5 + 5)),
737 ]
738 }
739
740 fn throttler_test_strategy() -> impl Strategy<Value = Vec<ThrottlerInput>> {
742 prop::collection::vec(throttler_input_strategy(), 10..=150)
743 }
744
745 fn test_throttler_with_inputs(inputs: Vec<ThrottlerInput>, test_throttler: TestThrottler) {
746 let test_clock = test_throttler.clock.clone();
747 let interval = test_throttler.interval;
748 let throttler = test_throttler.get_throttler();
749 let mut sent_count = 0;
750
751 for input in inputs {
752 match input {
753 ThrottlerInput::SendMessage(msg) => {
754 throttler.send(msg);
755 sent_count += 1;
756 }
757 ThrottlerInput::AdvanceClock(duration) => {
758 let mut clock_ref = test_clock.borrow_mut();
759 let current_time = clock_ref.get_time_ns();
760 let time_events =
761 clock_ref.advance_time(current_time + u64::from(duration), true);
762 for each_event in clock_ref.match_handlers(time_events) {
763 drop(clock_ref);
764 each_event.callback.call(each_event.event);
765 clock_ref = test_clock.borrow_mut();
766 }
767 }
768 }
769
770 let buffered_messages = throttler.qsize() > 0;
775 let now = throttler.clock.borrow().timestamp_ns().as_u64();
776 let limit_filled_within_interval = throttler
777 .timestamps
778 .get(throttler.limit - 1)
779 .is_some_and(|&ts| (now - ts.as_u64()) < interval);
780 let expected_limiting = buffered_messages && limit_filled_within_interval;
781 assert_eq!(throttler.is_limiting, expected_limiting);
782
783 assert_eq!(sent_count, throttler.sent_count + throttler.qsize());
785 }
786
787 let time_events = test_clock
789 .borrow_mut()
790 .advance_time((interval * 100).into(), true);
791 let mut clock_ref = test_clock.borrow_mut();
792 for each_event in clock_ref.match_handlers(time_events) {
793 drop(clock_ref);
794 each_event.callback.call(each_event.event);
795 clock_ref = test_clock.borrow_mut();
796 }
797 assert_eq!(throttler.qsize(), 0);
798 }
799
800 #[rstest]
801 fn prop_test() {
802 proptest!(|(inputs in throttler_test_strategy())| {
805 let test_throttler = test_throttler_buffered();
806 test_throttler_with_inputs(inputs, test_throttler);
807 });
808 }
809
810 #[rstest]
811 fn test_throttler_process_id_returns_ustr() {
812 let actor_id = Ustr::from("test_throttler");
815 let process = ThrottlerProcess::<String, fn(String)>::new(actor_id);
816
817 let handler_id: Ustr = process.id();
819
820 assert!(handler_id.as_str().contains("test_throttler_process"));
822 assert!(!handler_id.is_empty());
823
824 let _type_check: Ustr = handler_id;
826 }
827}