1use std::{
19 collections::{BTreeMap, BinaryHeap, HashMap},
20 ops::Deref,
21 pin::Pin,
22 sync::Arc,
23 task::{Context, Poll},
24};
25
26use chrono::{DateTime, Utc};
27use futures::Stream;
28use nautilus_core::{
29 AtomicTime, UnixNanos,
30 correctness::{check_positive_u64, check_predicate_true, check_valid_string},
31 time::get_atomic_clock_realtime,
32};
33use tokio::sync::Mutex;
34use ustr::Ustr;
35
36use crate::timer::{
37 LiveTimer, TestTimer, TimeEvent, TimeEventCallback, TimeEventHandlerV2, create_valid_interval,
38};
39
40pub trait Clock {
46 fn utc_now(&self) -> DateTime<Utc> {
48 DateTime::from_timestamp_nanos(self.timestamp_ns().as_i64())
49 }
50
51 fn timestamp_ns(&self) -> UnixNanos;
53
54 fn timestamp_us(&self) -> u64;
56
57 fn timestamp_ms(&self) -> u64;
59
60 fn timestamp(&self) -> f64;
62
63 fn timer_names(&self) -> Vec<&str>;
65
66 fn timer_count(&self) -> usize;
68
69 fn register_default_handler(&mut self, callback: TimeEventCallback);
72
73 fn get_handler(&self, event: TimeEvent) -> TimeEventHandlerV2;
77
78 fn set_time_alert_ns(
81 &mut self,
82 name: &str,
83 alert_time_ns: UnixNanos,
84 callback: Option<TimeEventCallback>,
85 ) -> anyhow::Result<()>;
86
87 fn set_timer_ns(
91 &mut self,
92 name: &str,
93 interval_ns: u64,
94 start_time_ns: UnixNanos,
95 stop_time_ns: Option<UnixNanos>,
96 callback: Option<TimeEventCallback>,
97 ) -> anyhow::Result<()>;
98
99 fn next_time_ns(&self, name: &str) -> UnixNanos;
103 fn cancel_timer(&mut self, name: &str);
104 fn cancel_timers(&mut self);
105}
106
107pub struct TestClock {
111 time: AtomicTime,
112 timers: BTreeMap<Ustr, TestTimer>,
115 default_callback: Option<TimeEventCallback>,
116 callbacks: HashMap<Ustr, TimeEventCallback>,
117 heap: BinaryHeap<TimeEvent>,
118}
119
120impl TestClock {
121 #[must_use]
123 pub fn new() -> Self {
124 Self {
125 time: AtomicTime::new(false, UnixNanos::default()),
126 timers: BTreeMap::new(),
127 default_callback: None,
128 callbacks: HashMap::new(),
129 heap: BinaryHeap::new(),
130 }
131 }
132
133 #[must_use]
135 pub const fn get_timers(&self) -> &BTreeMap<Ustr, TestTimer> {
136 &self.timers
137 }
138
139 pub fn advance_time(&mut self, to_time_ns: UnixNanos, set_time: bool) -> Vec<TimeEvent> {
148 assert!(
150 to_time_ns >= self.time.get_time_ns(),
151 "`to_time_ns` {} was < `self.time.get_time_ns()` {}",
152 to_time_ns,
153 self.time.get_time_ns()
154 );
155
156 if set_time {
157 self.time.set_time(to_time_ns);
158 }
159
160 let mut events: Vec<TimeEvent> = Vec::new();
162 self.timers.retain(|_, timer| {
163 timer.advance(to_time_ns).for_each(|event| {
164 events.push(event);
165 });
166
167 !timer.is_expired()
168 });
169
170 events.sort_by(|a, b| a.ts_event.cmp(&b.ts_event));
171 events
172 }
173
174 pub fn advance_to_time_on_heap(&mut self, to_time_ns: UnixNanos) {
180 assert!(
182 to_time_ns >= self.time.get_time_ns(),
183 "`to_time_ns` {} was < `self.time.get_time_ns()` {}",
184 to_time_ns,
185 self.time.get_time_ns()
186 );
187
188 self.time.set_time(to_time_ns);
189
190 self.timers.retain(|_, timer| {
192 timer.advance(to_time_ns).for_each(|event| {
193 self.heap.push(event);
194 });
195
196 !timer.is_expired()
197 });
198 }
199
200 #[must_use]
206 pub fn match_handlers(&self, events: Vec<TimeEvent>) -> Vec<TimeEventHandlerV2> {
207 events
208 .into_iter()
209 .map(|event| {
210 let callback = self.callbacks.get(&event.name).cloned().unwrap_or_else(|| {
211 self.default_callback
214 .clone()
215 .expect("Default callback should exist")
216 });
217 TimeEventHandlerV2::new(event, callback)
218 })
219 .collect()
220 }
221}
222
223impl Iterator for TestClock {
224 type Item = TimeEventHandlerV2;
225
226 fn next(&mut self) -> Option<Self::Item> {
227 self.heap.pop().map(|event| self.get_handler(event))
228 }
229}
230
231impl Default for TestClock {
232 fn default() -> Self {
234 Self::new()
235 }
236}
237
238impl Deref for TestClock {
239 type Target = AtomicTime;
240
241 fn deref(&self) -> &Self::Target {
242 &self.time
243 }
244}
245
246impl Clock for TestClock {
247 fn timestamp_ns(&self) -> UnixNanos {
248 self.time.get_time_ns()
249 }
250
251 fn timestamp_us(&self) -> u64 {
252 self.time.get_time_us()
253 }
254
255 fn timestamp_ms(&self) -> u64 {
256 self.time.get_time_ms()
257 }
258
259 fn timestamp(&self) -> f64 {
260 self.time.get_time()
261 }
262
263 fn timer_names(&self) -> Vec<&str> {
264 self.timers
265 .iter()
266 .filter(|(_, timer)| !timer.is_expired())
267 .map(|(k, _)| k.as_str())
268 .collect()
269 }
270
271 fn timer_count(&self) -> usize {
272 self.timers
273 .iter()
274 .filter(|(_, timer)| !timer.is_expired())
275 .count()
276 }
277
278 fn register_default_handler(&mut self, callback: TimeEventCallback) {
279 self.default_callback = Some(callback);
280 }
281
282 fn get_handler(&self, event: TimeEvent) -> TimeEventHandlerV2 {
283 let callback = self
285 .callbacks
286 .get(&event.name)
287 .cloned()
288 .or_else(|| self.default_callback.clone())
289 .unwrap_or_else(|| panic!("Event '{}' should have associated handler", event.name));
290
291 TimeEventHandlerV2::new(event, callback)
292 }
293
294 fn set_time_alert_ns(
295 &mut self,
296 name: &str,
297 alert_time_ns: UnixNanos,
298 callback: Option<TimeEventCallback>,
299 ) -> anyhow::Result<()> {
300 check_valid_string(name, stringify!(name))?;
301 let name_ustr = Ustr::from(name);
302
303 check_predicate_true(
304 callback.is_some()
305 | self.callbacks.contains_key(&name_ustr)
306 | self.default_callback.is_some(),
307 "No callbacks provided",
308 )?;
309
310 match callback {
311 Some(callback_py) => self.callbacks.insert(name_ustr, callback_py),
312 None => None,
313 };
314
315 self.cancel_timer(name);
317
318 let ts_now = self.time.get_time_ns();
320 let interval_ns = create_valid_interval(std::cmp::max((alert_time_ns - ts_now).into(), 1));
321 let timer = TestTimer::new(name, interval_ns, ts_now, Some(alert_time_ns));
322 self.timers.insert(name_ustr, timer);
323
324 Ok(())
325 }
326
327 fn set_timer_ns(
328 &mut self,
329 name: &str,
330 interval_ns: u64,
331 start_time_ns: UnixNanos,
332 stop_time_ns: Option<UnixNanos>,
333 callback: Option<TimeEventCallback>,
334 ) -> anyhow::Result<()> {
335 check_valid_string(name, stringify!(name))?;
336 check_positive_u64(interval_ns, stringify!(interval_ns))?;
337 check_predicate_true(
338 callback.is_some() | self.default_callback.is_some(),
339 "No callbacks provided",
340 )?;
341
342 let name_ustr = Ustr::from(name);
343 match callback {
344 Some(callback_py) => self.callbacks.insert(name_ustr, callback_py),
345 None => None,
346 };
347
348 let interval_ns = create_valid_interval(interval_ns);
349 let timer = TestTimer::new(name, interval_ns, start_time_ns, stop_time_ns);
350 self.timers.insert(name_ustr, timer);
351
352 Ok(())
353 }
354
355 fn next_time_ns(&self, name: &str) -> UnixNanos {
356 let timer = self.timers.get(&Ustr::from(name));
357 match timer {
358 None => 0.into(),
359 Some(timer) => timer.next_time_ns(),
360 }
361 }
362
363 fn cancel_timer(&mut self, name: &str) {
364 let timer = self.timers.remove(&Ustr::from(name));
365 match timer {
366 None => {}
367 Some(mut timer) => timer.cancel(),
368 }
369 }
370
371 fn cancel_timers(&mut self) {
372 for timer in &mut self.timers.values_mut() {
373 timer.cancel();
374 }
375 self.timers = BTreeMap::new();
376 }
377}
378
379pub struct LiveClock {
383 time: &'static AtomicTime,
384 timers: HashMap<Ustr, LiveTimer>,
385 default_callback: Option<TimeEventCallback>,
386 pub heap: Arc<Mutex<BinaryHeap<TimeEvent>>>,
387 #[allow(dead_code)]
388 callbacks: HashMap<Ustr, TimeEventCallback>,
389}
390
391impl LiveClock {
392 #[must_use]
394 pub fn new() -> Self {
395 Self {
396 time: get_atomic_clock_realtime(),
397 timers: HashMap::new(),
398 default_callback: None,
399 heap: Arc::new(Mutex::new(BinaryHeap::new())),
400 callbacks: HashMap::new(),
401 }
402 }
403
404 #[must_use]
405 pub fn get_event_stream(&self) -> TimeEventStream {
406 TimeEventStream::new(self.heap.clone())
407 }
408
409 #[must_use]
410 pub const fn get_timers(&self) -> &HashMap<Ustr, LiveTimer> {
411 &self.timers
412 }
413
414 fn clear_expired_timers(&mut self) {
416 self.timers.retain(|_, timer| !timer.is_expired());
417 }
418}
419
420impl Default for LiveClock {
421 fn default() -> Self {
423 Self::new()
424 }
425}
426
427impl Deref for LiveClock {
428 type Target = AtomicTime;
429
430 fn deref(&self) -> &Self::Target {
431 self.time
432 }
433}
434
435impl Clock for LiveClock {
436 fn timestamp_ns(&self) -> UnixNanos {
437 self.time.get_time_ns()
438 }
439
440 fn timestamp_us(&self) -> u64 {
441 self.time.get_time_us()
442 }
443
444 fn timestamp_ms(&self) -> u64 {
445 self.time.get_time_ms()
446 }
447
448 fn timestamp(&self) -> f64 {
449 self.time.get_time()
450 }
451
452 fn timer_names(&self) -> Vec<&str> {
453 self.timers
454 .iter()
455 .filter(|(_, timer)| !timer.is_expired())
456 .map(|(k, _)| k.as_str())
457 .collect()
458 }
459
460 fn timer_count(&self) -> usize {
461 self.timers
462 .iter()
463 .filter(|(_, timer)| !timer.is_expired())
464 .count()
465 }
466
467 fn register_default_handler(&mut self, handler: TimeEventCallback) {
468 self.default_callback = Some(handler);
469 }
470
471 #[allow(unused_variables)]
472 fn get_handler(&self, event: TimeEvent) -> TimeEventHandlerV2 {
473 #[cfg(not(feature = "clock_v2"))]
474 panic!("Cannot get live clock handler without 'clock_v2' feature");
475
476 #[cfg(feature = "clock_v2")]
478 {
479 let callback = self
480 .callbacks
481 .get(&event.name)
482 .cloned()
483 .or_else(|| self.default_callback.clone())
484 .unwrap_or_else(|| panic!("Event '{}' should have associated handler", event.name));
485
486 TimeEventHandlerV2::new(event, callback)
487 }
488 }
489
490 fn set_time_alert_ns(
491 &mut self,
492 name: &str,
493 mut alert_time_ns: UnixNanos,
494 callback: Option<TimeEventCallback>,
495 ) -> anyhow::Result<()> {
496 check_valid_string(name, stringify!(name))?;
497 let name_ustr = Ustr::from(name);
498
499 check_predicate_true(
500 callback.is_some()
501 | self.callbacks.contains_key(&name_ustr)
502 | self.default_callback.is_some(),
503 "No callbacks provided",
504 )?;
505
506 #[cfg(feature = "clock_v2")]
507 {
508 match callback.clone() {
509 Some(callback) => self.callbacks.insert(name_ustr, callback),
510 None => None,
511 };
512 }
513
514 let callback = match callback {
515 Some(callback) => callback,
516 None => {
517 if self.callbacks.contains_key(&name_ustr) {
518 self.callbacks.get(&name_ustr).unwrap().clone()
519 } else {
520 self.default_callback.clone().unwrap()
521 }
522 }
523 };
524
525 self.cancel_timer(name);
527
528 let ts_now = self.get_time_ns();
530 alert_time_ns = std::cmp::max(alert_time_ns, ts_now);
531 let interval_ns = create_valid_interval(std::cmp::max((alert_time_ns - ts_now).into(), 1));
532
533 #[cfg(not(feature = "clock_v2"))]
534 let mut timer = LiveTimer::new(name, interval_ns, ts_now, Some(alert_time_ns), callback);
535
536 #[cfg(feature = "clock_v2")]
537 let mut timer = LiveTimer::new(
538 name,
539 interval_ns,
540 ts_now,
541 Some(alert_time_ns),
542 callback,
543 self.heap.clone(),
544 );
545
546 timer.start();
547
548 self.clear_expired_timers();
549 self.timers.insert(Ustr::from(name), timer);
550
551 Ok(())
552 }
553
554 fn set_timer_ns(
555 &mut self,
556 name: &str,
557 interval_ns: u64,
558 start_time_ns: UnixNanos,
559 stop_time_ns: Option<UnixNanos>,
560 callback: Option<TimeEventCallback>,
561 ) -> anyhow::Result<()> {
562 check_valid_string(name, stringify!(name))?;
563 check_positive_u64(interval_ns, stringify!(interval_ns))?;
564 check_predicate_true(
565 callback.is_some() | self.default_callback.is_some(),
566 "No callbacks provided",
567 )?;
568
569 let callback = match callback {
570 Some(callback) => callback,
571 None => self.default_callback.clone().unwrap(),
572 };
573
574 #[cfg(feature = "clock_v2")]
575 {
576 let name = Ustr::from(name);
577 self.callbacks.insert(name, callback.clone());
578 }
579
580 let mut start_time_ns = start_time_ns;
583 if start_time_ns == 0 {
584 start_time_ns = self.timestamp_ns();
586 }
587 let interval_ns = create_valid_interval(interval_ns);
588
589 #[cfg(not(feature = "clock_v2"))]
590 let mut timer = LiveTimer::new(name, interval_ns, start_time_ns, stop_time_ns, callback);
591
592 #[cfg(feature = "clock_v2")]
593 let mut timer = LiveTimer::new(
594 name,
595 interval_ns,
596 start_time_ns,
597 stop_time_ns,
598 callback,
599 self.heap.clone(),
600 );
601 timer.start();
602
603 self.clear_expired_timers();
604 self.timers.insert(Ustr::from(name), timer);
605
606 Ok(())
607 }
608
609 fn next_time_ns(&self, name: &str) -> UnixNanos {
610 let timer = self.timers.get(&Ustr::from(name));
611 match timer {
612 None => 0.into(),
613 Some(timer) => timer.next_time_ns(),
614 }
615 }
616
617 fn cancel_timer(&mut self, name: &str) {
618 let timer = self.timers.remove(&Ustr::from(name));
619 match timer {
620 None => {}
621 Some(mut timer) => {
622 timer.cancel();
623 }
624 }
625 }
626
627 fn cancel_timers(&mut self) {
628 for timer in &mut self.timers.values_mut() {
629 timer.cancel();
630 }
631 self.timers.clear();
632 }
633}
634
635pub struct TimeEventStream {
637 heap: Arc<Mutex<BinaryHeap<TimeEvent>>>,
638}
639
640impl TimeEventStream {
641 pub const fn new(heap: Arc<Mutex<BinaryHeap<TimeEvent>>>) -> Self {
642 Self { heap }
643 }
644}
645
646impl Stream for TimeEventStream {
647 type Item = TimeEvent;
648
649 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
650 let mut heap = match self.heap.try_lock() {
651 Ok(guard) => guard,
652 Err(e) => {
653 tracing::error!("Unable to get LiveClock heap lock: {e}");
654 cx.waker().wake_by_ref();
655 return Poll::Pending;
656 }
657 };
658
659 if let Some(event) = heap.pop() {
660 Poll::Ready(Some(event))
661 } else {
662 cx.waker().wake_by_ref();
663 Poll::Pending
664 }
665 }
666}
667
668#[cfg(test)]
672mod tests {
673 use std::{cell::RefCell, rc::Rc};
674
675 use rstest::{fixture, rstest};
676
677 use super::*;
678
679 #[derive(Default)]
680 struct TestCallback {
681 called: Rc<RefCell<bool>>,
682 }
683
684 impl TestCallback {
685 const fn new(called: Rc<RefCell<bool>>) -> Self {
686 Self { called }
687 }
688 }
689
690 impl From<TestCallback> for TimeEventCallback {
691 fn from(callback: TestCallback) -> Self {
692 Self::Rust(Rc::new(move |_event: TimeEvent| {
693 *callback.called.borrow_mut() = true;
694 }))
695 }
696 }
697
698 #[fixture]
699 pub fn test_clock() -> TestClock {
700 let mut clock = TestClock::new();
701 clock.register_default_handler(TestCallback::default().into());
702 clock
703 }
704
705 #[rstest]
706 fn test_time_monotonicity(mut test_clock: TestClock) {
707 let initial_time = test_clock.timestamp_ns();
708 test_clock.advance_time((*initial_time + 1000).into(), true);
709 assert!(test_clock.timestamp_ns() > initial_time);
710 }
711
712 #[rstest]
713 fn test_timer_registration(mut test_clock: TestClock) {
714 test_clock
715 .set_time_alert_ns(
716 "test_timer",
717 (*test_clock.timestamp_ns() + 1000).into(),
718 None,
719 )
720 .unwrap();
721 assert_eq!(test_clock.timer_count(), 1);
722 assert_eq!(test_clock.timer_names(), vec!["test_timer"]);
723 }
724
725 #[rstest]
726 fn test_timer_expiration(mut test_clock: TestClock) {
727 let alert_time = (*test_clock.timestamp_ns() + 1000).into();
728 test_clock
729 .set_time_alert_ns("test_timer", alert_time, None)
730 .unwrap();
731 let events = test_clock.advance_time(alert_time, true);
732 assert_eq!(events.len(), 1);
733 assert_eq!(events[0].name.as_str(), "test_timer");
734 }
735
736 #[rstest]
737 fn test_timer_cancellation(mut test_clock: TestClock) {
738 test_clock
739 .set_time_alert_ns(
740 "test_timer",
741 (*test_clock.timestamp_ns() + 1000).into(),
742 None,
743 )
744 .unwrap();
745 assert_eq!(test_clock.timer_count(), 1);
746 test_clock.cancel_timer("test_timer");
747 assert_eq!(test_clock.timer_count(), 0);
748 }
749
750 #[rstest]
751 fn test_time_advancement(mut test_clock: TestClock) {
752 let start_time = test_clock.timestamp_ns();
753 test_clock
754 .set_timer_ns("test_timer", 1000, start_time, None, None)
755 .unwrap();
756 let events = test_clock.advance_time((*start_time + 2500).into(), true);
757 assert_eq!(events.len(), 2);
758 assert_eq!(*events[0].ts_event, *start_time + 1000);
759 assert_eq!(*events[1].ts_event, *start_time + 2000);
760 }
761
762 #[test]
763 fn test_default_and_custom_callbacks() {
764 let mut clock = TestClock::new();
765 let default_called = Rc::new(RefCell::new(false));
766 let custom_called = Rc::new(RefCell::new(false));
767
768 let default_callback = TestCallback::new(Rc::clone(&default_called));
769 let custom_callback = TestCallback::new(Rc::clone(&custom_called));
770
771 clock.register_default_handler(TimeEventCallback::from(default_callback));
772 clock
773 .set_time_alert_ns("default_timer", (*clock.timestamp_ns() + 1000).into(), None)
774 .unwrap();
775 clock
776 .set_time_alert_ns(
777 "custom_timer",
778 (*clock.timestamp_ns() + 1000).into(),
779 Some(TimeEventCallback::from(custom_callback)),
780 )
781 .unwrap();
782
783 let events = clock.advance_time((*clock.timestamp_ns() + 1000).into(), true);
784 let handlers = clock.match_handlers(events);
785
786 for handler in handlers {
787 handler.callback.call(handler.event);
788 }
789
790 assert!(*default_called.borrow());
791 assert!(*custom_called.borrow());
792 }
793
794 #[rstest]
795 fn test_multiple_timers(mut test_clock: TestClock) {
796 let start_time = test_clock.timestamp_ns();
797 test_clock
798 .set_timer_ns("timer1", 1000, start_time, None, None)
799 .unwrap();
800 test_clock
801 .set_timer_ns("timer2", 2000, start_time, None, None)
802 .unwrap();
803 let events = test_clock.advance_time((*start_time + 2000).into(), true);
804 assert_eq!(events.len(), 3);
805 assert_eq!(events[0].name.as_str(), "timer1");
806 assert_eq!(events[1].name.as_str(), "timer1");
807 assert_eq!(events[2].name.as_str(), "timer2");
808 }
809}