nautilus_infrastructure/redis/
msgbus.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    collections::{HashMap, VecDeque},
18    sync::{
19        Arc,
20        atomic::{AtomicBool, Ordering},
21    },
22    time::{Duration, Instant},
23};
24
25use bytes::Bytes;
26use futures::stream::Stream;
27use nautilus_common::{
28    msgbus::{
29        BusMessage, CLOSE_TOPIC,
30        database::{DatabaseConfig, MessageBusConfig, MessageBusDatabaseAdapter},
31    },
32    runtime::get_runtime,
33};
34use nautilus_core::{
35    UUID4,
36    time::{duration_since_unix_epoch, get_atomic_clock_realtime},
37};
38use nautilus_cryptography::providers::install_cryptographic_provider;
39use nautilus_model::identifiers::TraderId;
40use redis::*;
41use streams::StreamReadOptions;
42
43use super::{REDIS_MINID, REDIS_XTRIM, await_handle};
44use crate::redis::{create_redis_connection, get_stream_key};
45
46const MSGBUS_PUBLISH: &str = "msgbus-publish";
47const MSGBUS_STREAM: &str = "msgbus-stream";
48const MSGBUS_HEARTBEAT: &str = "msgbus-heartbeat";
49const HEARTBEAT_TOPIC: &str = "health:heartbeat";
50const TRIM_BUFFER_SECS: u64 = 60;
51
52type RedisStreamBulk = Vec<HashMap<String, Vec<HashMap<String, redis::Value>>>>;
53
54#[cfg_attr(
55    feature = "python",
56    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.infrastructure")
57)]
58pub struct RedisMessageBusDatabase {
59    /// The trader ID for this message bus database.
60    pub trader_id: TraderId,
61    /// The instance ID for this message bus database.
62    pub instance_id: UUID4,
63    pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
64    pub_handle: Option<tokio::task::JoinHandle<()>>,
65    stream_rx: Option<tokio::sync::mpsc::Receiver<BusMessage>>,
66    stream_handle: Option<tokio::task::JoinHandle<()>>,
67    stream_signal: Arc<AtomicBool>,
68    heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
69    heartbeat_signal: Arc<AtomicBool>,
70}
71
72impl MessageBusDatabaseAdapter for RedisMessageBusDatabase {
73    type DatabaseType = RedisMessageBusDatabase;
74
75    /// Creates a new [`RedisMessageBusDatabase`] instance.
76    fn new(
77        trader_id: TraderId,
78        instance_id: UUID4,
79        config: MessageBusConfig,
80    ) -> anyhow::Result<Self> {
81        install_cryptographic_provider();
82
83        let config_clone = config.clone();
84        let db_config = config
85            .database
86            .clone()
87            .ok_or_else(|| anyhow::anyhow!("No database config"))?;
88
89        let (pub_tx, pub_rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
90
91        // Create publish task (start the runtime here for now)
92        let pub_handle = Some(get_runtime().spawn(async move {
93            if let Err(e) = publish_messages(pub_rx, trader_id, instance_id, config_clone).await {
94                log::error!("Error in task '{MSGBUS_PUBLISH}': {e}");
95            };
96        }));
97
98        // Conditionally create stream task and channel if external streams configured
99        let external_streams = config.external_streams.clone().unwrap_or_default();
100        let stream_signal = Arc::new(AtomicBool::new(false));
101        let (stream_rx, stream_handle) = if !external_streams.is_empty() {
102            let stream_signal_clone = stream_signal.clone();
103            let (stream_tx, stream_rx) = tokio::sync::mpsc::channel::<BusMessage>(100_000);
104            (
105                Some(stream_rx),
106                Some(get_runtime().spawn(async move {
107                    if let Err(e) =
108                        stream_messages(stream_tx, db_config, external_streams, stream_signal_clone)
109                            .await
110                    {
111                        log::error!("Error in task '{MSGBUS_STREAM}': {e}");
112                    }
113                })),
114            )
115        } else {
116            (None, None)
117        };
118
119        // Create heartbeat task
120        let heartbeat_signal = Arc::new(AtomicBool::new(false));
121        let heartbeat_handle = if let Some(heartbeat_interval_secs) = config.heartbeat_interval_secs
122        {
123            let signal = heartbeat_signal.clone();
124            let pub_tx_clone = pub_tx.clone();
125
126            Some(get_runtime().spawn(async move {
127                run_heartbeat(heartbeat_interval_secs, signal, pub_tx_clone).await
128            }))
129        } else {
130            None
131        };
132
133        Ok(Self {
134            trader_id,
135            instance_id,
136            pub_tx,
137            pub_handle,
138            stream_rx,
139            stream_handle,
140            stream_signal,
141            heartbeat_handle,
142            heartbeat_signal,
143        })
144    }
145
146    /// Returns whether the message bus database adapter publishing channel is closed.
147    fn is_closed(&self) -> bool {
148        self.pub_tx.is_closed()
149    }
150
151    /// Publishes a message with the given `topic` and `payload`.
152    fn publish(&self, topic: String, payload: Bytes) {
153        let msg = BusMessage { topic, payload };
154        if let Err(e) = self.pub_tx.send(msg) {
155            log::error!("Failed to send message: {e}");
156        }
157    }
158
159    /// Closes the message bus database adapter.
160    fn close(&mut self) {
161        log::debug!("Closing");
162
163        self.stream_signal.store(true, Ordering::Relaxed);
164        self.heartbeat_signal.store(true, Ordering::Relaxed);
165
166        if !self.pub_tx.is_closed() {
167            let msg = BusMessage {
168                topic: CLOSE_TOPIC.to_string(),
169                payload: Bytes::new(), // Empty
170            };
171            if let Err(e) = self.pub_tx.send(msg) {
172                log::error!("Failed to send close message: {e:?}");
173            }
174        }
175
176        // Keep close sync for now to avoid async trait method
177        tokio::task::block_in_place(|| {
178            get_runtime().block_on(async {
179                self.close_async().await;
180            });
181        });
182
183        log::debug!("Closed");
184    }
185}
186
187impl RedisMessageBusDatabase {
188    /// Gets the stream receiver for this instance.
189    pub fn get_stream_receiver(
190        &mut self,
191    ) -> anyhow::Result<tokio::sync::mpsc::Receiver<BusMessage>> {
192        self.stream_rx
193            .take()
194            .ok_or_else(|| anyhow::anyhow!("Stream receiver already taken"))
195    }
196
197    /// Streams messages arriving on the stream receiver channel.
198    pub fn stream(
199        mut stream_rx: tokio::sync::mpsc::Receiver<BusMessage>,
200    ) -> impl Stream<Item = BusMessage> + 'static {
201        async_stream::stream! {
202            while let Some(msg) = stream_rx.recv().await {
203                yield msg;
204            }
205        }
206    }
207
208    pub async fn close_async(&mut self) {
209        await_handle(self.pub_handle.take(), MSGBUS_PUBLISH).await;
210        await_handle(self.stream_handle.take(), MSGBUS_STREAM).await;
211        await_handle(self.heartbeat_handle.take(), MSGBUS_HEARTBEAT).await;
212    }
213}
214
215pub async fn publish_messages(
216    mut rx: tokio::sync::mpsc::UnboundedReceiver<BusMessage>,
217    trader_id: TraderId,
218    instance_id: UUID4,
219    config: MessageBusConfig,
220) -> anyhow::Result<()> {
221    tracing::debug!("Starting message publishing");
222
223    let db_config = config
224        .database
225        .as_ref()
226        .ok_or_else(|| anyhow::anyhow!("No database config"))?;
227    let mut con = create_redis_connection(MSGBUS_PUBLISH, db_config.clone()).await?;
228    let stream_key = get_stream_key(trader_id, instance_id, &config);
229
230    // Auto-trimming
231    let autotrim_duration = config
232        .autotrim_mins
233        .filter(|&mins| mins > 0)
234        .map(|mins| Duration::from_secs(mins as u64 * 60));
235    let mut last_trim_index: HashMap<String, usize> = HashMap::new();
236
237    // Buffering
238    let mut buffer: VecDeque<BusMessage> = VecDeque::new();
239    let mut last_drain = Instant::now();
240    let buffer_interval = Duration::from_millis(config.buffer_interval_ms.unwrap_or(0) as u64);
241
242    loop {
243        if last_drain.elapsed() >= buffer_interval && !buffer.is_empty() {
244            drain_buffer(
245                &mut con,
246                &stream_key,
247                config.stream_per_topic,
248                autotrim_duration,
249                &mut last_trim_index,
250                &mut buffer,
251            )
252            .await?;
253            last_drain = Instant::now();
254        } else {
255            match rx.recv().await {
256                Some(msg) => {
257                    if msg.topic == CLOSE_TOPIC {
258                        tracing::debug!("Received close message");
259                        drop(rx);
260                        break;
261                    }
262                    buffer.push_back(msg);
263                }
264                None => {
265                    tracing::debug!("Channel hung up");
266                    break;
267                }
268            }
269        }
270    }
271
272    // Drain any remaining messages
273    if !buffer.is_empty() {
274        drain_buffer(
275            &mut con,
276            &stream_key,
277            config.stream_per_topic,
278            autotrim_duration,
279            &mut last_trim_index,
280            &mut buffer,
281        )
282        .await?;
283    }
284
285    tracing::debug!("Stopped message publishing");
286    Ok(())
287}
288
289async fn drain_buffer(
290    conn: &mut redis::aio::ConnectionManager,
291    stream_key: &str,
292    stream_per_topic: bool,
293    autotrim_duration: Option<Duration>,
294    last_trim_index: &mut HashMap<String, usize>,
295    buffer: &mut VecDeque<BusMessage>,
296) -> anyhow::Result<()> {
297    let mut pipe = redis::pipe();
298    pipe.atomic();
299
300    for msg in buffer.drain(..) {
301        let items: Vec<(&str, &[u8])> = vec![
302            ("topic", msg.topic.as_ref()),
303            ("payload", msg.payload.as_ref()),
304        ];
305        let stream_key = match stream_per_topic {
306            true => format!("{stream_key}:{}", &msg.topic),
307            false => stream_key.to_string(),
308        };
309        pipe.xadd(&stream_key, "*", &items);
310
311        if autotrim_duration.is_none() {
312            continue; // Nothing else to do
313        }
314
315        // Autotrim stream
316        let last_trim_ms = last_trim_index.entry(stream_key.clone()).or_insert(0); // Remove clone
317        let unix_duration_now = duration_since_unix_epoch();
318        let trim_buffer = Duration::from_secs(TRIM_BUFFER_SECS);
319
320        // Improve efficiency of this by batching
321        if *last_trim_ms < (unix_duration_now - trim_buffer).as_millis() as usize {
322            let min_timestamp_ms =
323                (unix_duration_now - autotrim_duration.unwrap()).as_millis() as usize;
324            let result: Result<(), redis::RedisError> = redis::cmd(REDIS_XTRIM)
325                .arg(stream_key.clone())
326                .arg(REDIS_MINID)
327                .arg(min_timestamp_ms)
328                .query_async(conn)
329                .await;
330
331            if let Err(e) = result {
332                tracing::error!("Error trimming stream '{stream_key}': {e}");
333            } else {
334                last_trim_index.insert(
335                    stream_key.to_string(),
336                    unix_duration_now.as_millis() as usize,
337                );
338            }
339        }
340    }
341
342    pipe.query_async(conn).await.map_err(anyhow::Error::from)
343}
344
345pub async fn stream_messages(
346    tx: tokio::sync::mpsc::Sender<BusMessage>,
347    config: DatabaseConfig,
348    stream_keys: Vec<String>,
349    stream_signal: Arc<AtomicBool>,
350) -> anyhow::Result<()> {
351    tracing::info!("Starting message streaming");
352    let mut con = create_redis_connection(MSGBUS_STREAM, config).await?;
353
354    let stream_keys = &stream_keys
355        .iter()
356        .map(String::as_str)
357        .collect::<Vec<&str>>();
358
359    tracing::debug!("Listening to streams: [{}]", stream_keys.join(", "));
360
361    // Start streaming from current timestamp
362    let clock = get_atomic_clock_realtime();
363    let timestamp_ms = clock.get_time_ms();
364    let mut last_id = timestamp_ms.to_string();
365
366    let opts = StreamReadOptions::default().block(100);
367
368    'outer: loop {
369        if stream_signal.load(Ordering::Relaxed) {
370            tracing::debug!("Received streaming terminate signal");
371            break;
372        }
373        let result: Result<RedisStreamBulk, _> =
374            con.xread_options(&[&stream_keys], &[&last_id], &opts).await;
375        match result {
376            Ok(stream_bulk) => {
377                if stream_bulk.is_empty() {
378                    // Timeout occurred: no messages received
379                    continue;
380                }
381                for entry in stream_bulk.iter() {
382                    for (_stream_key, stream_msgs) in entry.iter() {
383                        for stream_msg in stream_msgs.iter() {
384                            for (id, array) in stream_msg {
385                                last_id.clear();
386                                last_id.push_str(id);
387                                match decode_bus_message(array) {
388                                    Ok(msg) => {
389                                        if let Err(e) = tx.send(msg).await {
390                                            tracing::debug!("Channel closed: {e:?}");
391                                            break 'outer; // End streaming
392                                        }
393                                    }
394                                    Err(e) => {
395                                        tracing::error!("{e:?}");
396                                        continue;
397                                    }
398                                }
399                            }
400                        }
401                    }
402                }
403            }
404            Err(e) => {
405                anyhow::bail!("Error reading from stream: {e:?}");
406            }
407        }
408    }
409
410    tracing::debug!("Stopped message streaming");
411    Ok(())
412}
413
414fn decode_bus_message(stream_msg: &redis::Value) -> anyhow::Result<BusMessage> {
415    if let redis::Value::Array(stream_msg) = stream_msg {
416        if stream_msg.len() < 4 {
417            anyhow::bail!("Invalid stream message format: {stream_msg:?}");
418        }
419
420        let topic = match &stream_msg[1] {
421            redis::Value::BulkString(bytes) => match String::from_utf8(bytes.clone()) {
422                Ok(topic) => topic,
423                Err(e) => anyhow::bail!("Error parsing topic: {e}"),
424            },
425            _ => {
426                anyhow::bail!("Invalid topic format: {stream_msg:?}");
427            }
428        };
429
430        let payload = match &stream_msg[3] {
431            redis::Value::BulkString(bytes) => Bytes::copy_from_slice(bytes),
432            _ => {
433                anyhow::bail!("Invalid payload format: {stream_msg:?}");
434            }
435        };
436
437        Ok(BusMessage { topic, payload })
438    } else {
439        anyhow::bail!("Invalid stream message format: {stream_msg:?}")
440    }
441}
442
443async fn run_heartbeat(
444    heartbeat_interval_secs: u16,
445    signal: Arc<AtomicBool>,
446    pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
447) {
448    tracing::debug!("Starting heartbeat at {heartbeat_interval_secs} second intervals");
449
450    let heartbeat_interval = Duration::from_secs(heartbeat_interval_secs as u64);
451    let heartbeat_timer = tokio::time::interval(heartbeat_interval);
452
453    let check_interval = Duration::from_millis(100);
454    let check_timer = tokio::time::interval(check_interval);
455
456    tokio::pin!(heartbeat_timer);
457    tokio::pin!(check_timer);
458
459    loop {
460        if signal.load(Ordering::Relaxed) {
461            tracing::debug!("Received heartbeat terminate signal");
462            break;
463        }
464
465        tokio::select! {
466            _ = heartbeat_timer.tick() => {
467                let heartbeat = create_heartbeat_msg();
468                if let Err(e) = pub_tx.send(heartbeat) {
469                    // We expect an error if the channel is closed during shutdown
470                    tracing::debug!("Error sending heartbeat: {e}");
471                }
472            },
473            _ = check_timer.tick() => {}
474        }
475    }
476
477    tracing::debug!("Stopped heartbeat");
478}
479
480fn create_heartbeat_msg() -> BusMessage {
481    BusMessage {
482        topic: HEARTBEAT_TOPIC.to_string(),
483        payload: Bytes::from(chrono::Utc::now().to_rfc3339().into_bytes()),
484    }
485}
486
487////////////////////////////////////////////////////////////////////////////////
488// Tests
489////////////////////////////////////////////////////////////////////////////////
490#[cfg(test)]
491mod tests {
492    use redis::Value;
493    use rstest::*;
494
495    use super::*;
496
497    #[rstest]
498    fn test_decode_bus_message_valid() {
499        let stream_msg = Value::Array(vec![
500            Value::BulkString(b"0".to_vec()),
501            Value::BulkString(b"topic1".to_vec()),
502            Value::BulkString(b"unused".to_vec()),
503            Value::BulkString(b"data1".to_vec()),
504        ]);
505
506        let result = decode_bus_message(&stream_msg);
507        assert!(result.is_ok());
508        let msg = result.unwrap();
509        assert_eq!(msg.topic, "topic1");
510        assert_eq!(msg.payload, Bytes::from("data1"));
511    }
512
513    #[rstest]
514    fn test_decode_bus_message_missing_fields() {
515        let stream_msg = Value::Array(vec![
516            Value::BulkString(b"0".to_vec()),
517            Value::BulkString(b"topic1".to_vec()),
518        ]);
519
520        let result = decode_bus_message(&stream_msg);
521        assert!(result.is_err());
522        assert_eq!(
523            format!("{}", result.unwrap_err()),
524            "Invalid stream message format: [bulk-string('\"0\"'), bulk-string('\"topic1\"')]"
525        );
526    }
527
528    #[rstest]
529    fn test_decode_bus_message_invalid_topic_format() {
530        let stream_msg = Value::Array(vec![
531            Value::BulkString(b"0".to_vec()),
532            Value::Int(42), // Invalid topic format
533            Value::BulkString(b"unused".to_vec()),
534            Value::BulkString(b"data1".to_vec()),
535        ]);
536
537        let result = decode_bus_message(&stream_msg);
538        assert!(result.is_err());
539        assert_eq!(
540            format!("{}", result.unwrap_err()),
541            "Invalid topic format: [bulk-string('\"0\"'), int(42), bulk-string('\"unused\"'), bulk-string('\"data1\"')]"
542        );
543    }
544
545    #[rstest]
546    fn test_decode_bus_message_invalid_payload_format() {
547        let stream_msg = Value::Array(vec![
548            Value::BulkString(b"0".to_vec()),
549            Value::BulkString(b"topic1".to_vec()),
550            Value::BulkString(b"unused".to_vec()),
551            Value::Int(42), // Invalid payload format
552        ]);
553
554        let result = decode_bus_message(&stream_msg);
555        assert!(result.is_err());
556        assert_eq!(
557            format!("{}", result.unwrap_err()),
558            "Invalid payload format: [bulk-string('\"0\"'), bulk-string('\"topic1\"'), bulk-string('\"unused\"'), int(42)]"
559        );
560    }
561
562    #[rstest]
563    fn test_decode_bus_message_invalid_stream_msg_format() {
564        let stream_msg = Value::BulkString(b"not an array".to_vec());
565
566        let result = decode_bus_message(&stream_msg);
567        assert!(result.is_err());
568        assert_eq!(
569            format!("{}", result.unwrap_err()),
570            "Invalid stream message format: bulk-string('\"not an array\"')"
571        );
572    }
573}
574
575#[cfg(target_os = "linux")] // Run Redis tests on Linux platforms only
576#[cfg(test)]
577mod serial_tests {
578    use nautilus_common::testing::wait_until_async;
579    use redis::aio::ConnectionManager;
580    use rstest::*;
581
582    use super::*;
583    use crate::redis::flush_redis;
584
585    #[fixture]
586    async fn redis_connection() -> ConnectionManager {
587        let config = DatabaseConfig::default();
588        let mut con = create_redis_connection(MSGBUS_STREAM, config)
589            .await
590            .unwrap();
591        flush_redis(&mut con).await.unwrap();
592        con
593    }
594
595    #[rstest]
596    #[tokio::test(flavor = "multi_thread")]
597    async fn test_stream_messages_terminate_signal(#[future] redis_connection: ConnectionManager) {
598        let mut con = redis_connection.await;
599        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
600
601        let trader_id = TraderId::from("tester-001");
602        let instance_id = UUID4::new();
603        let mut config = MessageBusConfig::default();
604        config.database = Some(DatabaseConfig::default());
605
606        let stream_key = get_stream_key(trader_id, instance_id, &config);
607        let external_streams = vec![stream_key.clone()];
608        let stream_signal = Arc::new(AtomicBool::new(false));
609        let stream_signal_clone = stream_signal.clone();
610
611        // Start the message streaming task
612        let handle = tokio::spawn(async move {
613            stream_messages(
614                tx,
615                DatabaseConfig::default(),
616                external_streams,
617                stream_signal_clone,
618            )
619            .await
620            .unwrap();
621        });
622
623        stream_signal.store(true, Ordering::Relaxed);
624        let _ = rx.recv().await; // Wait for the tx to close
625
626        // Shutdown and cleanup
627        rx.close();
628        handle.await.unwrap();
629        flush_redis(&mut con).await.unwrap()
630    }
631
632    #[rstest]
633    #[tokio::test(flavor = "multi_thread")]
634    async fn test_stream_messages_when_receiver_closed(
635        #[future] redis_connection: ConnectionManager,
636    ) {
637        let mut con = redis_connection.await;
638        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
639
640        let trader_id = TraderId::from("tester-001");
641        let instance_id = UUID4::new();
642        let mut config = MessageBusConfig::default();
643        config.database = Some(DatabaseConfig::default());
644
645        let stream_key = get_stream_key(trader_id, instance_id, &config);
646        let external_streams = vec![stream_key.clone()];
647        let stream_signal = Arc::new(AtomicBool::new(false));
648        let stream_signal_clone = stream_signal.clone();
649
650        // Use a message ID in the future, as streaming begins
651        // around the timestamp the thread is spawned.
652        let clock = get_atomic_clock_realtime();
653        let future_id = (clock.get_time_ms() + 1_000_000).to_string();
654
655        // Publish test message
656        let _: () = con
657            .xadd(
658                stream_key,
659                future_id,
660                &[("topic", "topic1"), ("payload", "data1")],
661            )
662            .await
663            .unwrap();
664
665        // Immediately close channel
666        rx.close();
667
668        // Start the message streaming task
669        let handle = tokio::spawn(async move {
670            stream_messages(
671                tx,
672                DatabaseConfig::default(),
673                external_streams,
674                stream_signal_clone,
675            )
676            .await
677            .unwrap();
678        });
679
680        // Shutdown and cleanup
681        handle.await.unwrap();
682        flush_redis(&mut con).await.unwrap()
683    }
684
685    #[rstest]
686    #[tokio::test(flavor = "multi_thread")]
687    async fn test_stream_messages(#[future] redis_connection: ConnectionManager) {
688        let mut con = redis_connection.await;
689        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
690
691        let trader_id = TraderId::from("tester-001");
692        let instance_id = UUID4::new();
693        let mut config = MessageBusConfig::default();
694        config.database = Some(DatabaseConfig::default());
695
696        let stream_key = get_stream_key(trader_id, instance_id, &config);
697        let external_streams = vec![stream_key.clone()];
698        let stream_signal = Arc::new(AtomicBool::new(false));
699        let stream_signal_clone = stream_signal.clone();
700
701        // Use a message ID in the future, as streaming begins
702        // around the timestamp the task is spawned.
703        let clock = get_atomic_clock_realtime();
704        let future_id = (clock.get_time_ms() + 1_000_000).to_string();
705
706        // Publish test message
707        let _: () = con
708            .xadd(
709                stream_key,
710                future_id,
711                &[("topic", "topic1"), ("payload", "data1")],
712            )
713            .await
714            .unwrap();
715
716        // Start the message streaming task
717        let handle = tokio::spawn(async move {
718            stream_messages(
719                tx,
720                DatabaseConfig::default(),
721                external_streams,
722                stream_signal_clone,
723            )
724            .await
725            .unwrap();
726        });
727
728        // Receive and verify the message
729        let msg = rx.recv().await.unwrap();
730        assert_eq!(msg.topic, "topic1");
731        assert_eq!(msg.payload, Bytes::from("data1"));
732
733        // Shutdown and cleanup
734        rx.close();
735        stream_signal.store(true, Ordering::Relaxed);
736        handle.await.unwrap();
737        flush_redis(&mut con).await.unwrap()
738    }
739
740    #[rstest]
741    #[tokio::test(flavor = "multi_thread")]
742    async fn test_publish_messages(#[future] redis_connection: ConnectionManager) {
743        let mut con = redis_connection.await;
744        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
745
746        let trader_id = TraderId::from("tester-001");
747        let instance_id = UUID4::new();
748        let mut config = MessageBusConfig::default();
749        config.database = Some(DatabaseConfig::default());
750        config.stream_per_topic = false;
751        let stream_key = get_stream_key(trader_id, instance_id, &config);
752
753        // Start the publish_messages task
754        let handle = tokio::spawn(async move {
755            publish_messages(rx, trader_id, instance_id, config)
756                .await
757                .unwrap();
758        });
759
760        // Send a test message
761        let msg = BusMessage {
762            topic: "test_topic".to_string(),
763            payload: Bytes::from("test_payload"),
764        };
765        tx.send(msg).unwrap();
766
767        // Wait until the message is published to Redis
768        wait_until_async(
769            || {
770                let mut con = con.clone();
771                let stream_key = stream_key.clone();
772                async move {
773                    let messages: RedisStreamBulk =
774                        con.xread(&[&stream_key], &["0"]).await.unwrap();
775                    !messages.is_empty()
776                }
777            },
778            Duration::from_secs(2),
779        )
780        .await;
781
782        // Verify the message was published to Redis
783        let messages: RedisStreamBulk = con.xread(&[&stream_key], &["0"]).await.unwrap();
784        assert_eq!(messages.len(), 1);
785        let stream_msgs = messages[0].get(&stream_key).unwrap();
786        let stream_msg_array = &stream_msgs[0].values().next().unwrap();
787        let decoded_message = decode_bus_message(stream_msg_array).unwrap();
788        assert_eq!(decoded_message.topic, "test_topic");
789        assert_eq!(decoded_message.payload, Bytes::from("test_payload"));
790
791        // Stop publishing task
792        let msg = BusMessage {
793            topic: CLOSE_TOPIC.to_string(),
794            payload: Bytes::new(), // Empty
795        };
796        tx.send(msg).unwrap();
797
798        // Shutdown and cleanup
799        handle.await.unwrap();
800        flush_redis(&mut con).await.unwrap();
801    }
802
803    #[rstest]
804    #[tokio::test(flavor = "multi_thread")]
805    async fn test_close() {
806        let trader_id = TraderId::from("tester-001");
807        let instance_id = UUID4::new();
808        let mut config = MessageBusConfig::default();
809        config.database = Some(DatabaseConfig::default());
810
811        let mut db = RedisMessageBusDatabase::new(trader_id, instance_id, config).unwrap();
812
813        // Close the message bus database (test should not hang)
814        db.close();
815    }
816
817    #[rstest]
818    #[tokio::test(flavor = "multi_thread")]
819    async fn test_heartbeat_task() {
820        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
821        let signal = Arc::new(AtomicBool::new(false));
822
823        // Start the heartbeat task with a short interval
824        let handle = tokio::spawn(run_heartbeat(1, signal.clone(), tx));
825
826        // Wait for a couple of heartbeats
827        tokio::time::sleep(Duration::from_secs(2)).await;
828
829        // Stop the heartbeat task
830        signal.store(true, Ordering::Relaxed);
831        handle.await.unwrap();
832
833        // Ensure heartbeats were sent
834        let mut heartbeats: Vec<BusMessage> = Vec::new();
835        while let Ok(hb) = rx.try_recv() {
836            heartbeats.push(hb);
837        }
838
839        assert!(!heartbeats.is_empty());
840
841        for hb in heartbeats {
842            assert_eq!(hb.topic, HEARTBEAT_TOPIC);
843        }
844    }
845}