1use std::{
17 collections::{HashMap, VecDeque},
18 fmt::Debug,
19 sync::{
20 Arc,
21 atomic::{AtomicBool, Ordering},
22 },
23 time::Duration,
24};
25
26use bytes::Bytes;
27use futures::stream::Stream;
28use nautilus_common::{
29 logging::{log_task_error, log_task_started, log_task_stopped},
30 msgbus::{
31 BusMessage,
32 database::{DatabaseConfig, MessageBusConfig, MessageBusDatabaseAdapter},
33 switchboard::CLOSE_TOPIC,
34 },
35 runtime::get_runtime,
36};
37use nautilus_core::{
38 UUID4,
39 time::{duration_since_unix_epoch, get_atomic_clock_realtime},
40};
41use nautilus_cryptography::providers::install_cryptographic_provider;
42use nautilus_model::identifiers::TraderId;
43use redis::{AsyncCommands, streams};
44use streams::StreamReadOptions;
45use tokio::time::Instant;
46use ustr::Ustr;
47
48use super::{REDIS_MINID, REDIS_XTRIM, await_handle};
49use crate::redis::{create_redis_connection, get_stream_key};
50
51const MSGBUS_PUBLISH: &str = "msgbus-publish";
52const MSGBUS_STREAM: &str = "msgbus-stream";
53const MSGBUS_HEARTBEAT: &str = "msgbus-heartbeat";
54const HEARTBEAT_TOPIC: &str = "health:heartbeat";
55const TRIM_BUFFER_SECS: u64 = 60;
56
57type RedisStreamBulk = Vec<HashMap<String, Vec<HashMap<String, redis::Value>>>>;
58
59#[cfg_attr(
60 feature = "python",
61 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.infrastructure")
62)]
63pub struct RedisMessageBusDatabase {
64 pub trader_id: TraderId,
66 pub instance_id: UUID4,
68 pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
69 pub_handle: Option<tokio::task::JoinHandle<()>>,
70 stream_rx: Option<tokio::sync::mpsc::Receiver<BusMessage>>,
71 stream_handle: Option<tokio::task::JoinHandle<()>>,
72 stream_signal: Arc<AtomicBool>,
73 heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
74 heartbeat_signal: Arc<AtomicBool>,
75}
76
77impl Debug for RedisMessageBusDatabase {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct(stringify!(RedisMessageBusDatabase))
80 .field("trader_id", &self.trader_id)
81 .field("instance_id", &self.instance_id)
82 .finish()
83 }
84}
85
86impl MessageBusDatabaseAdapter for RedisMessageBusDatabase {
87 type DatabaseType = Self;
88
89 fn new(
97 trader_id: TraderId,
98 instance_id: UUID4,
99 config: MessageBusConfig,
100 ) -> anyhow::Result<Self> {
101 install_cryptographic_provider();
102
103 let config_clone = config.clone();
104 let db_config = config
105 .database
106 .clone()
107 .ok_or_else(|| anyhow::anyhow!("No database config"))?;
108
109 let (pub_tx, pub_rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
110
111 let pub_handle = Some(get_runtime().spawn(async move {
113 if let Err(e) = publish_messages(pub_rx, trader_id, instance_id, config_clone).await {
114 log_task_error(MSGBUS_PUBLISH, &e);
115 }
116 }));
117
118 let external_streams = config.external_streams.clone().unwrap_or_default();
120 let stream_signal = Arc::new(AtomicBool::new(false));
121 let (stream_rx, stream_handle) = if external_streams.is_empty() {
122 (None, None)
123 } else {
124 let stream_signal_clone = stream_signal.clone();
125 let (stream_tx, stream_rx) = tokio::sync::mpsc::channel::<BusMessage>(100_000);
126 (
127 Some(stream_rx),
128 Some(get_runtime().spawn(async move {
129 if let Err(e) =
130 stream_messages(stream_tx, db_config, external_streams, stream_signal_clone)
131 .await
132 {
133 log_task_error(MSGBUS_STREAM, &e);
134 }
135 })),
136 )
137 };
138
139 let heartbeat_signal = Arc::new(AtomicBool::new(false));
141 let heartbeat_handle = if let Some(heartbeat_interval_secs) = config.heartbeat_interval_secs
142 {
143 let signal = heartbeat_signal.clone();
144 let pub_tx_clone = pub_tx.clone();
145
146 Some(get_runtime().spawn(async move {
147 run_heartbeat(heartbeat_interval_secs, signal, pub_tx_clone).await;
148 }))
149 } else {
150 None
151 };
152
153 Ok(Self {
154 trader_id,
155 instance_id,
156 pub_tx,
157 pub_handle,
158 stream_rx,
159 stream_handle,
160 stream_signal,
161 heartbeat_handle,
162 heartbeat_signal,
163 })
164 }
165
166 fn is_closed(&self) -> bool {
168 self.pub_tx.is_closed()
169 }
170
171 fn publish(&self, topic: Ustr, payload: Bytes) {
173 let msg = BusMessage::new(topic, payload);
174 if let Err(e) = self.pub_tx.send(msg) {
175 log::error!("Failed to send message: {e}");
176 }
177 }
178
179 fn close(&mut self) {
181 log::debug!("Closing");
182
183 self.stream_signal.store(true, Ordering::Relaxed);
184 self.heartbeat_signal.store(true, Ordering::Relaxed);
185
186 if !self.pub_tx.is_closed() {
187 let msg = BusMessage::new_close();
188
189 if let Err(e) = self.pub_tx.send(msg) {
190 log::error!("Failed to send close message: {e:?}");
191 }
192 }
193
194 tokio::task::block_in_place(|| {
196 get_runtime().block_on(async {
197 self.close_async().await;
198 });
199 });
200
201 log::debug!("Closed");
202 }
203}
204
205impl RedisMessageBusDatabase {
206 pub fn get_stream_receiver(
212 &mut self,
213 ) -> anyhow::Result<tokio::sync::mpsc::Receiver<BusMessage>> {
214 self.stream_rx
215 .take()
216 .ok_or_else(|| anyhow::anyhow!("Stream receiver already taken"))
217 }
218
219 pub fn stream(
221 mut stream_rx: tokio::sync::mpsc::Receiver<BusMessage>,
222 ) -> impl Stream<Item = BusMessage> + 'static {
223 async_stream::stream! {
224 while let Some(msg) = stream_rx.recv().await {
225 yield msg;
226 }
227 }
228 }
229
230 pub async fn close_async(&mut self) {
231 await_handle(self.pub_handle.take(), MSGBUS_PUBLISH).await;
232 await_handle(self.stream_handle.take(), MSGBUS_STREAM).await;
233 await_handle(self.heartbeat_handle.take(), MSGBUS_HEARTBEAT).await;
234 }
235}
236
237pub async fn publish_messages(
246 mut rx: tokio::sync::mpsc::UnboundedReceiver<BusMessage>,
247 trader_id: TraderId,
248 instance_id: UUID4,
249 config: MessageBusConfig,
250) -> anyhow::Result<()> {
251 log_task_started(MSGBUS_PUBLISH);
252
253 let db_config = config
254 .database
255 .as_ref()
256 .ok_or_else(|| anyhow::anyhow!("No database config"))?;
257 let mut con = create_redis_connection(MSGBUS_PUBLISH, db_config.clone()).await?;
258 let stream_key = get_stream_key(trader_id, instance_id, &config);
259
260 let autotrim_duration = config
262 .autotrim_mins
263 .filter(|&mins| mins > 0)
264 .map(|mins| Duration::from_secs(u64::from(mins) * 60));
265 let mut last_trim_index: HashMap<String, usize> = HashMap::new();
266
267 let mut buffer: VecDeque<BusMessage> = VecDeque::new();
269 let buffer_interval = Duration::from_millis(u64::from(config.buffer_interval_ms.unwrap_or(0)));
270
271 let flush_timer = tokio::time::sleep(buffer_interval);
275 tokio::pin!(flush_timer);
276
277 loop {
278 tokio::select! {
279 maybe_msg = rx.recv() => {
280 if let Some(msg) = maybe_msg {
281 if msg.topic == CLOSE_TOPIC {
282 tracing::debug!("Received close message");
283 if !buffer.is_empty() {
285 drain_buffer(
286 &mut con,
287 &stream_key,
288 config.stream_per_topic,
289 autotrim_duration,
290 &mut last_trim_index,
291 &mut buffer,
292 ).await?;
293 }
294 break;
295 }
296
297 buffer.push_back(msg);
298
299 if buffer_interval.is_zero() {
300 drain_buffer(
302 &mut con,
303 &stream_key,
304 config.stream_per_topic,
305 autotrim_duration,
306 &mut last_trim_index,
307 &mut buffer,
308 ).await?;
309 }
310 } else {
311 tracing::debug!("Channel hung up");
312 break;
313 }
314 }
315 () = &mut flush_timer, if !buffer_interval.is_zero() => {
318 if !buffer.is_empty() {
319 drain_buffer(
320 &mut con,
321 &stream_key,
322 config.stream_per_topic,
323 autotrim_duration,
324 &mut last_trim_index,
325 &mut buffer,
326 ).await?;
327 }
328
329 flush_timer.as_mut().reset(Instant::now() + buffer_interval);
331 }
332 }
333 }
334
335 if !buffer.is_empty() {
337 drain_buffer(
338 &mut con,
339 &stream_key,
340 config.stream_per_topic,
341 autotrim_duration,
342 &mut last_trim_index,
343 &mut buffer,
344 )
345 .await?;
346 }
347
348 log_task_stopped(MSGBUS_PUBLISH);
349 Ok(())
350}
351
352async fn drain_buffer(
353 conn: &mut redis::aio::ConnectionManager,
354 stream_key: &str,
355 stream_per_topic: bool,
356 autotrim_duration: Option<Duration>,
357 last_trim_index: &mut HashMap<String, usize>,
358 buffer: &mut VecDeque<BusMessage>,
359) -> anyhow::Result<()> {
360 let mut pipe = redis::pipe();
361 pipe.atomic();
362
363 for msg in buffer.drain(..) {
364 let items: Vec<(&str, &[u8])> = vec![
365 ("topic", msg.topic.as_ref()),
366 ("payload", msg.payload.as_ref()),
367 ];
368 let stream_key = if stream_per_topic {
369 format!("{stream_key}:{}", &msg.topic)
370 } else {
371 stream_key.to_string()
372 };
373 pipe.xadd(&stream_key, "*", &items);
374
375 if autotrim_duration.is_none() {
376 continue; }
378
379 let last_trim_ms = last_trim_index.entry(stream_key.clone()).or_insert(0); let unix_duration_now = duration_since_unix_epoch();
382 let trim_buffer = Duration::from_secs(TRIM_BUFFER_SECS);
383
384 if *last_trim_ms < (unix_duration_now - trim_buffer).as_millis() as usize {
386 let min_timestamp_ms =
387 (unix_duration_now - autotrim_duration.unwrap()).as_millis() as usize;
388 let result: Result<(), redis::RedisError> = redis::cmd(REDIS_XTRIM)
389 .arg(stream_key.clone())
390 .arg(REDIS_MINID)
391 .arg(min_timestamp_ms)
392 .query_async(conn)
393 .await;
394
395 if let Err(e) = result {
396 tracing::error!("Error trimming stream '{stream_key}': {e}");
397 } else {
398 last_trim_index.insert(stream_key.clone(), unix_duration_now.as_millis() as usize);
399 }
400 }
401 }
402
403 pipe.query_async(conn).await.map_err(anyhow::Error::from)
404}
405
406pub async fn stream_messages(
414 tx: tokio::sync::mpsc::Sender<BusMessage>,
415 config: DatabaseConfig,
416 stream_keys: Vec<String>,
417 stream_signal: Arc<AtomicBool>,
418) -> anyhow::Result<()> {
419 log_task_started(MSGBUS_STREAM);
420
421 let mut con = create_redis_connection(MSGBUS_STREAM, config).await?;
422
423 let stream_keys = &stream_keys
424 .iter()
425 .map(String::as_str)
426 .collect::<Vec<&str>>();
427
428 tracing::debug!("Listening to streams: [{}]", stream_keys.join(", "));
429
430 let clock = get_atomic_clock_realtime();
432 let timestamp_ms = clock.get_time_ms();
433 let initial_id = timestamp_ms.to_string();
434
435 let mut last_ids: HashMap<String, String> = stream_keys
436 .iter()
437 .map(|&key| (key.to_string(), initial_id.clone()))
438 .collect();
439
440 let opts = StreamReadOptions::default().block(100);
441
442 'outer: loop {
443 if stream_signal.load(Ordering::Relaxed) {
444 tracing::debug!("Received streaming terminate signal");
445 break;
446 }
447
448 let ids: Vec<String> = stream_keys
449 .iter()
450 .map(|&key| last_ids[key].clone())
451 .collect();
452 let id_refs: Vec<&str> = ids.iter().map(String::as_str).collect();
453
454 let result: Result<RedisStreamBulk, _> =
455 con.xread_options(&[&stream_keys], &[&id_refs], &opts).await;
456 match result {
457 Ok(stream_bulk) => {
458 if stream_bulk.is_empty() {
459 continue;
461 }
462 for entry in &stream_bulk {
463 for (stream_key, stream_msgs) in entry {
464 for stream_msg in stream_msgs {
465 for (id, array) in stream_msg {
466 last_ids.insert(stream_key.clone(), id.clone());
467
468 match decode_bus_message(array) {
469 Ok(msg) => {
470 if let Err(e) = tx.send(msg).await {
471 tracing::debug!("Channel closed: {e:?}");
472 break 'outer; }
474 }
475 Err(e) => {
476 tracing::error!("{e:?}");
477 continue;
478 }
479 }
480 }
481 }
482 }
483 }
484 }
485 Err(e) => {
486 anyhow::bail!("Error reading from stream: {e:?}");
487 }
488 }
489 }
490
491 log_task_stopped(MSGBUS_STREAM);
492 Ok(())
493}
494
495fn decode_bus_message(stream_msg: &redis::Value) -> anyhow::Result<BusMessage> {
504 if let redis::Value::Array(stream_msg) = stream_msg {
505 if stream_msg.len() < 4 {
506 anyhow::bail!("Invalid stream message format: {stream_msg:?}");
507 }
508
509 let topic = match &stream_msg[1] {
510 redis::Value::BulkString(bytes) => match String::from_utf8(bytes.clone()) {
511 Ok(topic) => topic,
512 Err(e) => anyhow::bail!("Error parsing topic: {e}"),
513 },
514 _ => {
515 anyhow::bail!("Invalid topic format: {stream_msg:?}");
516 }
517 };
518
519 let payload = match &stream_msg[3] {
520 redis::Value::BulkString(bytes) => Bytes::copy_from_slice(bytes),
521 _ => {
522 anyhow::bail!("Invalid payload format: {stream_msg:?}");
523 }
524 };
525
526 Ok(BusMessage::with_str_topic(topic, payload))
527 } else {
528 anyhow::bail!("Invalid stream message format: {stream_msg:?}")
529 }
530}
531
532async fn run_heartbeat(
533 heartbeat_interval_secs: u16,
534 signal: Arc<AtomicBool>,
535 pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
536) {
537 log_task_started("heartbeat");
538 tracing::debug!("Heartbeat at {heartbeat_interval_secs} second intervals");
539
540 let heartbeat_interval = Duration::from_secs(u64::from(heartbeat_interval_secs));
541 let heartbeat_timer = tokio::time::interval(heartbeat_interval);
542
543 let check_interval = Duration::from_millis(100);
544 let check_timer = tokio::time::interval(check_interval);
545
546 tokio::pin!(heartbeat_timer);
547 tokio::pin!(check_timer);
548
549 loop {
550 if signal.load(Ordering::Relaxed) {
551 tracing::debug!("Received heartbeat terminate signal");
552 break;
553 }
554
555 tokio::select! {
556 _ = heartbeat_timer.tick() => {
557 let heartbeat = create_heartbeat_msg();
558 if let Err(e) = pub_tx.send(heartbeat) {
559 tracing::debug!("Error sending heartbeat: {e}");
561 }
562 },
563 _ = check_timer.tick() => {}
564 }
565 }
566
567 log_task_stopped("heartbeat");
568}
569
570fn create_heartbeat_msg() -> BusMessage {
571 let payload = Bytes::from(chrono::Utc::now().to_rfc3339().into_bytes());
572 BusMessage::with_str_topic(HEARTBEAT_TOPIC, payload)
573}
574
575#[cfg(test)]
579mod tests {
580 use redis::Value;
581 use rstest::*;
582
583 use super::*;
584
585 #[rstest]
586 fn test_decode_bus_message_valid() {
587 let stream_msg = Value::Array(vec![
588 Value::BulkString(b"0".to_vec()),
589 Value::BulkString(b"topic1".to_vec()),
590 Value::BulkString(b"unused".to_vec()),
591 Value::BulkString(b"data1".to_vec()),
592 ]);
593
594 let result = decode_bus_message(&stream_msg);
595 assert!(result.is_ok());
596 let msg = result.unwrap();
597 assert_eq!(msg.topic, "topic1");
598 assert_eq!(msg.payload, Bytes::from("data1"));
599 }
600
601 #[rstest]
602 fn test_decode_bus_message_missing_fields() {
603 let stream_msg = Value::Array(vec![
604 Value::BulkString(b"0".to_vec()),
605 Value::BulkString(b"topic1".to_vec()),
606 ]);
607
608 let result = decode_bus_message(&stream_msg);
609 assert!(result.is_err());
610 assert_eq!(
611 format!("{}", result.unwrap_err()),
612 "Invalid stream message format: [bulk-string('\"0\"'), bulk-string('\"topic1\"')]"
613 );
614 }
615
616 #[rstest]
617 fn test_decode_bus_message_invalid_topic_format() {
618 let stream_msg = Value::Array(vec![
619 Value::BulkString(b"0".to_vec()),
620 Value::Int(42), Value::BulkString(b"unused".to_vec()),
622 Value::BulkString(b"data1".to_vec()),
623 ]);
624
625 let result = decode_bus_message(&stream_msg);
626 assert!(result.is_err());
627 assert_eq!(
628 format!("{}", result.unwrap_err()),
629 "Invalid topic format: [bulk-string('\"0\"'), int(42), bulk-string('\"unused\"'), bulk-string('\"data1\"')]"
630 );
631 }
632
633 #[rstest]
634 fn test_decode_bus_message_invalid_payload_format() {
635 let stream_msg = Value::Array(vec![
636 Value::BulkString(b"0".to_vec()),
637 Value::BulkString(b"topic1".to_vec()),
638 Value::BulkString(b"unused".to_vec()),
639 Value::Int(42), ]);
641
642 let result = decode_bus_message(&stream_msg);
643 assert!(result.is_err());
644 assert_eq!(
645 format!("{}", result.unwrap_err()),
646 "Invalid payload format: [bulk-string('\"0\"'), bulk-string('\"topic1\"'), bulk-string('\"unused\"'), int(42)]"
647 );
648 }
649
650 #[rstest]
651 fn test_decode_bus_message_invalid_stream_msg_format() {
652 let stream_msg = Value::BulkString(b"not an array".to_vec());
653
654 let result = decode_bus_message(&stream_msg);
655 assert!(result.is_err());
656 assert_eq!(
657 format!("{}", result.unwrap_err()),
658 "Invalid stream message format: bulk-string('\"not an array\"')"
659 );
660 }
661}
662
663#[cfg(target_os = "linux")] #[cfg(test)]
665mod serial_tests {
666 use nautilus_common::testing::wait_until_async;
667 use redis::aio::ConnectionManager;
668 use rstest::*;
669
670 use super::*;
671 use crate::redis::flush_redis;
672
673 #[fixture]
674 async fn redis_connection() -> ConnectionManager {
675 let config = DatabaseConfig::default();
676 let mut con = create_redis_connection(MSGBUS_STREAM, config)
677 .await
678 .unwrap();
679 flush_redis(&mut con).await.unwrap();
680 con
681 }
682
683 #[rstest]
684 #[tokio::test(flavor = "multi_thread")]
685 async fn test_stream_messages_terminate_signal(#[future] redis_connection: ConnectionManager) {
686 let mut con = redis_connection.await;
687 let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
688
689 let trader_id = TraderId::from("tester-001");
690 let instance_id = UUID4::new();
691 let config = MessageBusConfig {
692 database: Some(DatabaseConfig::default()),
693 ..Default::default()
694 };
695
696 let stream_key = get_stream_key(trader_id, instance_id, &config);
697 let external_streams = vec![stream_key.clone()];
698 let stream_signal = Arc::new(AtomicBool::new(false));
699 let stream_signal_clone = stream_signal.clone();
700
701 let handle = tokio::spawn(async move {
703 stream_messages(
704 tx,
705 DatabaseConfig::default(),
706 external_streams,
707 stream_signal_clone,
708 )
709 .await
710 .unwrap();
711 });
712
713 stream_signal.store(true, Ordering::Relaxed);
714 let _ = rx.recv().await; rx.close();
718 handle.await.unwrap();
719 flush_redis(&mut con).await.unwrap();
720 }
721
722 #[rstest]
723 #[tokio::test(flavor = "multi_thread")]
724 async fn test_stream_messages_when_receiver_closed(
725 #[future] redis_connection: ConnectionManager,
726 ) {
727 let mut con = redis_connection.await;
728 let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
729
730 let trader_id = TraderId::from("tester-001");
731 let instance_id = UUID4::new();
732 let config = MessageBusConfig {
733 database: Some(DatabaseConfig::default()),
734 ..Default::default()
735 };
736
737 let stream_key = get_stream_key(trader_id, instance_id, &config);
738 let external_streams = vec![stream_key.clone()];
739 let stream_signal = Arc::new(AtomicBool::new(false));
740 let stream_signal_clone = stream_signal.clone();
741
742 let clock = get_atomic_clock_realtime();
745 let future_id = (clock.get_time_ms() + 1_000_000).to_string();
746
747 let _: () = con
749 .xadd(
750 stream_key,
751 future_id,
752 &[("topic", "topic1"), ("payload", "data1")],
753 )
754 .await
755 .unwrap();
756
757 rx.close();
759
760 let handle = tokio::spawn(async move {
762 stream_messages(
763 tx,
764 DatabaseConfig::default(),
765 external_streams,
766 stream_signal_clone,
767 )
768 .await
769 .unwrap();
770 });
771
772 handle.await.unwrap();
774 flush_redis(&mut con).await.unwrap();
775 }
776
777 #[rstest]
778 #[tokio::test(flavor = "multi_thread")]
779 async fn test_stream_messages(#[future] redis_connection: ConnectionManager) {
780 let mut con = redis_connection.await;
781 let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
782
783 let trader_id = TraderId::from("tester-001");
784 let instance_id = UUID4::new();
785 let config = MessageBusConfig {
786 database: Some(DatabaseConfig::default()),
787 ..Default::default()
788 };
789
790 let stream_key = get_stream_key(trader_id, instance_id, &config);
791 let external_streams = vec![stream_key.clone()];
792 let stream_signal = Arc::new(AtomicBool::new(false));
793 let stream_signal_clone = stream_signal.clone();
794
795 let clock = get_atomic_clock_realtime();
798 let future_id = (clock.get_time_ms() + 1_000_000).to_string();
799
800 let _: () = con
802 .xadd(
803 stream_key,
804 future_id,
805 &[("topic", "topic1"), ("payload", "data1")],
806 )
807 .await
808 .unwrap();
809
810 let handle = tokio::spawn(async move {
812 stream_messages(
813 tx,
814 DatabaseConfig::default(),
815 external_streams,
816 stream_signal_clone,
817 )
818 .await
819 .unwrap();
820 });
821
822 let msg = rx.recv().await.unwrap();
824 assert_eq!(msg.topic, "topic1");
825 assert_eq!(msg.payload, Bytes::from("data1"));
826
827 rx.close();
829 stream_signal.store(true, Ordering::Relaxed);
830 handle.await.unwrap();
831 flush_redis(&mut con).await.unwrap();
832 }
833
834 #[rstest]
835 #[tokio::test(flavor = "multi_thread")]
836 async fn test_publish_messages(#[future] redis_connection: ConnectionManager) {
837 let mut con = redis_connection.await;
838 let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
839
840 let trader_id = TraderId::from("tester-001");
841 let instance_id = UUID4::new();
842 let config = MessageBusConfig {
843 database: Some(DatabaseConfig::default()),
844 stream_per_topic: false,
845 ..Default::default()
846 };
847 let stream_key = get_stream_key(trader_id, instance_id, &config);
848
849 let handle = tokio::spawn(async move {
851 publish_messages(rx, trader_id, instance_id, config)
852 .await
853 .unwrap();
854 });
855
856 let msg = BusMessage::with_str_topic("test_topic", Bytes::from("test_payload"));
858 tx.send(msg).unwrap();
859
860 wait_until_async(
862 || {
863 let mut con = con.clone();
864 let stream_key = stream_key.clone();
865 async move {
866 let messages: RedisStreamBulk =
867 con.xread(&[&stream_key], &["0"]).await.unwrap();
868 !messages.is_empty()
869 }
870 },
871 Duration::from_secs(3),
872 )
873 .await;
874
875 let messages: RedisStreamBulk = con.xread(&[&stream_key], &["0"]).await.unwrap();
877 assert_eq!(messages.len(), 1);
878 let stream_msgs = messages[0].get(&stream_key).unwrap();
879 let stream_msg_array = &stream_msgs[0].values().next().unwrap();
880 let decoded_message = decode_bus_message(stream_msg_array).unwrap();
881 assert_eq!(decoded_message.topic, "test_topic");
882 assert_eq!(decoded_message.payload, Bytes::from("test_payload"));
883
884 let msg = BusMessage::new_close();
886 tx.send(msg).unwrap();
887
888 handle.await.unwrap();
890 flush_redis(&mut con).await.unwrap();
891 }
892
893 #[rstest]
894 #[tokio::test(flavor = "multi_thread")]
895 async fn test_stream_messages_multiple_streams(#[future] redis_connection: ConnectionManager) {
896 let mut con = redis_connection.await;
897 let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
898
899 let stream_key1 = "test:stream:1".to_string();
901 let stream_key2 = "test:stream:2".to_string();
902 let external_streams = vec![stream_key1.clone(), stream_key2.clone()];
903 let stream_signal = Arc::new(AtomicBool::new(false));
904 let stream_signal_clone = stream_signal.clone();
905
906 let clock = get_atomic_clock_realtime();
907 let base_id = clock.get_time_ms() + 1_000_000;
908
909 let handle = tokio::spawn(async move {
911 stream_messages(
912 tx,
913 DatabaseConfig::default(),
914 external_streams,
915 stream_signal_clone,
916 )
917 .await
918 .unwrap();
919 });
920
921 tokio::time::sleep(Duration::from_millis(200)).await;
922
923 let _: () = con
925 .xadd(
926 &stream_key1,
927 format!("{}", base_id + 100),
928 &[("topic", "stream1-first"), ("payload", "data")],
929 )
930 .await
931 .unwrap();
932
933 let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
934 .await
935 .expect("Stream 1 message should be received")
936 .unwrap();
937 assert_eq!(msg.topic, "stream1-first");
938
939 let _: () = con
941 .xadd(
942 &stream_key2,
943 format!("{}", base_id + 50),
944 &[("topic", "stream2-second"), ("payload", "data")],
945 )
946 .await
947 .unwrap();
948
949 let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
950 .await
951 .expect("Stream 2 message should be received")
952 .unwrap();
953 assert_eq!(msg.topic, "stream2-second");
954
955 rx.close();
957 stream_signal.store(true, Ordering::Relaxed);
958 handle.await.unwrap();
959 flush_redis(&mut con).await.unwrap();
960 }
961
962 #[rstest]
963 #[tokio::test(flavor = "multi_thread")]
964 async fn test_stream_messages_interleaved_at_different_rates(
965 #[future] redis_connection: ConnectionManager,
966 ) {
967 let mut con = redis_connection.await;
968 let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
969
970 let stream_key1 = "test:stream:interleaved:1".to_string();
972 let stream_key2 = "test:stream:interleaved:2".to_string();
973 let stream_key3 = "test:stream:interleaved:3".to_string();
974 let external_streams = vec![
975 stream_key1.clone(),
976 stream_key2.clone(),
977 stream_key3.clone(),
978 ];
979 let stream_signal = Arc::new(AtomicBool::new(false));
980 let stream_signal_clone = stream_signal.clone();
981
982 let clock = get_atomic_clock_realtime();
983 let base_id = clock.get_time_ms() + 1_000_000;
984
985 let handle = tokio::spawn(async move {
986 stream_messages(
987 tx,
988 DatabaseConfig::default(),
989 external_streams,
990 stream_signal_clone,
991 )
992 .await
993 .unwrap();
994 });
995
996 tokio::time::sleep(Duration::from_millis(200)).await;
997
998 let _: () = con
1000 .xadd(
1001 &stream_key1,
1002 format!("{}", base_id + 100),
1003 &[("topic", "s1m1"), ("payload", "data")],
1004 )
1005 .await
1006 .unwrap();
1007 let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1008 .await
1009 .expect("Stream 1 message should be received")
1010 .unwrap();
1011 assert_eq!(msg.topic, "s1m1");
1012
1013 let _: () = con
1015 .xadd(
1016 &stream_key2,
1017 format!("{}", base_id + 50),
1018 &[("topic", "s2m1"), ("payload", "data")],
1019 )
1020 .await
1021 .unwrap();
1022 let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1023 .await
1024 .expect("Stream 2 message should be received")
1025 .unwrap();
1026 assert_eq!(msg.topic, "s2m1");
1027
1028 let _: () = con
1030 .xadd(
1031 &stream_key3,
1032 format!("{}", base_id + 25),
1033 &[("topic", "s3m1"), ("payload", "data")],
1034 )
1035 .await
1036 .unwrap();
1037 let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1038 .await
1039 .expect("Stream 3 message should be received")
1040 .unwrap();
1041 assert_eq!(msg.topic, "s3m1");
1042
1043 rx.close();
1045 stream_signal.store(true, Ordering::Relaxed);
1046 handle.await.unwrap();
1047 flush_redis(&mut con).await.unwrap();
1048 }
1049
1050 #[rstest]
1051 #[tokio::test(flavor = "multi_thread")]
1052 async fn test_close() {
1053 let trader_id = TraderId::from("tester-001");
1054 let instance_id = UUID4::new();
1055 let config = MessageBusConfig {
1056 database: Some(DatabaseConfig::default()),
1057 ..Default::default()
1058 };
1059
1060 let mut db = RedisMessageBusDatabase::new(trader_id, instance_id, config).unwrap();
1061
1062 db.close();
1064 }
1065
1066 #[rstest]
1067 #[tokio::test(flavor = "multi_thread")]
1068 async fn test_heartbeat_task() {
1069 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
1070 let signal = Arc::new(AtomicBool::new(false));
1071
1072 let handle = tokio::spawn(run_heartbeat(1, signal.clone(), tx));
1074
1075 tokio::time::sleep(Duration::from_secs(2)).await;
1077
1078 signal.store(true, Ordering::Relaxed);
1080 handle.await.unwrap();
1081
1082 let mut heartbeats: Vec<BusMessage> = Vec::new();
1084 while let Ok(hb) = rx.try_recv() {
1085 heartbeats.push(hb);
1086 }
1087
1088 assert!(!heartbeats.is_empty());
1089
1090 for hb in heartbeats {
1091 assert_eq!(hb.topic, HEARTBEAT_TOPIC);
1092 }
1093 }
1094}