1use std::{
19 collections::{BTreeMap, BinaryHeap, HashMap},
20 ops::Deref,
21 pin::Pin,
22 sync::Arc,
23 task::{Context, Poll},
24};
25
26use chrono::{DateTime, Utc};
27use futures::Stream;
28use nautilus_core::{
29 AtomicTime, UnixNanos,
30 correctness::{check_positive_u64, check_predicate_true, check_valid_string},
31 time::get_atomic_clock_realtime,
32};
33use tokio::sync::Mutex;
34use ustr::Ustr;
35
36use crate::timer::{
37 LiveTimer, TestTimer, TimeEvent, TimeEventCallback, TimeEventHandlerV2, create_valid_interval,
38};
39
40pub trait Clock {
46 fn utc_now(&self) -> DateTime<Utc> {
48 DateTime::from_timestamp_nanos(self.timestamp_ns().as_i64())
49 }
50
51 fn timestamp_ns(&self) -> UnixNanos;
53
54 fn timestamp_us(&self) -> u64;
56
57 fn timestamp_ms(&self) -> u64;
59
60 fn timestamp(&self) -> f64;
62
63 fn timer_names(&self) -> Vec<&str>;
65
66 fn timer_count(&self) -> usize;
68
69 fn register_default_handler(&mut self, callback: TimeEventCallback);
72
73 fn get_handler(&self, event: TimeEvent) -> TimeEventHandlerV2;
77
78 fn set_time_alert_ns(
81 &mut self,
82 name: &str,
83 alert_time_ns: UnixNanos,
84 callback: Option<TimeEventCallback>,
85 ) -> anyhow::Result<()>;
86
87 fn set_timer_ns(
91 &mut self,
92 name: &str,
93 interval_ns: u64,
94 start_time_ns: UnixNanos,
95 stop_time_ns: Option<UnixNanos>,
96 callback: Option<TimeEventCallback>,
97 ) -> anyhow::Result<()>;
98
99 fn next_time_ns(&self, name: &str) -> UnixNanos;
103 fn cancel_timer(&mut self, name: &str);
104 fn cancel_timers(&mut self);
105
106 fn reset(&mut self);
107}
108
109pub struct TestClock {
113 time: AtomicTime,
114 timers: BTreeMap<Ustr, TestTimer>,
117 default_callback: Option<TimeEventCallback>,
118 callbacks: HashMap<Ustr, TimeEventCallback>,
119 heap: BinaryHeap<TimeEvent>,
120}
121
122impl TestClock {
123 #[must_use]
125 pub fn new() -> Self {
126 Self {
127 time: AtomicTime::new(false, UnixNanos::default()),
128 timers: BTreeMap::new(),
129 default_callback: None,
130 callbacks: HashMap::new(),
131 heap: BinaryHeap::new(),
132 }
133 }
134
135 #[must_use]
137 pub const fn get_timers(&self) -> &BTreeMap<Ustr, TestTimer> {
138 &self.timers
139 }
140
141 pub fn advance_time(&mut self, to_time_ns: UnixNanos, set_time: bool) -> Vec<TimeEvent> {
150 assert!(
152 to_time_ns >= self.time.get_time_ns(),
153 "`to_time_ns` {} was < `self.time.get_time_ns()` {}",
154 to_time_ns,
155 self.time.get_time_ns()
156 );
157
158 if set_time {
159 self.time.set_time(to_time_ns);
160 }
161
162 let mut events: Vec<TimeEvent> = Vec::new();
164 self.timers.retain(|_, timer| {
165 timer.advance(to_time_ns).for_each(|event| {
166 events.push(event);
167 });
168
169 !timer.is_expired()
170 });
171
172 events.sort_by(|a, b| a.ts_event.cmp(&b.ts_event));
173 events
174 }
175
176 pub fn advance_to_time_on_heap(&mut self, to_time_ns: UnixNanos) {
182 assert!(
184 to_time_ns >= self.time.get_time_ns(),
185 "`to_time_ns` {} was < `self.time.get_time_ns()` {}",
186 to_time_ns,
187 self.time.get_time_ns()
188 );
189
190 self.time.set_time(to_time_ns);
191
192 self.timers.retain(|_, timer| {
194 timer.advance(to_time_ns).for_each(|event| {
195 self.heap.push(event);
196 });
197
198 !timer.is_expired()
199 });
200 }
201
202 #[must_use]
208 pub fn match_handlers(&self, events: Vec<TimeEvent>) -> Vec<TimeEventHandlerV2> {
209 events
210 .into_iter()
211 .map(|event| {
212 let callback = self.callbacks.get(&event.name).cloned().unwrap_or_else(|| {
213 self.default_callback
216 .clone()
217 .expect("Default callback should exist")
218 });
219 TimeEventHandlerV2::new(event, callback)
220 })
221 .collect()
222 }
223}
224
225impl Iterator for TestClock {
226 type Item = TimeEventHandlerV2;
227
228 fn next(&mut self) -> Option<Self::Item> {
229 self.heap.pop().map(|event| self.get_handler(event))
230 }
231}
232
233impl Default for TestClock {
234 fn default() -> Self {
236 Self::new()
237 }
238}
239
240impl Deref for TestClock {
241 type Target = AtomicTime;
242
243 fn deref(&self) -> &Self::Target {
244 &self.time
245 }
246}
247
248impl Clock for TestClock {
249 fn timestamp_ns(&self) -> UnixNanos {
250 self.time.get_time_ns()
251 }
252
253 fn timestamp_us(&self) -> u64 {
254 self.time.get_time_us()
255 }
256
257 fn timestamp_ms(&self) -> u64 {
258 self.time.get_time_ms()
259 }
260
261 fn timestamp(&self) -> f64 {
262 self.time.get_time()
263 }
264
265 fn timer_names(&self) -> Vec<&str> {
266 self.timers
267 .iter()
268 .filter(|(_, timer)| !timer.is_expired())
269 .map(|(k, _)| k.as_str())
270 .collect()
271 }
272
273 fn timer_count(&self) -> usize {
274 self.timers
275 .iter()
276 .filter(|(_, timer)| !timer.is_expired())
277 .count()
278 }
279
280 fn register_default_handler(&mut self, callback: TimeEventCallback) {
281 self.default_callback = Some(callback);
282 }
283
284 fn get_handler(&self, event: TimeEvent) -> TimeEventHandlerV2 {
285 let callback = self
287 .callbacks
288 .get(&event.name)
289 .cloned()
290 .or_else(|| self.default_callback.clone())
291 .unwrap_or_else(|| panic!("Event '{}' should have associated handler", event.name));
292
293 TimeEventHandlerV2::new(event, callback)
294 }
295
296 fn set_time_alert_ns(
297 &mut self,
298 name: &str,
299 alert_time_ns: UnixNanos,
300 callback: Option<TimeEventCallback>,
301 ) -> anyhow::Result<()> {
302 check_valid_string(name, stringify!(name))?;
303
304 let name = Ustr::from(name);
305
306 check_predicate_true(
307 callback.is_some()
308 | self.callbacks.contains_key(&name)
309 | self.default_callback.is_some(),
310 "No callbacks provided",
311 )?;
312
313 match callback {
314 Some(callback_py) => self.callbacks.insert(name, callback_py),
315 None => None,
316 };
317
318 self.cancel_timer(name.as_str());
320
321 let ts_now = self.time.get_time_ns();
323 let interval_ns = create_valid_interval(std::cmp::max((alert_time_ns - ts_now).into(), 1));
324 let timer = TestTimer::new(name, interval_ns, ts_now, Some(alert_time_ns));
325 self.timers.insert(name, timer);
326
327 Ok(())
328 }
329
330 fn set_timer_ns(
331 &mut self,
332 name: &str,
333 interval_ns: u64,
334 start_time_ns: UnixNanos,
335 stop_time_ns: Option<UnixNanos>,
336 callback: Option<TimeEventCallback>,
337 ) -> anyhow::Result<()> {
338 check_valid_string(name, stringify!(name))?;
339 check_positive_u64(interval_ns, stringify!(interval_ns))?;
340 check_predicate_true(
341 callback.is_some() | self.default_callback.is_some(),
342 "No callbacks provided",
343 )?;
344
345 let name = Ustr::from(name);
346
347 match callback {
348 Some(callback_py) => self.callbacks.insert(name, callback_py),
349 None => None,
350 };
351
352 let interval_ns = create_valid_interval(interval_ns);
353 let timer = TestTimer::new(name, interval_ns, start_time_ns, stop_time_ns);
354 self.timers.insert(name, timer);
355
356 Ok(())
357 }
358
359 fn next_time_ns(&self, name: &str) -> UnixNanos {
360 let timer = self.timers.get(&Ustr::from(name));
361 match timer {
362 None => 0.into(),
363 Some(timer) => timer.next_time_ns(),
364 }
365 }
366
367 fn cancel_timer(&mut self, name: &str) {
368 let timer = self.timers.remove(&Ustr::from(name));
369 match timer {
370 None => {}
371 Some(mut timer) => timer.cancel(),
372 }
373 }
374
375 fn cancel_timers(&mut self) {
376 for timer in &mut self.timers.values_mut() {
377 timer.cancel();
378 }
379 self.timers = BTreeMap::new();
380 }
381
382 fn reset(&mut self) {
383 self.time = AtomicTime::new(false, UnixNanos::default());
384 self.timers = BTreeMap::new();
385 self.heap = BinaryHeap::new();
386 self.callbacks = HashMap::new();
387 }
388}
389
390pub struct LiveClock {
394 time: &'static AtomicTime,
395 timers: HashMap<Ustr, LiveTimer>,
396 default_callback: Option<TimeEventCallback>,
397 pub heap: Arc<Mutex<BinaryHeap<TimeEvent>>>,
398 #[allow(dead_code)]
399 callbacks: HashMap<Ustr, TimeEventCallback>,
400}
401
402impl LiveClock {
403 #[must_use]
405 pub fn new() -> Self {
406 Self {
407 time: get_atomic_clock_realtime(),
408 timers: HashMap::new(),
409 default_callback: None,
410 heap: Arc::new(Mutex::new(BinaryHeap::new())),
411 callbacks: HashMap::new(),
412 }
413 }
414
415 #[must_use]
416 pub fn get_event_stream(&self) -> TimeEventStream {
417 TimeEventStream::new(self.heap.clone())
418 }
419
420 #[must_use]
421 pub const fn get_timers(&self) -> &HashMap<Ustr, LiveTimer> {
422 &self.timers
423 }
424
425 fn clear_expired_timers(&mut self) {
427 self.timers.retain(|_, timer| !timer.is_expired());
428 }
429}
430
431impl Default for LiveClock {
432 fn default() -> Self {
434 Self::new()
435 }
436}
437
438impl Deref for LiveClock {
439 type Target = AtomicTime;
440
441 fn deref(&self) -> &Self::Target {
442 self.time
443 }
444}
445
446impl Clock for LiveClock {
447 fn timestamp_ns(&self) -> UnixNanos {
448 self.time.get_time_ns()
449 }
450
451 fn timestamp_us(&self) -> u64 {
452 self.time.get_time_us()
453 }
454
455 fn timestamp_ms(&self) -> u64 {
456 self.time.get_time_ms()
457 }
458
459 fn timestamp(&self) -> f64 {
460 self.time.get_time()
461 }
462
463 fn timer_names(&self) -> Vec<&str> {
464 self.timers
465 .iter()
466 .filter(|(_, timer)| !timer.is_expired())
467 .map(|(k, _)| k.as_str())
468 .collect()
469 }
470
471 fn timer_count(&self) -> usize {
472 self.timers
473 .iter()
474 .filter(|(_, timer)| !timer.is_expired())
475 .count()
476 }
477
478 fn register_default_handler(&mut self, handler: TimeEventCallback) {
479 self.default_callback = Some(handler);
480 }
481
482 #[allow(unused_variables)]
483 fn get_handler(&self, event: TimeEvent) -> TimeEventHandlerV2 {
484 #[cfg(not(feature = "clock_v2"))]
485 panic!("Cannot get live clock handler without 'clock_v2' feature");
486
487 #[cfg(feature = "clock_v2")]
489 {
490 let callback = self
491 .callbacks
492 .get(&event.name)
493 .cloned()
494 .or_else(|| self.default_callback.clone())
495 .unwrap_or_else(|| panic!("Event '{}' should have associated handler", event.name));
496
497 TimeEventHandlerV2::new(event, callback)
498 }
499 }
500
501 fn set_time_alert_ns(
502 &mut self,
503 name: &str,
504 mut alert_time_ns: UnixNanos,
505 callback: Option<TimeEventCallback>,
506 ) -> anyhow::Result<()> {
507 check_valid_string(name, stringify!(name))?;
508
509 let name = Ustr::from(name);
510
511 check_predicate_true(
512 callback.is_some()
513 | self.callbacks.contains_key(&name)
514 | self.default_callback.is_some(),
515 "No callbacks provided",
516 )?;
517
518 #[cfg(feature = "clock_v2")]
519 {
520 match callback.clone() {
521 Some(callback) => self.callbacks.insert(name, callback),
522 None => None,
523 };
524 }
525
526 let callback = match callback {
527 Some(callback) => callback,
528 None => {
529 if self.callbacks.contains_key(&name) {
530 self.callbacks.get(&name).unwrap().clone()
531 } else {
532 self.default_callback.clone().unwrap()
533 }
534 }
535 };
536
537 self.cancel_timer(name.as_str());
539
540 let ts_now = self.get_time_ns();
542 alert_time_ns = std::cmp::max(alert_time_ns, ts_now);
543 let interval_ns = create_valid_interval(std::cmp::max((alert_time_ns - ts_now).into(), 1));
544
545 #[cfg(not(feature = "clock_v2"))]
546 let mut timer = LiveTimer::new(name, interval_ns, ts_now, Some(alert_time_ns), callback);
547
548 #[cfg(feature = "clock_v2")]
549 let mut timer = LiveTimer::new(
550 name,
551 interval_ns,
552 ts_now,
553 Some(alert_time_ns),
554 callback,
555 self.heap.clone(),
556 );
557
558 timer.start();
559
560 self.clear_expired_timers();
561 self.timers.insert(name, timer);
562
563 Ok(())
564 }
565
566 fn set_timer_ns(
567 &mut self,
568 name: &str,
569 interval_ns: u64,
570 start_time_ns: UnixNanos,
571 stop_time_ns: Option<UnixNanos>,
572 callback: Option<TimeEventCallback>,
573 ) -> anyhow::Result<()> {
574 check_valid_string(name, stringify!(name))?;
575 check_positive_u64(interval_ns, stringify!(interval_ns))?;
576 check_predicate_true(
577 callback.is_some() | self.default_callback.is_some(),
578 "No callbacks provided",
579 )?;
580
581 let name = Ustr::from(name);
582
583 let callback = match callback {
584 Some(callback) => callback,
585 None => self.default_callback.clone().unwrap(),
586 };
587
588 #[cfg(feature = "clock_v2")]
589 {
590 self.callbacks.insert(name, callback.clone());
591 }
592
593 let mut start_time_ns = start_time_ns;
596 if start_time_ns == 0 {
597 start_time_ns = self.timestamp_ns();
599 }
600 let interval_ns = create_valid_interval(interval_ns);
601
602 #[cfg(not(feature = "clock_v2"))]
603 let mut timer = LiveTimer::new(name, interval_ns, start_time_ns, stop_time_ns, callback);
604
605 #[cfg(feature = "clock_v2")]
606 let mut timer = LiveTimer::new(
607 name,
608 interval_ns,
609 start_time_ns,
610 stop_time_ns,
611 callback,
612 self.heap.clone(),
613 );
614 timer.start();
615
616 self.clear_expired_timers();
617 self.timers.insert(name, timer);
618
619 Ok(())
620 }
621
622 fn next_time_ns(&self, name: &str) -> UnixNanos {
623 let timer = self.timers.get(&Ustr::from(name));
624 match timer {
625 None => 0.into(),
626 Some(timer) => timer.next_time_ns(),
627 }
628 }
629
630 fn cancel_timer(&mut self, name: &str) {
631 let timer = self.timers.remove(&Ustr::from(name));
632 match timer {
633 None => {}
634 Some(mut timer) => {
635 timer.cancel();
636 }
637 }
638 }
639
640 fn cancel_timers(&mut self) {
641 for timer in &mut self.timers.values_mut() {
642 timer.cancel();
643 }
644 self.timers.clear();
645 }
646
647 fn reset(&mut self) {
648 self.timers = HashMap::new();
649 self.heap = Arc::new(Mutex::new(BinaryHeap::new()));
650 self.callbacks = HashMap::new();
651 }
652}
653
654pub struct TimeEventStream {
656 heap: Arc<Mutex<BinaryHeap<TimeEvent>>>,
657}
658
659impl TimeEventStream {
660 pub const fn new(heap: Arc<Mutex<BinaryHeap<TimeEvent>>>) -> Self {
661 Self { heap }
662 }
663}
664
665impl Stream for TimeEventStream {
666 type Item = TimeEvent;
667
668 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
669 let mut heap = match self.heap.try_lock() {
670 Ok(guard) => guard,
671 Err(e) => {
672 tracing::error!("Unable to get LiveClock heap lock: {e}");
673 cx.waker().wake_by_ref();
674 return Poll::Pending;
675 }
676 };
677
678 if let Some(event) = heap.pop() {
679 Poll::Ready(Some(event))
680 } else {
681 cx.waker().wake_by_ref();
682 Poll::Pending
683 }
684 }
685}
686
687#[cfg(test)]
691mod tests {
692 use std::{cell::RefCell, rc::Rc};
693
694 use rstest::{fixture, rstest};
695
696 use super::*;
697
698 #[derive(Default)]
699 struct TestCallback {
700 called: Rc<RefCell<bool>>,
701 }
702
703 impl TestCallback {
704 const fn new(called: Rc<RefCell<bool>>) -> Self {
705 Self { called }
706 }
707 }
708
709 impl From<TestCallback> for TimeEventCallback {
710 fn from(callback: TestCallback) -> Self {
711 Self::Rust(Rc::new(move |_event: TimeEvent| {
712 *callback.called.borrow_mut() = true;
713 }))
714 }
715 }
716
717 #[fixture]
718 pub fn test_clock() -> TestClock {
719 let mut clock = TestClock::new();
720 clock.register_default_handler(TestCallback::default().into());
721 clock
722 }
723
724 #[rstest]
725 fn test_time_monotonicity(mut test_clock: TestClock) {
726 let initial_time = test_clock.timestamp_ns();
727 test_clock.advance_time((*initial_time + 1000).into(), true);
728 assert!(test_clock.timestamp_ns() > initial_time);
729 }
730
731 #[rstest]
732 fn test_timer_registration(mut test_clock: TestClock) {
733 test_clock
734 .set_time_alert_ns(
735 "test_timer",
736 (*test_clock.timestamp_ns() + 1000).into(),
737 None,
738 )
739 .unwrap();
740 assert_eq!(test_clock.timer_count(), 1);
741 assert_eq!(test_clock.timer_names(), vec!["test_timer"]);
742 }
743
744 #[rstest]
745 fn test_timer_expiration(mut test_clock: TestClock) {
746 let alert_time = (*test_clock.timestamp_ns() + 1000).into();
747 test_clock
748 .set_time_alert_ns("test_timer", alert_time, None)
749 .unwrap();
750 let events = test_clock.advance_time(alert_time, true);
751 assert_eq!(events.len(), 1);
752 assert_eq!(events[0].name.as_str(), "test_timer");
753 }
754
755 #[rstest]
756 fn test_timer_cancellation(mut test_clock: TestClock) {
757 test_clock
758 .set_time_alert_ns(
759 "test_timer",
760 (*test_clock.timestamp_ns() + 1000).into(),
761 None,
762 )
763 .unwrap();
764 assert_eq!(test_clock.timer_count(), 1);
765 test_clock.cancel_timer("test_timer");
766 assert_eq!(test_clock.timer_count(), 0);
767 }
768
769 #[rstest]
770 fn test_time_advancement(mut test_clock: TestClock) {
771 let start_time = test_clock.timestamp_ns();
772 test_clock
773 .set_timer_ns("test_timer", 1000, start_time, None, None)
774 .unwrap();
775 let events = test_clock.advance_time((*start_time + 2500).into(), true);
776 assert_eq!(events.len(), 2);
777 assert_eq!(*events[0].ts_event, *start_time + 1000);
778 assert_eq!(*events[1].ts_event, *start_time + 2000);
779 }
780
781 #[test]
782 fn test_default_and_custom_callbacks() {
783 let mut clock = TestClock::new();
784 let default_called = Rc::new(RefCell::new(false));
785 let custom_called = Rc::new(RefCell::new(false));
786
787 let default_callback = TestCallback::new(Rc::clone(&default_called));
788 let custom_callback = TestCallback::new(Rc::clone(&custom_called));
789
790 clock.register_default_handler(TimeEventCallback::from(default_callback));
791 clock
792 .set_time_alert_ns("default_timer", (*clock.timestamp_ns() + 1000).into(), None)
793 .unwrap();
794 clock
795 .set_time_alert_ns(
796 "custom_timer",
797 (*clock.timestamp_ns() + 1000).into(),
798 Some(TimeEventCallback::from(custom_callback)),
799 )
800 .unwrap();
801
802 let events = clock.advance_time((*clock.timestamp_ns() + 1000).into(), true);
803 let handlers = clock.match_handlers(events);
804
805 for handler in handlers {
806 handler.callback.call(handler.event);
807 }
808
809 assert!(*default_called.borrow());
810 assert!(*custom_called.borrow());
811 }
812
813 #[rstest]
814 fn test_multiple_timers(mut test_clock: TestClock) {
815 let start_time = test_clock.timestamp_ns();
816 test_clock
817 .set_timer_ns("timer1", 1000, start_time, None, None)
818 .unwrap();
819 test_clock
820 .set_timer_ns("timer2", 2000, start_time, None, None)
821 .unwrap();
822 let events = test_clock.advance_time((*start_time + 2000).into(), true);
823 assert_eq!(events.len(), 3);
824 assert_eq!(events[0].name.as_str(), "timer1");
825 assert_eq!(events[1].name.as_str(), "timer1");
826 assert_eq!(events[2].name.as_str(), "timer2");
827 }
828}