1use std::{
19 collections::{BTreeMap, BinaryHeap, HashMap},
20 ops::Deref,
21 pin::Pin,
22 sync::Arc,
23 task::{Context, Poll},
24};
25
26use chrono::{DateTime, Utc};
27use futures::Stream;
28use nautilus_core::{
29 correctness::{check_positive_u64, check_predicate_true, check_valid_string},
30 time::get_atomic_clock_realtime,
31 AtomicTime, UnixNanos,
32};
33use tokio::sync::Mutex;
34use ustr::Ustr;
35
36use crate::timer::{
37 create_valid_interval, LiveTimer, TestTimer, TimeEvent, TimeEventCallback, TimeEventHandlerV2,
38};
39
40pub trait Clock {
46 fn utc_now(&self) -> DateTime<Utc> {
48 DateTime::from_timestamp_nanos(self.timestamp_ns().as_i64())
49 }
50
51 fn timestamp_ns(&self) -> UnixNanos;
53
54 fn timestamp_us(&self) -> u64;
56
57 fn timestamp_ms(&self) -> u64;
59
60 fn timestamp(&self) -> f64;
62
63 fn timer_names(&self) -> Vec<&str>;
65
66 fn timer_count(&self) -> usize;
68
69 fn register_default_handler(&mut self, callback: TimeEventCallback);
72
73 fn get_handler(&self, event: TimeEvent) -> TimeEventHandlerV2;
77
78 fn set_time_alert_ns(
81 &mut self,
82 name: &str,
83 alert_time_ns: UnixNanos,
84 callback: Option<TimeEventCallback>,
85 ) -> anyhow::Result<()>;
86
87 fn set_timer_ns(
91 &mut self,
92 name: &str,
93 interval_ns: u64,
94 start_time_ns: UnixNanos,
95 stop_time_ns: Option<UnixNanos>,
96 callback: Option<TimeEventCallback>,
97 ) -> anyhow::Result<()>;
98
99 fn next_time_ns(&self, name: &str) -> UnixNanos;
103 fn cancel_timer(&mut self, name: &str);
104 fn cancel_timers(&mut self);
105}
106
107pub struct TestClock {
111 time: AtomicTime,
112 timers: BTreeMap<Ustr, TestTimer>,
115 default_callback: Option<TimeEventCallback>,
116 callbacks: HashMap<Ustr, TimeEventCallback>,
117 heap: BinaryHeap<TimeEvent>,
118}
119
120impl TestClock {
121 #[must_use]
123 pub fn new() -> Self {
124 Self {
125 time: AtomicTime::new(false, UnixNanos::default()),
126 timers: BTreeMap::new(),
127 default_callback: None,
128 callbacks: HashMap::new(),
129 heap: BinaryHeap::new(),
130 }
131 }
132
133 #[must_use]
135 pub const fn get_timers(&self) -> &BTreeMap<Ustr, TestTimer> {
136 &self.timers
137 }
138
139 pub fn advance_time(&mut self, to_time_ns: UnixNanos, set_time: bool) -> Vec<TimeEvent> {
148 assert!(
150 to_time_ns >= self.time.get_time_ns(),
151 "`to_time_ns` {} was < `self.time.get_time_ns()` {}",
152 to_time_ns,
153 self.time.get_time_ns()
154 );
155
156 if set_time {
157 self.time.set_time(to_time_ns);
158 }
159
160 let mut events: Vec<TimeEvent> = Vec::new();
162 self.timers.retain(|_, timer| {
163 timer.advance(to_time_ns).for_each(|event| {
164 events.push(event);
165 });
166
167 !timer.is_expired()
168 });
169
170 events.sort_by(|a, b| a.ts_event.cmp(&b.ts_event));
171 events
172 }
173
174 pub fn advance_to_time_on_heap(&mut self, to_time_ns: UnixNanos) {
180 assert!(
182 to_time_ns >= self.time.get_time_ns(),
183 "`to_time_ns` {} was < `self.time.get_time_ns()` {}",
184 to_time_ns,
185 self.time.get_time_ns()
186 );
187
188 self.time.set_time(to_time_ns);
189
190 self.timers.retain(|_, timer| {
192 timer.advance(to_time_ns).for_each(|event| {
193 self.heap.push(event);
194 });
195
196 !timer.is_expired()
197 });
198 }
199
200 #[must_use]
206 pub fn match_handlers(&self, events: Vec<TimeEvent>) -> Vec<TimeEventHandlerV2> {
207 events
208 .into_iter()
209 .map(|event| {
210 let callback = self.callbacks.get(&event.name).cloned().unwrap_or_else(|| {
211 self.default_callback
214 .clone()
215 .expect("Default callback should exist")
216 });
217 TimeEventHandlerV2::new(event, callback)
218 })
219 .collect()
220 }
221}
222
223impl Iterator for TestClock {
224 type Item = TimeEventHandlerV2;
225
226 fn next(&mut self) -> Option<Self::Item> {
227 self.heap.pop().map(|event| self.get_handler(event))
228 }
229}
230
231impl Default for TestClock {
232 fn default() -> Self {
234 Self::new()
235 }
236}
237
238impl Deref for TestClock {
239 type Target = AtomicTime;
240
241 fn deref(&self) -> &Self::Target {
242 &self.time
243 }
244}
245
246impl Clock for TestClock {
247 fn timestamp_ns(&self) -> UnixNanos {
248 self.time.get_time_ns()
249 }
250
251 fn timestamp_us(&self) -> u64 {
252 self.time.get_time_us()
253 }
254
255 fn timestamp_ms(&self) -> u64 {
256 self.time.get_time_ms()
257 }
258
259 fn timestamp(&self) -> f64 {
260 self.time.get_time()
261 }
262
263 fn timer_names(&self) -> Vec<&str> {
264 self.timers
265 .iter()
266 .filter(|(_, timer)| !timer.is_expired())
267 .map(|(k, _)| k.as_str())
268 .collect()
269 }
270
271 fn timer_count(&self) -> usize {
272 self.timers
273 .iter()
274 .filter(|(_, timer)| !timer.is_expired())
275 .count()
276 }
277
278 fn register_default_handler(&mut self, callback: TimeEventCallback) {
279 self.default_callback = Some(callback);
280 }
281
282 fn get_handler(&self, event: TimeEvent) -> TimeEventHandlerV2 {
283 let callback = self
285 .callbacks
286 .get(&event.name)
287 .cloned()
288 .or_else(|| self.default_callback.clone())
289 .unwrap_or_else(|| panic!("Event '{}' should have associated handler", event.name));
290
291 TimeEventHandlerV2::new(event, callback)
292 }
293
294 fn set_time_alert_ns(
295 &mut self,
296 name: &str,
297 alert_time_ns: UnixNanos,
298 callback: Option<TimeEventCallback>,
299 ) -> anyhow::Result<()> {
300 check_valid_string(name, stringify!(name))?;
301 check_predicate_true(
302 callback.is_some() | self.default_callback.is_some(),
303 "No callbacks provided",
304 )?;
305
306 let name_ustr = Ustr::from(name);
307 match callback {
308 Some(callback_py) => self.callbacks.insert(name_ustr, callback_py),
309 None => None,
310 };
311
312 let ts_now = self.time.get_time_ns();
313 let interval_ns = create_valid_interval(std::cmp::max((alert_time_ns - ts_now).into(), 1));
315 let timer = TestTimer::new(name, interval_ns, ts_now, Some(alert_time_ns));
316 self.timers.insert(name_ustr, timer);
317
318 Ok(())
319 }
320
321 fn set_timer_ns(
322 &mut self,
323 name: &str,
324 interval_ns: u64,
325 start_time_ns: UnixNanos,
326 stop_time_ns: Option<UnixNanos>,
327 callback: Option<TimeEventCallback>,
328 ) -> anyhow::Result<()> {
329 check_valid_string(name, stringify!(name))?;
330 check_positive_u64(interval_ns, stringify!(interval_ns))?;
331 check_predicate_true(
332 callback.is_some() | self.default_callback.is_some(),
333 "No callbacks provided",
334 )?;
335
336 let name_ustr = Ustr::from(name);
337 match callback {
338 Some(callback_py) => self.callbacks.insert(name_ustr, callback_py),
339 None => None,
340 };
341
342 let interval_ns = create_valid_interval(interval_ns);
343 let timer = TestTimer::new(name, interval_ns, start_time_ns, stop_time_ns);
344 self.timers.insert(name_ustr, timer);
345
346 Ok(())
347 }
348
349 fn next_time_ns(&self, name: &str) -> UnixNanos {
350 let timer = self.timers.get(&Ustr::from(name));
351 match timer {
352 None => 0.into(),
353 Some(timer) => timer.next_time_ns(),
354 }
355 }
356
357 fn cancel_timer(&mut self, name: &str) {
358 let timer = self.timers.remove(&Ustr::from(name));
359 match timer {
360 None => {}
361 Some(mut timer) => timer.cancel(),
362 }
363 }
364
365 fn cancel_timers(&mut self) {
366 for timer in &mut self.timers.values_mut() {
367 timer.cancel();
368 }
369 self.timers = BTreeMap::new();
370 }
371}
372
373pub struct LiveClock {
377 time: &'static AtomicTime,
378 timers: HashMap<Ustr, LiveTimer>,
379 default_callback: Option<TimeEventCallback>,
380 pub heap: Arc<Mutex<BinaryHeap<TimeEvent>>>,
381 #[allow(dead_code)]
382 callbacks: HashMap<Ustr, TimeEventCallback>,
383}
384
385impl LiveClock {
386 #[must_use]
388 pub fn new() -> Self {
389 Self {
390 time: get_atomic_clock_realtime(),
391 timers: HashMap::new(),
392 default_callback: None,
393 heap: Arc::new(Mutex::new(BinaryHeap::new())),
394 callbacks: HashMap::new(),
395 }
396 }
397
398 #[must_use]
399 pub fn get_event_stream(&self) -> TimeEventStream {
400 TimeEventStream::new(self.heap.clone())
401 }
402
403 #[must_use]
404 pub const fn get_timers(&self) -> &HashMap<Ustr, LiveTimer> {
405 &self.timers
406 }
407
408 fn clear_expired_timers(&mut self) {
410 self.timers.retain(|_, timer| !timer.is_expired());
411 }
412}
413
414impl Default for LiveClock {
415 fn default() -> Self {
417 Self::new()
418 }
419}
420
421impl Deref for LiveClock {
422 type Target = AtomicTime;
423
424 fn deref(&self) -> &Self::Target {
425 self.time
426 }
427}
428
429impl Clock for LiveClock {
430 fn timestamp_ns(&self) -> UnixNanos {
431 self.time.get_time_ns()
432 }
433
434 fn timestamp_us(&self) -> u64 {
435 self.time.get_time_us()
436 }
437
438 fn timestamp_ms(&self) -> u64 {
439 self.time.get_time_ms()
440 }
441
442 fn timestamp(&self) -> f64 {
443 self.time.get_time()
444 }
445
446 fn timer_names(&self) -> Vec<&str> {
447 self.timers
448 .iter()
449 .filter(|(_, timer)| !timer.is_expired())
450 .map(|(k, _)| k.as_str())
451 .collect()
452 }
453
454 fn timer_count(&self) -> usize {
455 self.timers
456 .iter()
457 .filter(|(_, timer)| !timer.is_expired())
458 .count()
459 }
460
461 fn register_default_handler(&mut self, handler: TimeEventCallback) {
462 self.default_callback = Some(handler);
463 }
464
465 #[allow(unused_variables)]
466 fn get_handler(&self, event: TimeEvent) -> TimeEventHandlerV2 {
467 #[cfg(not(feature = "clock_v2"))]
468 panic!("Cannot get live clock handler without 'clock_v2' feature");
469
470 #[cfg(feature = "clock_v2")]
472 {
473 let callback = self
474 .callbacks
475 .get(&event.name)
476 .cloned()
477 .or_else(|| self.default_callback.clone())
478 .unwrap_or_else(|| panic!("Event '{}' should have associated handler", event.name));
479
480 TimeEventHandlerV2::new(event, callback)
481 }
482 }
483
484 fn set_time_alert_ns(
485 &mut self,
486 name: &str,
487 mut alert_time_ns: UnixNanos,
488 callback: Option<TimeEventCallback>,
489 ) -> anyhow::Result<()> {
490 check_valid_string(name, stringify!(name))?;
491 check_predicate_true(
492 callback.is_some() | self.default_callback.is_some(),
493 "No callbacks provided",
494 )?;
495
496 let callback = match callback {
497 Some(callback) => callback,
498 None => self.default_callback.clone().unwrap(),
499 };
500
501 #[cfg(feature = "clock_v2")]
502 {
503 let name = Ustr::from(name);
504 self.callbacks.insert(name, callback.clone());
505 }
506
507 let ts_now = self.get_time_ns();
508 alert_time_ns = std::cmp::max(alert_time_ns, ts_now);
509 let interval_ns = create_valid_interval(std::cmp::max((alert_time_ns - ts_now).into(), 1));
511
512 #[cfg(not(feature = "clock_v2"))]
513 let mut timer = LiveTimer::new(name, interval_ns, ts_now, Some(alert_time_ns), callback);
514 #[cfg(feature = "clock_v2")]
515 let mut timer = LiveTimer::new(
516 name,
517 interval_ns,
518 ts_now,
519 Some(alert_time_ns),
520 callback,
521 self.heap.clone(),
522 );
523
524 timer.start();
525
526 self.clear_expired_timers();
527 self.timers.insert(Ustr::from(name), timer);
528
529 Ok(())
530 }
531
532 fn set_timer_ns(
533 &mut self,
534 name: &str,
535 interval_ns: u64,
536 start_time_ns: UnixNanos,
537 stop_time_ns: Option<UnixNanos>,
538 callback: Option<TimeEventCallback>,
539 ) -> anyhow::Result<()> {
540 check_valid_string(name, stringify!(name))?;
541 check_positive_u64(interval_ns, stringify!(interval_ns))?;
542 check_predicate_true(
543 callback.is_some() | self.default_callback.is_some(),
544 "No callbacks provided",
545 )?;
546
547 let callback = match callback {
548 Some(callback) => callback,
549 None => self.default_callback.clone().unwrap(),
550 };
551
552 #[cfg(feature = "clock_v2")]
553 {
554 let name = Ustr::from(name);
555 self.callbacks.insert(name, callback.clone());
556 }
557
558 let mut start_time_ns = start_time_ns;
561 if start_time_ns == 0 {
562 start_time_ns = self.timestamp_ns();
564 }
565 let interval_ns = create_valid_interval(interval_ns);
566
567 #[cfg(not(feature = "clock_v2"))]
568 let mut timer = LiveTimer::new(name, interval_ns, start_time_ns, stop_time_ns, callback);
569 #[cfg(feature = "clock_v2")]
570 let mut timer = LiveTimer::new(
571 name,
572 interval_ns,
573 start_time_ns,
574 stop_time_ns,
575 callback,
576 self.heap.clone(),
577 );
578 timer.start();
579
580 self.clear_expired_timers();
581 self.timers.insert(Ustr::from(name), timer);
582
583 Ok(())
584 }
585
586 fn next_time_ns(&self, name: &str) -> UnixNanos {
587 let timer = self.timers.get(&Ustr::from(name));
588 match timer {
589 None => 0.into(),
590 Some(timer) => timer.next_time_ns(),
591 }
592 }
593
594 fn cancel_timer(&mut self, name: &str) {
595 let timer = self.timers.remove(&Ustr::from(name));
596 match timer {
597 None => {}
598 Some(mut timer) => {
599 timer.cancel();
600 }
601 }
602 }
603
604 fn cancel_timers(&mut self) {
605 for timer in &mut self.timers.values_mut() {
606 timer.cancel();
607 }
608 self.timers.clear();
609 }
610}
611
612pub struct TimeEventStream {
614 heap: Arc<Mutex<BinaryHeap<TimeEvent>>>,
615}
616
617impl TimeEventStream {
618 pub const fn new(heap: Arc<Mutex<BinaryHeap<TimeEvent>>>) -> Self {
619 Self { heap }
620 }
621}
622
623impl Stream for TimeEventStream {
624 type Item = TimeEvent;
625
626 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
627 let mut heap = match self.heap.try_lock() {
628 Ok(guard) => guard,
629 Err(e) => {
630 tracing::error!("Unable to get LiveClock heap lock: {e}");
631 cx.waker().wake_by_ref();
632 return Poll::Pending;
633 }
634 };
635
636 if let Some(event) = heap.pop() {
637 Poll::Ready(Some(event))
638 } else {
639 cx.waker().wake_by_ref();
640 Poll::Pending
641 }
642 }
643}
644
645#[cfg(test)]
649mod tests {
650 use std::{cell::RefCell, rc::Rc};
651
652 use rstest::{fixture, rstest};
653
654 use super::*;
655
656 #[derive(Default)]
657 struct TestCallback {
658 called: Rc<RefCell<bool>>,
659 }
660
661 impl TestCallback {
662 const fn new(called: Rc<RefCell<bool>>) -> Self {
663 Self { called }
664 }
665 }
666
667 impl From<TestCallback> for TimeEventCallback {
668 fn from(callback: TestCallback) -> Self {
669 Self::Rust(Rc::new(move |_event: TimeEvent| {
670 *callback.called.borrow_mut() = true;
671 }))
672 }
673 }
674
675 #[fixture]
676 pub fn test_clock() -> TestClock {
677 let mut clock = TestClock::new();
678 clock.register_default_handler(TestCallback::default().into());
679 clock
680 }
681
682 #[rstest]
683 fn test_time_monotonicity(mut test_clock: TestClock) {
684 let initial_time = test_clock.timestamp_ns();
685 test_clock.advance_time((*initial_time + 1000).into(), true);
686 assert!(test_clock.timestamp_ns() > initial_time);
687 }
688
689 #[rstest]
690 fn test_timer_registration(mut test_clock: TestClock) {
691 test_clock
692 .set_time_alert_ns(
693 "test_timer",
694 (*test_clock.timestamp_ns() + 1000).into(),
695 None,
696 )
697 .unwrap();
698 assert_eq!(test_clock.timer_count(), 1);
699 assert_eq!(test_clock.timer_names(), vec!["test_timer"]);
700 }
701
702 #[rstest]
703 fn test_timer_expiration(mut test_clock: TestClock) {
704 let alert_time = (*test_clock.timestamp_ns() + 1000).into();
705 test_clock
706 .set_time_alert_ns("test_timer", alert_time, None)
707 .unwrap();
708 let events = test_clock.advance_time(alert_time, true);
709 assert_eq!(events.len(), 1);
710 assert_eq!(events[0].name.as_str(), "test_timer");
711 }
712
713 #[rstest]
714 fn test_timer_cancellation(mut test_clock: TestClock) {
715 test_clock
716 .set_time_alert_ns(
717 "test_timer",
718 (*test_clock.timestamp_ns() + 1000).into(),
719 None,
720 )
721 .unwrap();
722 assert_eq!(test_clock.timer_count(), 1);
723 test_clock.cancel_timer("test_timer");
724 assert_eq!(test_clock.timer_count(), 0);
725 }
726
727 #[rstest]
728 fn test_time_advancement(mut test_clock: TestClock) {
729 let start_time = test_clock.timestamp_ns();
730 test_clock
731 .set_timer_ns("test_timer", 1000, start_time, None, None)
732 .unwrap();
733 let events = test_clock.advance_time((*start_time + 2500).into(), true);
734 assert_eq!(events.len(), 2);
735 assert_eq!(*events[0].ts_event, *start_time + 1000);
736 assert_eq!(*events[1].ts_event, *start_time + 2000);
737 }
738
739 #[test]
740 fn test_default_and_custom_callbacks() {
741 let mut clock = TestClock::new();
742 let default_called = Rc::new(RefCell::new(false));
743 let custom_called = Rc::new(RefCell::new(false));
744
745 let default_callback = TestCallback::new(Rc::clone(&default_called));
746 let custom_callback = TestCallback::new(Rc::clone(&custom_called));
747
748 clock.register_default_handler(TimeEventCallback::from(default_callback));
749 clock
750 .set_time_alert_ns("default_timer", (*clock.timestamp_ns() + 1000).into(), None)
751 .unwrap();
752 clock
753 .set_time_alert_ns(
754 "custom_timer",
755 (*clock.timestamp_ns() + 1000).into(),
756 Some(TimeEventCallback::from(custom_callback)),
757 )
758 .unwrap();
759
760 let events = clock.advance_time((*clock.timestamp_ns() + 1000).into(), true);
761 let handlers = clock.match_handlers(events);
762
763 for handler in handlers {
764 handler.callback.call(handler.event);
765 }
766
767 assert!(*default_called.borrow());
768 assert!(*custom_called.borrow());
769 }
770
771 #[rstest]
772 fn test_multiple_timers(mut test_clock: TestClock) {
773 let start_time = test_clock.timestamp_ns();
774 test_clock
775 .set_timer_ns("timer1", 1000, start_time, None, None)
776 .unwrap();
777 test_clock
778 .set_timer_ns("timer2", 2000, start_time, None, None)
779 .unwrap();
780 let events = test_clock.advance_time((*start_time + 2000).into(), true);
781 assert_eq!(events.len(), 3);
782 assert_eq!(events[0].name.as_str(), "timer1");
783 assert_eq!(events[1].name.as_str(), "timer1");
784 assert_eq!(events[2].name.as_str(), "timer2");
785 }
786}