1use std::{
17 collections::{HashMap, VecDeque},
18 fmt::Debug,
19 sync::{
20 Arc,
21 atomic::{AtomicBool, Ordering},
22 },
23 time::Duration,
24};
25
26use bytes::Bytes;
27use futures::stream::Stream;
28use nautilus_common::{
29 logging::{log_task_error, log_task_started, log_task_stopped},
30 msgbus::{
31 BusMessage,
32 database::{DatabaseConfig, MessageBusConfig, MessageBusDatabaseAdapter},
33 switchboard::CLOSE_TOPIC,
34 },
35 runtime::get_runtime,
36};
37use nautilus_core::{
38 UUID4,
39 time::{duration_since_unix_epoch, get_atomic_clock_realtime},
40};
41use nautilus_cryptography::providers::install_cryptographic_provider;
42use nautilus_model::identifiers::TraderId;
43use redis::{AsyncCommands, streams};
44use streams::StreamReadOptions;
45use tokio::time::Instant;
46use ustr::Ustr;
47
48use super::{REDIS_MINID, REDIS_XTRIM, await_handle};
49use crate::redis::{create_redis_connection, get_stream_key};
50
51const MSGBUS_PUBLISH: &str = "msgbus-publish";
52const MSGBUS_STREAM: &str = "msgbus-stream";
53const MSGBUS_HEARTBEAT: &str = "msgbus-heartbeat";
54const HEARTBEAT_TOPIC: &str = "health:heartbeat";
55const TRIM_BUFFER_SECS: u64 = 60;
56
57type RedisStreamBulk = Vec<HashMap<String, Vec<HashMap<String, redis::Value>>>>;
58
59#[cfg_attr(
60 feature = "python",
61 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.infrastructure")
62)]
63pub struct RedisMessageBusDatabase {
64 pub trader_id: TraderId,
66 pub instance_id: UUID4,
68 pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
69 pub_handle: Option<tokio::task::JoinHandle<()>>,
70 stream_rx: Option<tokio::sync::mpsc::Receiver<BusMessage>>,
71 stream_handle: Option<tokio::task::JoinHandle<()>>,
72 stream_signal: Arc<AtomicBool>,
73 heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
74 heartbeat_signal: Arc<AtomicBool>,
75}
76
77impl Debug for RedisMessageBusDatabase {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct(stringify!(RedisMessageBusDatabase))
80 .field("trader_id", &self.trader_id)
81 .field("instance_id", &self.instance_id)
82 .finish()
83 }
84}
85
86impl MessageBusDatabaseAdapter for RedisMessageBusDatabase {
87 type DatabaseType = Self;
88
89 fn new(
97 trader_id: TraderId,
98 instance_id: UUID4,
99 config: MessageBusConfig,
100 ) -> anyhow::Result<Self> {
101 install_cryptographic_provider();
102
103 let config_clone = config.clone();
104 let db_config = config
105 .database
106 .clone()
107 .ok_or_else(|| anyhow::anyhow!("No database config"))?;
108
109 let (pub_tx, pub_rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
110
111 let pub_handle = Some(get_runtime().spawn(async move {
113 if let Err(e) = publish_messages(pub_rx, trader_id, instance_id, config_clone).await {
114 log_task_error(MSGBUS_PUBLISH, &e);
115 }
116 }));
117
118 let external_streams = config.external_streams.clone().unwrap_or_default();
120 let stream_signal = Arc::new(AtomicBool::new(false));
121 let (stream_rx, stream_handle) = if external_streams.is_empty() {
122 (None, None)
123 } else {
124 let stream_signal_clone = stream_signal.clone();
125 let (stream_tx, stream_rx) = tokio::sync::mpsc::channel::<BusMessage>(100_000);
126 (
127 Some(stream_rx),
128 Some(get_runtime().spawn(async move {
129 if let Err(e) =
130 stream_messages(stream_tx, db_config, external_streams, stream_signal_clone)
131 .await
132 {
133 log_task_error(MSGBUS_STREAM, &e);
134 }
135 })),
136 )
137 };
138
139 let heartbeat_signal = Arc::new(AtomicBool::new(false));
141 let heartbeat_handle = if let Some(heartbeat_interval_secs) = config.heartbeat_interval_secs
142 {
143 let signal = heartbeat_signal.clone();
144 let pub_tx_clone = pub_tx.clone();
145
146 Some(get_runtime().spawn(async move {
147 run_heartbeat(heartbeat_interval_secs, signal, pub_tx_clone).await;
148 }))
149 } else {
150 None
151 };
152
153 Ok(Self {
154 trader_id,
155 instance_id,
156 pub_tx,
157 pub_handle,
158 stream_rx,
159 stream_handle,
160 stream_signal,
161 heartbeat_handle,
162 heartbeat_signal,
163 })
164 }
165
166 fn is_closed(&self) -> bool {
168 self.pub_tx.is_closed()
169 }
170
171 fn publish(&self, topic: Ustr, payload: Bytes) {
173 let msg = BusMessage::new(topic, payload);
174 if let Err(e) = self.pub_tx.send(msg) {
175 log::error!("Failed to send message: {e}");
176 }
177 }
178
179 fn close(&mut self) {
181 log::debug!("Closing");
182
183 self.stream_signal.store(true, Ordering::Relaxed);
184 self.heartbeat_signal.store(true, Ordering::Relaxed);
185
186 if !self.pub_tx.is_closed() {
187 let msg = BusMessage::new_close();
188
189 if let Err(e) = self.pub_tx.send(msg) {
190 log::error!("Failed to send close message: {e:?}");
191 }
192 }
193
194 tokio::task::block_in_place(|| {
196 get_runtime().block_on(async {
197 self.close_async().await;
198 });
199 });
200
201 log::debug!("Closed");
202 }
203}
204
205impl RedisMessageBusDatabase {
206 pub fn get_stream_receiver(
212 &mut self,
213 ) -> anyhow::Result<tokio::sync::mpsc::Receiver<BusMessage>> {
214 self.stream_rx
215 .take()
216 .ok_or_else(|| anyhow::anyhow!("Stream receiver already taken"))
217 }
218
219 pub fn stream(
221 mut stream_rx: tokio::sync::mpsc::Receiver<BusMessage>,
222 ) -> impl Stream<Item = BusMessage> + 'static {
223 async_stream::stream! {
224 while let Some(msg) = stream_rx.recv().await {
225 yield msg;
226 }
227 }
228 }
229
230 pub async fn close_async(&mut self) {
231 await_handle(self.pub_handle.take(), MSGBUS_PUBLISH).await;
232 await_handle(self.stream_handle.take(), MSGBUS_STREAM).await;
233 await_handle(self.heartbeat_handle.take(), MSGBUS_HEARTBEAT).await;
234 }
235}
236
237pub async fn publish_messages(
246 mut rx: tokio::sync::mpsc::UnboundedReceiver<BusMessage>,
247 trader_id: TraderId,
248 instance_id: UUID4,
249 config: MessageBusConfig,
250) -> anyhow::Result<()> {
251 log_task_started(MSGBUS_PUBLISH);
252
253 let db_config = config
254 .database
255 .as_ref()
256 .ok_or_else(|| anyhow::anyhow!("No database config"))?;
257 let mut con = create_redis_connection(MSGBUS_PUBLISH, db_config.clone()).await?;
258 let stream_key = get_stream_key(trader_id, instance_id, &config);
259
260 let autotrim_duration = config
262 .autotrim_mins
263 .filter(|&mins| mins > 0)
264 .map(|mins| Duration::from_secs(u64::from(mins) * 60));
265 let mut last_trim_index: HashMap<String, usize> = HashMap::new();
266
267 let mut buffer: VecDeque<BusMessage> = VecDeque::new();
269 let buffer_interval = Duration::from_millis(u64::from(config.buffer_interval_ms.unwrap_or(0)));
270
271 let flush_timer = tokio::time::sleep(buffer_interval);
275 tokio::pin!(flush_timer);
276
277 loop {
278 tokio::select! {
279 maybe_msg = rx.recv() => {
280 if let Some(msg) = maybe_msg {
281 if msg.topic == CLOSE_TOPIC {
282 tracing::debug!("Received close message");
283 if !buffer.is_empty() {
285 drain_buffer(
286 &mut con,
287 &stream_key,
288 config.stream_per_topic,
289 autotrim_duration,
290 &mut last_trim_index,
291 &mut buffer,
292 ).await?;
293 }
294 break;
295 }
296
297 buffer.push_back(msg);
298
299 if buffer_interval.is_zero() {
300 drain_buffer(
302 &mut con,
303 &stream_key,
304 config.stream_per_topic,
305 autotrim_duration,
306 &mut last_trim_index,
307 &mut buffer,
308 ).await?;
309 }
310 } else {
311 tracing::debug!("Channel hung up");
312 break;
313 }
314 }
315 () = &mut flush_timer, if !buffer_interval.is_zero() => {
318 if !buffer.is_empty() {
319 drain_buffer(
320 &mut con,
321 &stream_key,
322 config.stream_per_topic,
323 autotrim_duration,
324 &mut last_trim_index,
325 &mut buffer,
326 ).await?;
327 }
328
329 flush_timer.as_mut().reset(Instant::now() + buffer_interval);
331 }
332 }
333 }
334
335 if !buffer.is_empty() {
337 drain_buffer(
338 &mut con,
339 &stream_key,
340 config.stream_per_topic,
341 autotrim_duration,
342 &mut last_trim_index,
343 &mut buffer,
344 )
345 .await?;
346 }
347
348 log_task_stopped(MSGBUS_PUBLISH);
349 Ok(())
350}
351
352async fn drain_buffer(
353 conn: &mut redis::aio::ConnectionManager,
354 stream_key: &str,
355 stream_per_topic: bool,
356 autotrim_duration: Option<Duration>,
357 last_trim_index: &mut HashMap<String, usize>,
358 buffer: &mut VecDeque<BusMessage>,
359) -> anyhow::Result<()> {
360 let mut pipe = redis::pipe();
361 pipe.atomic();
362
363 for msg in buffer.drain(..) {
364 let items: Vec<(&str, &[u8])> = vec![
365 ("topic", msg.topic.as_ref()),
366 ("payload", msg.payload.as_ref()),
367 ];
368 let stream_key = if stream_per_topic {
369 format!("{stream_key}:{}", &msg.topic)
370 } else {
371 stream_key.to_string()
372 };
373 pipe.xadd(&stream_key, "*", &items);
374
375 if autotrim_duration.is_none() {
376 continue; }
378
379 let last_trim_ms = last_trim_index.entry(stream_key.clone()).or_insert(0); let unix_duration_now = duration_since_unix_epoch();
382 let trim_buffer = Duration::from_secs(TRIM_BUFFER_SECS);
383
384 if *last_trim_ms < (unix_duration_now - trim_buffer).as_millis() as usize {
386 let min_timestamp_ms =
387 (unix_duration_now - autotrim_duration.unwrap()).as_millis() as usize;
388 let result: Result<(), redis::RedisError> = redis::cmd(REDIS_XTRIM)
389 .arg(stream_key.clone())
390 .arg(REDIS_MINID)
391 .arg(min_timestamp_ms)
392 .query_async(conn)
393 .await;
394
395 if let Err(e) = result {
396 tracing::error!("Error trimming stream '{stream_key}': {e}");
397 } else {
398 last_trim_index.insert(
399 stream_key.to_string(),
400 unix_duration_now.as_millis() as usize,
401 );
402 }
403 }
404 }
405
406 pipe.query_async(conn).await.map_err(anyhow::Error::from)
407}
408
409pub async fn stream_messages(
417 tx: tokio::sync::mpsc::Sender<BusMessage>,
418 config: DatabaseConfig,
419 stream_keys: Vec<String>,
420 stream_signal: Arc<AtomicBool>,
421) -> anyhow::Result<()> {
422 log_task_started(MSGBUS_STREAM);
423
424 let mut con = create_redis_connection(MSGBUS_STREAM, config).await?;
425
426 let stream_keys = &stream_keys
427 .iter()
428 .map(String::as_str)
429 .collect::<Vec<&str>>();
430
431 tracing::debug!("Listening to streams: [{}]", stream_keys.join(", "));
432
433 let clock = get_atomic_clock_realtime();
435 let timestamp_ms = clock.get_time_ms();
436 let mut last_id = timestamp_ms.to_string();
437
438 let opts = StreamReadOptions::default().block(100);
439
440 'outer: loop {
441 if stream_signal.load(Ordering::Relaxed) {
442 tracing::debug!("Received streaming terminate signal");
443 break;
444 }
445 let result: Result<RedisStreamBulk, _> =
446 con.xread_options(&[&stream_keys], &[&last_id], &opts).await;
447 match result {
448 Ok(stream_bulk) => {
449 if stream_bulk.is_empty() {
450 continue;
452 }
453 for entry in &stream_bulk {
454 for stream_msgs in entry.values() {
455 for stream_msg in stream_msgs {
456 for (id, array) in stream_msg {
457 last_id.clear();
458 last_id.push_str(id);
459 match decode_bus_message(array) {
460 Ok(msg) => {
461 if let Err(e) = tx.send(msg).await {
462 tracing::debug!("Channel closed: {e:?}");
463 break 'outer; }
465 }
466 Err(e) => {
467 tracing::error!("{e:?}");
468 continue;
469 }
470 }
471 }
472 }
473 }
474 }
475 }
476 Err(e) => {
477 anyhow::bail!("Error reading from stream: {e:?}");
478 }
479 }
480 }
481
482 log_task_stopped(MSGBUS_STREAM);
483 Ok(())
484}
485
486fn decode_bus_message(stream_msg: &redis::Value) -> anyhow::Result<BusMessage> {
495 if let redis::Value::Array(stream_msg) = stream_msg {
496 if stream_msg.len() < 4 {
497 anyhow::bail!("Invalid stream message format: {stream_msg:?}");
498 }
499
500 let topic = match &stream_msg[1] {
501 redis::Value::BulkString(bytes) => match String::from_utf8(bytes.clone()) {
502 Ok(topic) => topic,
503 Err(e) => anyhow::bail!("Error parsing topic: {e}"),
504 },
505 _ => {
506 anyhow::bail!("Invalid topic format: {stream_msg:?}");
507 }
508 };
509
510 let payload = match &stream_msg[3] {
511 redis::Value::BulkString(bytes) => Bytes::copy_from_slice(bytes),
512 _ => {
513 anyhow::bail!("Invalid payload format: {stream_msg:?}");
514 }
515 };
516
517 Ok(BusMessage::with_str_topic(topic, payload))
518 } else {
519 anyhow::bail!("Invalid stream message format: {stream_msg:?}")
520 }
521}
522
523async fn run_heartbeat(
524 heartbeat_interval_secs: u16,
525 signal: Arc<AtomicBool>,
526 pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
527) {
528 log_task_started("heartbeat");
529 tracing::debug!("Heartbeat at {heartbeat_interval_secs} second intervals");
530
531 let heartbeat_interval = Duration::from_secs(u64::from(heartbeat_interval_secs));
532 let heartbeat_timer = tokio::time::interval(heartbeat_interval);
533
534 let check_interval = Duration::from_millis(100);
535 let check_timer = tokio::time::interval(check_interval);
536
537 tokio::pin!(heartbeat_timer);
538 tokio::pin!(check_timer);
539
540 loop {
541 if signal.load(Ordering::Relaxed) {
542 tracing::debug!("Received heartbeat terminate signal");
543 break;
544 }
545
546 tokio::select! {
547 _ = heartbeat_timer.tick() => {
548 let heartbeat = create_heartbeat_msg();
549 if let Err(e) = pub_tx.send(heartbeat) {
550 tracing::debug!("Error sending heartbeat: {e}");
552 }
553 },
554 _ = check_timer.tick() => {}
555 }
556 }
557
558 log_task_stopped("heartbeat");
559}
560
561fn create_heartbeat_msg() -> BusMessage {
562 let payload = Bytes::from(chrono::Utc::now().to_rfc3339().into_bytes());
563 BusMessage::with_str_topic(HEARTBEAT_TOPIC, payload)
564}
565
566#[cfg(test)]
570mod tests {
571 use redis::Value;
572 use rstest::*;
573
574 use super::*;
575
576 #[rstest]
577 fn test_decode_bus_message_valid() {
578 let stream_msg = Value::Array(vec![
579 Value::BulkString(b"0".to_vec()),
580 Value::BulkString(b"topic1".to_vec()),
581 Value::BulkString(b"unused".to_vec()),
582 Value::BulkString(b"data1".to_vec()),
583 ]);
584
585 let result = decode_bus_message(&stream_msg);
586 assert!(result.is_ok());
587 let msg = result.unwrap();
588 assert_eq!(msg.topic, "topic1");
589 assert_eq!(msg.payload, Bytes::from("data1"));
590 }
591
592 #[rstest]
593 fn test_decode_bus_message_missing_fields() {
594 let stream_msg = Value::Array(vec![
595 Value::BulkString(b"0".to_vec()),
596 Value::BulkString(b"topic1".to_vec()),
597 ]);
598
599 let result = decode_bus_message(&stream_msg);
600 assert!(result.is_err());
601 assert_eq!(
602 format!("{}", result.unwrap_err()),
603 "Invalid stream message format: [bulk-string('\"0\"'), bulk-string('\"topic1\"')]"
604 );
605 }
606
607 #[rstest]
608 fn test_decode_bus_message_invalid_topic_format() {
609 let stream_msg = Value::Array(vec![
610 Value::BulkString(b"0".to_vec()),
611 Value::Int(42), Value::BulkString(b"unused".to_vec()),
613 Value::BulkString(b"data1".to_vec()),
614 ]);
615
616 let result = decode_bus_message(&stream_msg);
617 assert!(result.is_err());
618 assert_eq!(
619 format!("{}", result.unwrap_err()),
620 "Invalid topic format: [bulk-string('\"0\"'), int(42), bulk-string('\"unused\"'), bulk-string('\"data1\"')]"
621 );
622 }
623
624 #[rstest]
625 fn test_decode_bus_message_invalid_payload_format() {
626 let stream_msg = Value::Array(vec![
627 Value::BulkString(b"0".to_vec()),
628 Value::BulkString(b"topic1".to_vec()),
629 Value::BulkString(b"unused".to_vec()),
630 Value::Int(42), ]);
632
633 let result = decode_bus_message(&stream_msg);
634 assert!(result.is_err());
635 assert_eq!(
636 format!("{}", result.unwrap_err()),
637 "Invalid payload format: [bulk-string('\"0\"'), bulk-string('\"topic1\"'), bulk-string('\"unused\"'), int(42)]"
638 );
639 }
640
641 #[rstest]
642 fn test_decode_bus_message_invalid_stream_msg_format() {
643 let stream_msg = Value::BulkString(b"not an array".to_vec());
644
645 let result = decode_bus_message(&stream_msg);
646 assert!(result.is_err());
647 assert_eq!(
648 format!("{}", result.unwrap_err()),
649 "Invalid stream message format: bulk-string('\"not an array\"')"
650 );
651 }
652}
653
654#[cfg(target_os = "linux")] #[cfg(test)]
656mod serial_tests {
657 use nautilus_common::testing::wait_until_async;
658 use redis::aio::ConnectionManager;
659 use rstest::*;
660
661 use super::*;
662 use crate::redis::flush_redis;
663
664 #[fixture]
665 async fn redis_connection() -> ConnectionManager {
666 let config = DatabaseConfig::default();
667 let mut con = create_redis_connection(MSGBUS_STREAM, config)
668 .await
669 .unwrap();
670 flush_redis(&mut con).await.unwrap();
671 con
672 }
673
674 #[rstest]
675 #[tokio::test(flavor = "multi_thread")]
676 async fn test_stream_messages_terminate_signal(#[future] redis_connection: ConnectionManager) {
677 let mut con = redis_connection.await;
678 let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
679
680 let trader_id = TraderId::from("tester-001");
681 let instance_id = UUID4::new();
682 let config = MessageBusConfig {
683 database: Some(DatabaseConfig::default()),
684 ..Default::default()
685 };
686
687 let stream_key = get_stream_key(trader_id, instance_id, &config);
688 let external_streams = vec![stream_key.clone()];
689 let stream_signal = Arc::new(AtomicBool::new(false));
690 let stream_signal_clone = stream_signal.clone();
691
692 let handle = tokio::spawn(async move {
694 stream_messages(
695 tx,
696 DatabaseConfig::default(),
697 external_streams,
698 stream_signal_clone,
699 )
700 .await
701 .unwrap();
702 });
703
704 stream_signal.store(true, Ordering::Relaxed);
705 let _ = rx.recv().await; rx.close();
709 handle.await.unwrap();
710 flush_redis(&mut con).await.unwrap();
711 }
712
713 #[rstest]
714 #[tokio::test(flavor = "multi_thread")]
715 async fn test_stream_messages_when_receiver_closed(
716 #[future] redis_connection: ConnectionManager,
717 ) {
718 let mut con = redis_connection.await;
719 let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
720
721 let trader_id = TraderId::from("tester-001");
722 let instance_id = UUID4::new();
723 let config = MessageBusConfig {
724 database: Some(DatabaseConfig::default()),
725 ..Default::default()
726 };
727
728 let stream_key = get_stream_key(trader_id, instance_id, &config);
729 let external_streams = vec![stream_key.clone()];
730 let stream_signal = Arc::new(AtomicBool::new(false));
731 let stream_signal_clone = stream_signal.clone();
732
733 let clock = get_atomic_clock_realtime();
736 let future_id = (clock.get_time_ms() + 1_000_000).to_string();
737
738 let _: () = con
740 .xadd(
741 stream_key,
742 future_id,
743 &[("topic", "topic1"), ("payload", "data1")],
744 )
745 .await
746 .unwrap();
747
748 rx.close();
750
751 let handle = tokio::spawn(async move {
753 stream_messages(
754 tx,
755 DatabaseConfig::default(),
756 external_streams,
757 stream_signal_clone,
758 )
759 .await
760 .unwrap();
761 });
762
763 handle.await.unwrap();
765 flush_redis(&mut con).await.unwrap();
766 }
767
768 #[rstest]
769 #[tokio::test(flavor = "multi_thread")]
770 async fn test_stream_messages(#[future] redis_connection: ConnectionManager) {
771 let mut con = redis_connection.await;
772 let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
773
774 let trader_id = TraderId::from("tester-001");
775 let instance_id = UUID4::new();
776 let config = MessageBusConfig {
777 database: Some(DatabaseConfig::default()),
778 ..Default::default()
779 };
780
781 let stream_key = get_stream_key(trader_id, instance_id, &config);
782 let external_streams = vec![stream_key.clone()];
783 let stream_signal = Arc::new(AtomicBool::new(false));
784 let stream_signal_clone = stream_signal.clone();
785
786 let clock = get_atomic_clock_realtime();
789 let future_id = (clock.get_time_ms() + 1_000_000).to_string();
790
791 let _: () = con
793 .xadd(
794 stream_key,
795 future_id,
796 &[("topic", "topic1"), ("payload", "data1")],
797 )
798 .await
799 .unwrap();
800
801 let handle = tokio::spawn(async move {
803 stream_messages(
804 tx,
805 DatabaseConfig::default(),
806 external_streams,
807 stream_signal_clone,
808 )
809 .await
810 .unwrap();
811 });
812
813 let msg = rx.recv().await.unwrap();
815 assert_eq!(msg.topic, "topic1");
816 assert_eq!(msg.payload, Bytes::from("data1"));
817
818 rx.close();
820 stream_signal.store(true, Ordering::Relaxed);
821 handle.await.unwrap();
822 flush_redis(&mut con).await.unwrap();
823 }
824
825 #[rstest]
826 #[tokio::test(flavor = "multi_thread")]
827 async fn test_publish_messages(#[future] redis_connection: ConnectionManager) {
828 let mut con = redis_connection.await;
829 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
830
831 let trader_id = TraderId::from("tester-001");
832 let instance_id = UUID4::new();
833 let config = MessageBusConfig {
834 database: Some(DatabaseConfig::default()),
835 stream_per_topic: false,
836 ..Default::default()
837 };
838 let stream_key = get_stream_key(trader_id, instance_id, &config);
839
840 let handle = tokio::spawn(async move {
842 publish_messages(rx, trader_id, instance_id, config)
843 .await
844 .unwrap();
845 });
846
847 let msg = BusMessage::with_str_topic("test_topic", Bytes::from("test_payload"));
849 tx.send(msg).unwrap();
850
851 wait_until_async(
853 || {
854 let mut con = con.clone();
855 let stream_key = stream_key.clone();
856 async move {
857 let messages: RedisStreamBulk =
858 con.xread(&[&stream_key], &["0"]).await.unwrap();
859 !messages.is_empty()
860 }
861 },
862 Duration::from_secs(3),
863 )
864 .await;
865
866 let messages: RedisStreamBulk = con.xread(&[&stream_key], &["0"]).await.unwrap();
868 assert_eq!(messages.len(), 1);
869 let stream_msgs = messages[0].get(&stream_key).unwrap();
870 let stream_msg_array = &stream_msgs[0].values().next().unwrap();
871 let decoded_message = decode_bus_message(stream_msg_array).unwrap();
872 assert_eq!(decoded_message.topic, "test_topic");
873 assert_eq!(decoded_message.payload, Bytes::from("test_payload"));
874
875 let msg = BusMessage::new_close();
877 tx.send(msg).unwrap();
878
879 handle.await.unwrap();
881 flush_redis(&mut con).await.unwrap();
882 }
883
884 #[rstest]
885 #[tokio::test(flavor = "multi_thread")]
886 async fn test_close() {
887 let trader_id = TraderId::from("tester-001");
888 let instance_id = UUID4::new();
889 let config = MessageBusConfig {
890 database: Some(DatabaseConfig::default()),
891 ..Default::default()
892 };
893
894 let mut db = RedisMessageBusDatabase::new(trader_id, instance_id, config).unwrap();
895
896 db.close();
898 }
899
900 #[rstest]
901 #[tokio::test(flavor = "multi_thread")]
902 async fn test_heartbeat_task() {
903 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
904 let signal = Arc::new(AtomicBool::new(false));
905
906 let handle = tokio::spawn(run_heartbeat(1, signal.clone(), tx));
908
909 tokio::time::sleep(Duration::from_secs(2)).await;
911
912 signal.store(true, Ordering::Relaxed);
914 handle.await.unwrap();
915
916 let mut heartbeats: Vec<BusMessage> = Vec::new();
918 while let Ok(hb) = rx.try_recv() {
919 heartbeats.push(hb);
920 }
921
922 assert!(!heartbeats.is_empty());
923
924 for hb in heartbeats {
925 assert_eq!(hb.topic, HEARTBEAT_TOPIC);
926 }
927 }
928}