1use std::{
17 collections::{HashMap, VecDeque},
18 sync::{
19 Arc,
20 atomic::{AtomicBool, Ordering},
21 },
22 time::{Duration, Instant},
23};
24
25use bytes::Bytes;
26use futures::stream::Stream;
27use nautilus_common::{
28 msgbus::{
29 BusMessage, CLOSE_TOPIC,
30 database::{DatabaseConfig, MessageBusConfig, MessageBusDatabaseAdapter},
31 },
32 runtime::get_runtime,
33};
34use nautilus_core::{
35 UUID4,
36 time::{duration_since_unix_epoch, get_atomic_clock_realtime},
37};
38use nautilus_cryptography::providers::install_cryptographic_provider;
39use nautilus_model::identifiers::TraderId;
40use redis::*;
41use streams::StreamReadOptions;
42
43use super::{REDIS_MINID, REDIS_XTRIM, await_handle};
44use crate::redis::{create_redis_connection, get_stream_key};
45
46const MSGBUS_PUBLISH: &str = "msgbus-publish";
47const MSGBUS_STREAM: &str = "msgbus-stream";
48const MSGBUS_HEARTBEAT: &str = "msgbus-heartbeat";
49const HEARTBEAT_TOPIC: &str = "health:heartbeat";
50const TRIM_BUFFER_SECS: u64 = 60;
51
52type RedisStreamBulk = Vec<HashMap<String, Vec<HashMap<String, redis::Value>>>>;
53
54#[cfg_attr(
55 feature = "python",
56 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.infrastructure")
57)]
58pub struct RedisMessageBusDatabase {
59 pub trader_id: TraderId,
61 pub instance_id: UUID4,
63 pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
64 pub_handle: Option<tokio::task::JoinHandle<()>>,
65 stream_rx: Option<tokio::sync::mpsc::Receiver<BusMessage>>,
66 stream_handle: Option<tokio::task::JoinHandle<()>>,
67 stream_signal: Arc<AtomicBool>,
68 heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
69 heartbeat_signal: Arc<AtomicBool>,
70}
71
72impl MessageBusDatabaseAdapter for RedisMessageBusDatabase {
73 type DatabaseType = RedisMessageBusDatabase;
74
75 fn new(
77 trader_id: TraderId,
78 instance_id: UUID4,
79 config: MessageBusConfig,
80 ) -> anyhow::Result<Self> {
81 install_cryptographic_provider();
82
83 let config_clone = config.clone();
84 let db_config = config
85 .database
86 .clone()
87 .ok_or_else(|| anyhow::anyhow!("No database config"))?;
88
89 let (pub_tx, pub_rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
90
91 let pub_handle = Some(get_runtime().spawn(async move {
93 if let Err(e) = publish_messages(pub_rx, trader_id, instance_id, config_clone).await {
94 log::error!("Error in task '{MSGBUS_PUBLISH}': {e}");
95 };
96 }));
97
98 let external_streams = config.external_streams.clone().unwrap_or_default();
100 let stream_signal = Arc::new(AtomicBool::new(false));
101 let (stream_rx, stream_handle) = if !external_streams.is_empty() {
102 let stream_signal_clone = stream_signal.clone();
103 let (stream_tx, stream_rx) = tokio::sync::mpsc::channel::<BusMessage>(100_000);
104 (
105 Some(stream_rx),
106 Some(get_runtime().spawn(async move {
107 if let Err(e) =
108 stream_messages(stream_tx, db_config, external_streams, stream_signal_clone)
109 .await
110 {
111 log::error!("Error in task '{MSGBUS_STREAM}': {e}");
112 }
113 })),
114 )
115 } else {
116 (None, None)
117 };
118
119 let heartbeat_signal = Arc::new(AtomicBool::new(false));
121 let heartbeat_handle = if let Some(heartbeat_interval_secs) = config.heartbeat_interval_secs
122 {
123 let signal = heartbeat_signal.clone();
124 let pub_tx_clone = pub_tx.clone();
125
126 Some(get_runtime().spawn(async move {
127 run_heartbeat(heartbeat_interval_secs, signal, pub_tx_clone).await
128 }))
129 } else {
130 None
131 };
132
133 Ok(Self {
134 trader_id,
135 instance_id,
136 pub_tx,
137 pub_handle,
138 stream_rx,
139 stream_handle,
140 stream_signal,
141 heartbeat_handle,
142 heartbeat_signal,
143 })
144 }
145
146 fn is_closed(&self) -> bool {
148 self.pub_tx.is_closed()
149 }
150
151 fn publish(&self, topic: String, payload: Bytes) {
153 let msg = BusMessage { topic, payload };
154 if let Err(e) = self.pub_tx.send(msg) {
155 log::error!("Failed to send message: {e}");
156 }
157 }
158
159 fn close(&mut self) {
161 log::debug!("Closing");
162
163 self.stream_signal.store(true, Ordering::Relaxed);
164 self.heartbeat_signal.store(true, Ordering::Relaxed);
165
166 if !self.pub_tx.is_closed() {
167 let msg = BusMessage {
168 topic: CLOSE_TOPIC.to_string(),
169 payload: Bytes::new(), };
171 if let Err(e) = self.pub_tx.send(msg) {
172 log::error!("Failed to send close message: {e:?}");
173 }
174 }
175
176 tokio::task::block_in_place(|| {
178 get_runtime().block_on(async {
179 self.close_async().await;
180 });
181 });
182
183 log::debug!("Closed");
184 }
185}
186
187impl RedisMessageBusDatabase {
188 pub fn get_stream_receiver(
190 &mut self,
191 ) -> anyhow::Result<tokio::sync::mpsc::Receiver<BusMessage>> {
192 self.stream_rx
193 .take()
194 .ok_or_else(|| anyhow::anyhow!("Stream receiver already taken"))
195 }
196
197 pub fn stream(
199 mut stream_rx: tokio::sync::mpsc::Receiver<BusMessage>,
200 ) -> impl Stream<Item = BusMessage> + 'static {
201 async_stream::stream! {
202 while let Some(msg) = stream_rx.recv().await {
203 yield msg;
204 }
205 }
206 }
207
208 pub async fn close_async(&mut self) {
209 await_handle(self.pub_handle.take(), MSGBUS_PUBLISH).await;
210 await_handle(self.stream_handle.take(), MSGBUS_STREAM).await;
211 await_handle(self.heartbeat_handle.take(), MSGBUS_HEARTBEAT).await;
212 }
213}
214
215pub async fn publish_messages(
216 mut rx: tokio::sync::mpsc::UnboundedReceiver<BusMessage>,
217 trader_id: TraderId,
218 instance_id: UUID4,
219 config: MessageBusConfig,
220) -> anyhow::Result<()> {
221 tracing::debug!("Starting message publishing");
222
223 let db_config = config
224 .database
225 .as_ref()
226 .ok_or_else(|| anyhow::anyhow!("No database config"))?;
227 let mut con = create_redis_connection(MSGBUS_PUBLISH, db_config.clone()).await?;
228 let stream_key = get_stream_key(trader_id, instance_id, &config);
229
230 let autotrim_duration = config
232 .autotrim_mins
233 .filter(|&mins| mins > 0)
234 .map(|mins| Duration::from_secs(mins as u64 * 60));
235 let mut last_trim_index: HashMap<String, usize> = HashMap::new();
236
237 let mut buffer: VecDeque<BusMessage> = VecDeque::new();
239 let mut last_drain = Instant::now();
240 let buffer_interval = Duration::from_millis(config.buffer_interval_ms.unwrap_or(0) as u64);
241
242 loop {
243 if last_drain.elapsed() >= buffer_interval && !buffer.is_empty() {
244 drain_buffer(
245 &mut con,
246 &stream_key,
247 config.stream_per_topic,
248 autotrim_duration,
249 &mut last_trim_index,
250 &mut buffer,
251 )
252 .await?;
253 last_drain = Instant::now();
254 } else {
255 match rx.recv().await {
256 Some(msg) => {
257 if msg.topic == CLOSE_TOPIC {
258 tracing::debug!("Received close message");
259 drop(rx);
260 break;
261 }
262 buffer.push_back(msg);
263 }
264 None => {
265 tracing::debug!("Channel hung up");
266 break;
267 }
268 }
269 }
270 }
271
272 if !buffer.is_empty() {
274 drain_buffer(
275 &mut con,
276 &stream_key,
277 config.stream_per_topic,
278 autotrim_duration,
279 &mut last_trim_index,
280 &mut buffer,
281 )
282 .await?;
283 }
284
285 tracing::debug!("Stopped message publishing");
286 Ok(())
287}
288
289async fn drain_buffer(
290 conn: &mut redis::aio::ConnectionManager,
291 stream_key: &str,
292 stream_per_topic: bool,
293 autotrim_duration: Option<Duration>,
294 last_trim_index: &mut HashMap<String, usize>,
295 buffer: &mut VecDeque<BusMessage>,
296) -> anyhow::Result<()> {
297 let mut pipe = redis::pipe();
298 pipe.atomic();
299
300 for msg in buffer.drain(..) {
301 let items: Vec<(&str, &[u8])> = vec![
302 ("topic", msg.topic.as_ref()),
303 ("payload", msg.payload.as_ref()),
304 ];
305 let stream_key = match stream_per_topic {
306 true => format!("{stream_key}:{}", &msg.topic),
307 false => stream_key.to_string(),
308 };
309 pipe.xadd(&stream_key, "*", &items);
310
311 if autotrim_duration.is_none() {
312 continue; }
314
315 let last_trim_ms = last_trim_index.entry(stream_key.clone()).or_insert(0); let unix_duration_now = duration_since_unix_epoch();
318 let trim_buffer = Duration::from_secs(TRIM_BUFFER_SECS);
319
320 if *last_trim_ms < (unix_duration_now - trim_buffer).as_millis() as usize {
322 let min_timestamp_ms =
323 (unix_duration_now - autotrim_duration.unwrap()).as_millis() as usize;
324 let result: Result<(), redis::RedisError> = redis::cmd(REDIS_XTRIM)
325 .arg(stream_key.clone())
326 .arg(REDIS_MINID)
327 .arg(min_timestamp_ms)
328 .query_async(conn)
329 .await;
330
331 if let Err(e) = result {
332 tracing::error!("Error trimming stream '{stream_key}': {e}");
333 } else {
334 last_trim_index.insert(
335 stream_key.to_string(),
336 unix_duration_now.as_millis() as usize,
337 );
338 }
339 }
340 }
341
342 pipe.query_async(conn).await.map_err(anyhow::Error::from)
343}
344
345pub async fn stream_messages(
346 tx: tokio::sync::mpsc::Sender<BusMessage>,
347 config: DatabaseConfig,
348 stream_keys: Vec<String>,
349 stream_signal: Arc<AtomicBool>,
350) -> anyhow::Result<()> {
351 tracing::info!("Starting message streaming");
352 let mut con = create_redis_connection(MSGBUS_STREAM, config).await?;
353
354 let stream_keys = &stream_keys
355 .iter()
356 .map(String::as_str)
357 .collect::<Vec<&str>>();
358
359 tracing::debug!("Listening to streams: [{}]", stream_keys.join(", "));
360
361 let clock = get_atomic_clock_realtime();
363 let timestamp_ms = clock.get_time_ms();
364 let mut last_id = timestamp_ms.to_string();
365
366 let opts = StreamReadOptions::default().block(100);
367
368 'outer: loop {
369 if stream_signal.load(Ordering::Relaxed) {
370 tracing::debug!("Received streaming terminate signal");
371 break;
372 }
373 let result: Result<RedisStreamBulk, _> =
374 con.xread_options(&[&stream_keys], &[&last_id], &opts).await;
375 match result {
376 Ok(stream_bulk) => {
377 if stream_bulk.is_empty() {
378 continue;
380 }
381 for entry in stream_bulk.iter() {
382 for (_stream_key, stream_msgs) in entry.iter() {
383 for stream_msg in stream_msgs.iter() {
384 for (id, array) in stream_msg {
385 last_id.clear();
386 last_id.push_str(id);
387 match decode_bus_message(array) {
388 Ok(msg) => {
389 if let Err(e) = tx.send(msg).await {
390 tracing::debug!("Channel closed: {e:?}");
391 break 'outer; }
393 }
394 Err(e) => {
395 tracing::error!("{e:?}");
396 continue;
397 }
398 }
399 }
400 }
401 }
402 }
403 }
404 Err(e) => {
405 anyhow::bail!("Error reading from stream: {e:?}");
406 }
407 }
408 }
409
410 tracing::debug!("Stopped message streaming");
411 Ok(())
412}
413
414fn decode_bus_message(stream_msg: &redis::Value) -> anyhow::Result<BusMessage> {
415 if let redis::Value::Array(stream_msg) = stream_msg {
416 if stream_msg.len() < 4 {
417 anyhow::bail!("Invalid stream message format: {stream_msg:?}");
418 }
419
420 let topic = match &stream_msg[1] {
421 redis::Value::BulkString(bytes) => match String::from_utf8(bytes.clone()) {
422 Ok(topic) => topic,
423 Err(e) => anyhow::bail!("Error parsing topic: {e}"),
424 },
425 _ => {
426 anyhow::bail!("Invalid topic format: {stream_msg:?}");
427 }
428 };
429
430 let payload = match &stream_msg[3] {
431 redis::Value::BulkString(bytes) => Bytes::copy_from_slice(bytes),
432 _ => {
433 anyhow::bail!("Invalid payload format: {stream_msg:?}");
434 }
435 };
436
437 Ok(BusMessage { topic, payload })
438 } else {
439 anyhow::bail!("Invalid stream message format: {stream_msg:?}")
440 }
441}
442
443async fn run_heartbeat(
444 heartbeat_interval_secs: u16,
445 signal: Arc<AtomicBool>,
446 pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
447) {
448 tracing::debug!("Starting heartbeat at {heartbeat_interval_secs} second intervals");
449
450 let heartbeat_interval = Duration::from_secs(heartbeat_interval_secs as u64);
451 let heartbeat_timer = tokio::time::interval(heartbeat_interval);
452
453 let check_interval = Duration::from_millis(100);
454 let check_timer = tokio::time::interval(check_interval);
455
456 tokio::pin!(heartbeat_timer);
457 tokio::pin!(check_timer);
458
459 loop {
460 if signal.load(Ordering::Relaxed) {
461 tracing::debug!("Received heartbeat terminate signal");
462 break;
463 }
464
465 tokio::select! {
466 _ = heartbeat_timer.tick() => {
467 let heartbeat = create_heartbeat_msg();
468 if let Err(e) = pub_tx.send(heartbeat) {
469 tracing::debug!("Error sending heartbeat: {e}");
471 }
472 },
473 _ = check_timer.tick() => {}
474 }
475 }
476
477 tracing::debug!("Stopped heartbeat");
478}
479
480fn create_heartbeat_msg() -> BusMessage {
481 BusMessage {
482 topic: HEARTBEAT_TOPIC.to_string(),
483 payload: Bytes::from(chrono::Utc::now().to_rfc3339().into_bytes()),
484 }
485}
486
487#[cfg(test)]
491mod tests {
492 use redis::Value;
493 use rstest::*;
494
495 use super::*;
496
497 #[rstest]
498 fn test_decode_bus_message_valid() {
499 let stream_msg = Value::Array(vec![
500 Value::BulkString(b"0".to_vec()),
501 Value::BulkString(b"topic1".to_vec()),
502 Value::BulkString(b"unused".to_vec()),
503 Value::BulkString(b"data1".to_vec()),
504 ]);
505
506 let result = decode_bus_message(&stream_msg);
507 assert!(result.is_ok());
508 let msg = result.unwrap();
509 assert_eq!(msg.topic, "topic1");
510 assert_eq!(msg.payload, Bytes::from("data1"));
511 }
512
513 #[rstest]
514 fn test_decode_bus_message_missing_fields() {
515 let stream_msg = Value::Array(vec![
516 Value::BulkString(b"0".to_vec()),
517 Value::BulkString(b"topic1".to_vec()),
518 ]);
519
520 let result = decode_bus_message(&stream_msg);
521 assert!(result.is_err());
522 assert_eq!(
523 format!("{}", result.unwrap_err()),
524 "Invalid stream message format: [bulk-string('\"0\"'), bulk-string('\"topic1\"')]"
525 );
526 }
527
528 #[rstest]
529 fn test_decode_bus_message_invalid_topic_format() {
530 let stream_msg = Value::Array(vec![
531 Value::BulkString(b"0".to_vec()),
532 Value::Int(42), Value::BulkString(b"unused".to_vec()),
534 Value::BulkString(b"data1".to_vec()),
535 ]);
536
537 let result = decode_bus_message(&stream_msg);
538 assert!(result.is_err());
539 assert_eq!(
540 format!("{}", result.unwrap_err()),
541 "Invalid topic format: [bulk-string('\"0\"'), int(42), bulk-string('\"unused\"'), bulk-string('\"data1\"')]"
542 );
543 }
544
545 #[rstest]
546 fn test_decode_bus_message_invalid_payload_format() {
547 let stream_msg = Value::Array(vec![
548 Value::BulkString(b"0".to_vec()),
549 Value::BulkString(b"topic1".to_vec()),
550 Value::BulkString(b"unused".to_vec()),
551 Value::Int(42), ]);
553
554 let result = decode_bus_message(&stream_msg);
555 assert!(result.is_err());
556 assert_eq!(
557 format!("{}", result.unwrap_err()),
558 "Invalid payload format: [bulk-string('\"0\"'), bulk-string('\"topic1\"'), bulk-string('\"unused\"'), int(42)]"
559 );
560 }
561
562 #[rstest]
563 fn test_decode_bus_message_invalid_stream_msg_format() {
564 let stream_msg = Value::BulkString(b"not an array".to_vec());
565
566 let result = decode_bus_message(&stream_msg);
567 assert!(result.is_err());
568 assert_eq!(
569 format!("{}", result.unwrap_err()),
570 "Invalid stream message format: bulk-string('\"not an array\"')"
571 );
572 }
573}
574
575#[cfg(target_os = "linux")] #[cfg(test)]
577mod serial_tests {
578 use nautilus_common::testing::wait_until_async;
579 use redis::aio::ConnectionManager;
580 use rstest::*;
581
582 use super::*;
583 use crate::redis::flush_redis;
584
585 #[fixture]
586 async fn redis_connection() -> ConnectionManager {
587 let config = DatabaseConfig::default();
588 let mut con = create_redis_connection(MSGBUS_STREAM, config)
589 .await
590 .unwrap();
591 flush_redis(&mut con).await.unwrap();
592 con
593 }
594
595 #[rstest]
596 #[tokio::test(flavor = "multi_thread")]
597 async fn test_stream_messages_terminate_signal(#[future] redis_connection: ConnectionManager) {
598 let mut con = redis_connection.await;
599 let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
600
601 let trader_id = TraderId::from("tester-001");
602 let instance_id = UUID4::new();
603 let mut config = MessageBusConfig::default();
604 config.database = Some(DatabaseConfig::default());
605
606 let stream_key = get_stream_key(trader_id, instance_id, &config);
607 let external_streams = vec![stream_key.clone()];
608 let stream_signal = Arc::new(AtomicBool::new(false));
609 let stream_signal_clone = stream_signal.clone();
610
611 let handle = tokio::spawn(async move {
613 stream_messages(
614 tx,
615 DatabaseConfig::default(),
616 external_streams,
617 stream_signal_clone,
618 )
619 .await
620 .unwrap();
621 });
622
623 stream_signal.store(true, Ordering::Relaxed);
624 let _ = rx.recv().await; rx.close();
628 handle.await.unwrap();
629 flush_redis(&mut con).await.unwrap()
630 }
631
632 #[rstest]
633 #[tokio::test(flavor = "multi_thread")]
634 async fn test_stream_messages_when_receiver_closed(
635 #[future] redis_connection: ConnectionManager,
636 ) {
637 let mut con = redis_connection.await;
638 let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
639
640 let trader_id = TraderId::from("tester-001");
641 let instance_id = UUID4::new();
642 let mut config = MessageBusConfig::default();
643 config.database = Some(DatabaseConfig::default());
644
645 let stream_key = get_stream_key(trader_id, instance_id, &config);
646 let external_streams = vec![stream_key.clone()];
647 let stream_signal = Arc::new(AtomicBool::new(false));
648 let stream_signal_clone = stream_signal.clone();
649
650 let clock = get_atomic_clock_realtime();
653 let future_id = (clock.get_time_ms() + 1_000_000).to_string();
654
655 let _: () = con
657 .xadd(
658 stream_key,
659 future_id,
660 &[("topic", "topic1"), ("payload", "data1")],
661 )
662 .await
663 .unwrap();
664
665 rx.close();
667
668 let handle = tokio::spawn(async move {
670 stream_messages(
671 tx,
672 DatabaseConfig::default(),
673 external_streams,
674 stream_signal_clone,
675 )
676 .await
677 .unwrap();
678 });
679
680 handle.await.unwrap();
682 flush_redis(&mut con).await.unwrap()
683 }
684
685 #[rstest]
686 #[tokio::test(flavor = "multi_thread")]
687 async fn test_stream_messages(#[future] redis_connection: ConnectionManager) {
688 let mut con = redis_connection.await;
689 let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
690
691 let trader_id = TraderId::from("tester-001");
692 let instance_id = UUID4::new();
693 let mut config = MessageBusConfig::default();
694 config.database = Some(DatabaseConfig::default());
695
696 let stream_key = get_stream_key(trader_id, instance_id, &config);
697 let external_streams = vec![stream_key.clone()];
698 let stream_signal = Arc::new(AtomicBool::new(false));
699 let stream_signal_clone = stream_signal.clone();
700
701 let clock = get_atomic_clock_realtime();
704 let future_id = (clock.get_time_ms() + 1_000_000).to_string();
705
706 let _: () = con
708 .xadd(
709 stream_key,
710 future_id,
711 &[("topic", "topic1"), ("payload", "data1")],
712 )
713 .await
714 .unwrap();
715
716 let handle = tokio::spawn(async move {
718 stream_messages(
719 tx,
720 DatabaseConfig::default(),
721 external_streams,
722 stream_signal_clone,
723 )
724 .await
725 .unwrap();
726 });
727
728 let msg = rx.recv().await.unwrap();
730 assert_eq!(msg.topic, "topic1");
731 assert_eq!(msg.payload, Bytes::from("data1"));
732
733 rx.close();
735 stream_signal.store(true, Ordering::Relaxed);
736 handle.await.unwrap();
737 flush_redis(&mut con).await.unwrap()
738 }
739
740 #[rstest]
741 #[tokio::test(flavor = "multi_thread")]
742 async fn test_publish_messages(#[future] redis_connection: ConnectionManager) {
743 let mut con = redis_connection.await;
744 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
745
746 let trader_id = TraderId::from("tester-001");
747 let instance_id = UUID4::new();
748 let mut config = MessageBusConfig::default();
749 config.database = Some(DatabaseConfig::default());
750 config.stream_per_topic = false;
751 let stream_key = get_stream_key(trader_id, instance_id, &config);
752
753 let handle = tokio::spawn(async move {
755 publish_messages(rx, trader_id, instance_id, config)
756 .await
757 .unwrap();
758 });
759
760 let msg = BusMessage {
762 topic: "test_topic".to_string(),
763 payload: Bytes::from("test_payload"),
764 };
765 tx.send(msg).unwrap();
766
767 wait_until_async(
769 || {
770 let mut con = con.clone();
771 let stream_key = stream_key.clone();
772 async move {
773 let messages: RedisStreamBulk =
774 con.xread(&[&stream_key], &["0"]).await.unwrap();
775 !messages.is_empty()
776 }
777 },
778 Duration::from_secs(2),
779 )
780 .await;
781
782 let messages: RedisStreamBulk = con.xread(&[&stream_key], &["0"]).await.unwrap();
784 assert_eq!(messages.len(), 1);
785 let stream_msgs = messages[0].get(&stream_key).unwrap();
786 let stream_msg_array = &stream_msgs[0].values().next().unwrap();
787 let decoded_message = decode_bus_message(stream_msg_array).unwrap();
788 assert_eq!(decoded_message.topic, "test_topic");
789 assert_eq!(decoded_message.payload, Bytes::from("test_payload"));
790
791 let msg = BusMessage {
793 topic: CLOSE_TOPIC.to_string(),
794 payload: Bytes::new(), };
796 tx.send(msg).unwrap();
797
798 handle.await.unwrap();
800 flush_redis(&mut con).await.unwrap();
801 }
802
803 #[rstest]
804 #[tokio::test(flavor = "multi_thread")]
805 async fn test_close() {
806 let trader_id = TraderId::from("tester-001");
807 let instance_id = UUID4::new();
808 let mut config = MessageBusConfig::default();
809 config.database = Some(DatabaseConfig::default());
810
811 let mut db = RedisMessageBusDatabase::new(trader_id, instance_id, config).unwrap();
812
813 db.close();
815 }
816
817 #[rstest]
818 #[tokio::test(flavor = "multi_thread")]
819 async fn test_heartbeat_task() {
820 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
821 let signal = Arc::new(AtomicBool::new(false));
822
823 let handle = tokio::spawn(run_heartbeat(1, signal.clone(), tx));
825
826 tokio::time::sleep(Duration::from_secs(2)).await;
828
829 signal.store(true, Ordering::Relaxed);
831 handle.await.unwrap();
832
833 let mut heartbeats: Vec<BusMessage> = Vec::new();
835 while let Ok(hb) = rx.try_recv() {
836 heartbeats.push(hb);
837 }
838
839 assert!(!heartbeats.is_empty());
840
841 for hb in heartbeats {
842 assert_eq!(hb.topic, HEARTBEAT_TOPIC);
843 }
844 }
845}