nautilus_infrastructure/redis/
msgbus.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    collections::{HashMap, VecDeque},
18    sync::{
19        atomic::{AtomicBool, Ordering},
20        Arc,
21    },
22    time::{Duration, Instant},
23};
24
25use bytes::Bytes;
26use futures::stream::Stream;
27use nautilus_common::{
28    msgbus::{
29        database::{BusMessage, DatabaseConfig, MessageBusConfig, MessageBusDatabaseAdapter},
30        CLOSE_TOPIC,
31    },
32    runtime::get_runtime,
33};
34use nautilus_core::{
35    time::{duration_since_unix_epoch, get_atomic_clock_realtime},
36    UUID4,
37};
38use nautilus_cryptography::providers::install_cryptographic_provider;
39use nautilus_model::identifiers::TraderId;
40use redis::*;
41use streams::StreamReadOptions;
42
43use super::{await_handle, REDIS_MINID, REDIS_XTRIM};
44use crate::redis::{create_redis_connection, get_stream_key};
45
46const MSGBUS_PUBLISH: &str = "msgbus-publish";
47const MSGBUS_STREAM: &str = "msgbus-stream";
48const MSGBUS_HEARTBEAT: &str = "msgbus-heartbeat";
49const HEARTBEAT_TOPIC: &str = "health:heartbeat";
50const TRIM_BUFFER_SECS: u64 = 60;
51
52type RedisStreamBulk = Vec<HashMap<String, Vec<HashMap<String, redis::Value>>>>;
53
54#[cfg_attr(
55    feature = "python",
56    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.infrastructure")
57)]
58pub struct RedisMessageBusDatabase {
59    /// The trader ID for this message bus database.
60    pub trader_id: TraderId,
61    /// The instance ID for this message bus database.
62    pub instance_id: UUID4,
63    pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
64    pub_handle: Option<tokio::task::JoinHandle<()>>,
65    stream_rx: Option<tokio::sync::mpsc::Receiver<BusMessage>>,
66    stream_handle: Option<tokio::task::JoinHandle<()>>,
67    stream_signal: Arc<AtomicBool>,
68    heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
69    heartbeat_signal: Arc<AtomicBool>,
70}
71
72impl MessageBusDatabaseAdapter for RedisMessageBusDatabase {
73    type DatabaseType = RedisMessageBusDatabase;
74
75    /// Creates a new [`RedisMessageBusDatabase`] instance.
76    fn new(
77        trader_id: TraderId,
78        instance_id: UUID4,
79        config: MessageBusConfig,
80    ) -> anyhow::Result<Self> {
81        install_cryptographic_provider();
82
83        let config_clone = config.clone();
84        let db_config = config
85            .database
86            .clone()
87            .ok_or_else(|| anyhow::anyhow!("No database config"))?;
88
89        let (pub_tx, pub_rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
90
91        // Create publish task (start the runtime here for now)
92        let pub_handle = Some(get_runtime().spawn(async move {
93            publish_messages(pub_rx, trader_id, instance_id, config_clone)
94                .await
95                .expect("Error spawning task '{MSGBUS_PUBLISH}'}");
96        }));
97
98        // Conditionally create stream task and channel if external streams configured
99        let external_streams = config.external_streams.clone().unwrap_or_default();
100        let stream_signal = Arc::new(AtomicBool::new(false));
101        let (stream_rx, stream_handle) = if !external_streams.is_empty() {
102            let stream_signal_clone = stream_signal.clone();
103            let (stream_tx, stream_rx) = tokio::sync::mpsc::channel::<BusMessage>(100_000);
104            (
105                Some(stream_rx),
106                Some(get_runtime().spawn(async move {
107                    stream_messages(stream_tx, db_config, external_streams, stream_signal_clone)
108                        .await
109                        .expect("Error spawning task '{MSGBUS_STREAM}'}");
110                })),
111            )
112        } else {
113            (None, None)
114        };
115
116        // Create heartbeat task
117        let heartbeat_signal = Arc::new(AtomicBool::new(false));
118        let heartbeat_handle = if let Some(heartbeat_interval_secs) = config.heartbeat_interval_secs
119        {
120            let signal = heartbeat_signal.clone();
121            let pub_tx_clone = pub_tx.clone();
122
123            Some(get_runtime().spawn(async move {
124                run_heartbeat(heartbeat_interval_secs, signal, pub_tx_clone).await
125            }))
126        } else {
127            None
128        };
129
130        Ok(Self {
131            trader_id,
132            instance_id,
133            pub_tx,
134            pub_handle,
135            stream_rx,
136            stream_handle,
137            stream_signal,
138            heartbeat_handle,
139            heartbeat_signal,
140        })
141    }
142
143    /// Returns whether the message bus database adapter publishing channel is closed.
144    #[must_use]
145    fn is_closed(&self) -> bool {
146        self.pub_tx.is_closed()
147    }
148
149    /// Publishes a message with the given `topic` and `payload`.
150    fn publish(&self, topic: String, payload: Bytes) {
151        let msg = BusMessage { topic, payload };
152        if let Err(e) = self.pub_tx.send(msg) {
153            log::error!("Failed to send message: {e}");
154        }
155    }
156
157    /// Closes the message bus database adapter.
158    fn close(&mut self) {
159        log::debug!("Closing");
160
161        self.stream_signal.store(true, Ordering::Relaxed);
162        self.heartbeat_signal.store(true, Ordering::Relaxed);
163
164        if !self.pub_tx.is_closed() {
165            let msg = BusMessage {
166                topic: CLOSE_TOPIC.to_string(),
167                payload: Bytes::new(), // Empty
168            };
169            if let Err(e) = self.pub_tx.send(msg) {
170                log::error!("Failed to send close message: {e:?}");
171            }
172        }
173
174        // Keep close sync for now to avoid async trait method
175        tokio::task::block_in_place(|| {
176            get_runtime().block_on(async {
177                self.close_async().await;
178            });
179        });
180
181        log::debug!("Closed");
182    }
183}
184
185impl RedisMessageBusDatabase {
186    /// Gets the stream receiver for this instance.
187    pub fn get_stream_receiver(
188        &mut self,
189    ) -> anyhow::Result<tokio::sync::mpsc::Receiver<BusMessage>> {
190        self.stream_rx
191            .take()
192            .ok_or_else(|| anyhow::anyhow!("Stream receiver already taken"))
193    }
194
195    /// Streams messages arriving on the stream receiver channel.
196    pub fn stream(
197        mut stream_rx: tokio::sync::mpsc::Receiver<BusMessage>,
198    ) -> impl Stream<Item = BusMessage> + 'static {
199        async_stream::stream! {
200            while let Some(msg) = stream_rx.recv().await {
201                yield msg;
202            }
203        }
204    }
205
206    pub async fn close_async(&mut self) {
207        await_handle(self.pub_handle.take(), MSGBUS_PUBLISH).await;
208        await_handle(self.stream_handle.take(), MSGBUS_STREAM).await;
209        await_handle(self.heartbeat_handle.take(), MSGBUS_HEARTBEAT).await;
210    }
211}
212
213pub async fn publish_messages(
214    mut rx: tokio::sync::mpsc::UnboundedReceiver<BusMessage>,
215    trader_id: TraderId,
216    instance_id: UUID4,
217    config: MessageBusConfig,
218) -> anyhow::Result<()> {
219    tracing::debug!("Starting message publishing");
220
221    let db_config = config
222        .database
223        .as_ref()
224        .ok_or_else(|| anyhow::anyhow!("No database config"))?;
225    let mut con = create_redis_connection(MSGBUS_PUBLISH, db_config.clone())?;
226    let stream_key = get_stream_key(trader_id, instance_id, &config);
227
228    // Auto-trimming
229    let autotrim_duration = config
230        .autotrim_mins
231        .filter(|&mins| mins > 0)
232        .map(|mins| Duration::from_secs(mins as u64 * 60));
233    let mut last_trim_index: HashMap<String, usize> = HashMap::new();
234
235    // Buffering
236    let mut buffer: VecDeque<BusMessage> = VecDeque::new();
237    let mut last_drain = Instant::now();
238    let buffer_interval = Duration::from_millis(config.buffer_interval_ms.unwrap_or(0) as u64);
239
240    loop {
241        if last_drain.elapsed() >= buffer_interval && !buffer.is_empty() {
242            drain_buffer(
243                &mut con,
244                &stream_key,
245                config.stream_per_topic,
246                autotrim_duration,
247                &mut last_trim_index,
248                &mut buffer,
249            )?;
250            last_drain = Instant::now();
251        } else {
252            match rx.recv().await {
253                Some(msg) => {
254                    if msg.topic == CLOSE_TOPIC {
255                        tracing::debug!("Received close message");
256                        drop(rx);
257                        break;
258                    }
259                    buffer.push_back(msg);
260                }
261                None => {
262                    tracing::debug!("Channel hung up");
263                    break;
264                }
265            }
266        }
267    }
268
269    // Drain any remaining messages
270    if !buffer.is_empty() {
271        drain_buffer(
272            &mut con,
273            &stream_key,
274            config.stream_per_topic,
275            autotrim_duration,
276            &mut last_trim_index,
277            &mut buffer,
278        )?;
279    }
280
281    tracing::debug!("Stopped message publishing");
282    Ok(())
283}
284
285fn drain_buffer(
286    conn: &mut Connection,
287    stream_key: &str,
288    stream_per_topic: bool,
289    autotrim_duration: Option<Duration>,
290    last_trim_index: &mut HashMap<String, usize>,
291    buffer: &mut VecDeque<BusMessage>,
292) -> anyhow::Result<()> {
293    let mut pipe = redis::pipe();
294    pipe.atomic();
295
296    for msg in buffer.drain(..) {
297        let items: Vec<(&str, &[u8])> = vec![
298            ("topic", msg.topic.as_ref()),
299            ("payload", msg.payload.as_ref()),
300        ];
301        let stream_key = match stream_per_topic {
302            true => format!("{stream_key}:{}", &msg.topic),
303            false => stream_key.to_string(),
304        };
305        pipe.xadd(&stream_key, "*", &items);
306
307        if autotrim_duration.is_none() {
308            continue; // Nothing else to do
309        }
310
311        // Autotrim stream
312        let last_trim_ms = last_trim_index.entry(stream_key.clone()).or_insert(0); // Remove clone
313        let unix_duration_now = duration_since_unix_epoch();
314        let trim_buffer = Duration::from_secs(TRIM_BUFFER_SECS);
315
316        // Improve efficiency of this by batching
317        if *last_trim_ms < (unix_duration_now - trim_buffer).as_millis() as usize {
318            let min_timestamp_ms =
319                (unix_duration_now - autotrim_duration.unwrap()).as_millis() as usize;
320            let result: Result<(), redis::RedisError> = redis::cmd(REDIS_XTRIM)
321                .arg(stream_key.clone())
322                .arg(REDIS_MINID)
323                .arg(min_timestamp_ms)
324                .query(conn);
325
326            if let Err(e) = result {
327                tracing::error!("Error trimming stream '{stream_key}': {e}");
328            } else {
329                last_trim_index.insert(
330                    stream_key.to_string(),
331                    unix_duration_now.as_millis() as usize,
332                );
333            }
334        }
335    }
336
337    pipe.query::<()>(conn).map_err(anyhow::Error::from)
338}
339
340pub async fn stream_messages(
341    tx: tokio::sync::mpsc::Sender<BusMessage>,
342    config: DatabaseConfig,
343    stream_keys: Vec<String>,
344    stream_signal: Arc<AtomicBool>,
345) -> anyhow::Result<()> {
346    tracing::info!("Starting message streaming");
347    let mut con = create_redis_connection(MSGBUS_STREAM, config)?;
348
349    let stream_keys = &stream_keys
350        .iter()
351        .map(String::as_str)
352        .collect::<Vec<&str>>();
353
354    tracing::debug!("Listening to streams: [{}]", stream_keys.join(", "));
355
356    // Start streaming from current timestamp
357    let clock = get_atomic_clock_realtime();
358    let timestamp_ms = clock.get_time_ms();
359    let mut last_id = timestamp_ms.to_string();
360
361    let opts = StreamReadOptions::default().block(100);
362
363    'outer: loop {
364        if stream_signal.load(Ordering::Relaxed) {
365            tracing::debug!("Received streaming terminate signal");
366            break;
367        }
368        let result: Result<RedisStreamBulk, _> =
369            con.xread_options(&[&stream_keys], &[&last_id], &opts);
370        match result {
371            Ok(stream_bulk) => {
372                if stream_bulk.is_empty() {
373                    // Timeout occurred: no messages received
374                    continue;
375                }
376                for entry in stream_bulk.iter() {
377                    for (_stream_key, stream_msgs) in entry.iter() {
378                        for stream_msg in stream_msgs.iter() {
379                            for (id, array) in stream_msg {
380                                last_id.clear();
381                                last_id.push_str(id);
382                                match decode_bus_message(array) {
383                                    Ok(msg) => {
384                                        if let Err(e) = tx.send(msg).await {
385                                            tracing::debug!("Channel closed: {e:?}");
386                                            break 'outer; // End streaming
387                                        }
388                                    }
389                                    Err(e) => {
390                                        tracing::error!("{e:?}");
391                                        continue;
392                                    }
393                                }
394                            }
395                        }
396                    }
397                }
398            }
399            Err(e) => {
400                return Err(anyhow::anyhow!("Error reading from stream: {e:?}"));
401            }
402        }
403    }
404
405    tracing::debug!("Stopped message streaming");
406    Ok(())
407}
408
409fn decode_bus_message(stream_msg: &redis::Value) -> anyhow::Result<BusMessage> {
410    if let redis::Value::Array(stream_msg) = stream_msg {
411        if stream_msg.len() < 4 {
412            anyhow::bail!("Invalid stream message format: {stream_msg:?}");
413        }
414
415        let topic = match &stream_msg[1] {
416            redis::Value::BulkString(bytes) => {
417                String::from_utf8(bytes.clone()).expect("Error parsing topic")
418            }
419            _ => {
420                anyhow::bail!("Invalid topic format: {stream_msg:?}");
421            }
422        };
423
424        let payload = match &stream_msg[3] {
425            redis::Value::BulkString(bytes) => Bytes::copy_from_slice(bytes),
426            _ => {
427                anyhow::bail!("Invalid payload format: {stream_msg:?}");
428            }
429        };
430
431        Ok(BusMessage { topic, payload })
432    } else {
433        anyhow::bail!("Invalid stream message format: {stream_msg:?}")
434    }
435}
436
437async fn run_heartbeat(
438    heartbeat_interval_secs: u16,
439    signal: Arc<AtomicBool>,
440    pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
441) {
442    tracing::debug!("Starting heartbeat at {heartbeat_interval_secs} second intervals");
443
444    let heartbeat_interval = Duration::from_secs(heartbeat_interval_secs as u64);
445    let heartbeat_timer = tokio::time::interval(heartbeat_interval);
446
447    let check_interval = Duration::from_millis(100);
448    let check_timer = tokio::time::interval(check_interval);
449
450    tokio::pin!(heartbeat_timer);
451    tokio::pin!(check_timer);
452
453    loop {
454        if signal.load(Ordering::Relaxed) {
455            tracing::debug!("Received heartbeat terminate signal");
456            break;
457        }
458
459        tokio::select! {
460            _ = heartbeat_timer.tick() => {
461                let heartbeat = create_heartbeat_msg();
462                if let Err(e) = pub_tx.send(heartbeat) {
463                    // We expect an error if the channel is closed during shutdown
464                    tracing::debug!("Error sending heartbeat: {e}");
465                }
466            },
467            _ = check_timer.tick() => {}
468        }
469    }
470
471    tracing::debug!("Stopped heartbeat");
472}
473
474fn create_heartbeat_msg() -> BusMessage {
475    BusMessage {
476        topic: HEARTBEAT_TOPIC.to_string(),
477        payload: Bytes::from(chrono::Utc::now().to_rfc3339().into_bytes()),
478    }
479}
480
481////////////////////////////////////////////////////////////////////////////////
482// Tests
483////////////////////////////////////////////////////////////////////////////////
484#[cfg(test)]
485mod tests {
486    use redis::Value;
487    use rstest::*;
488
489    use super::*;
490
491    #[rstest]
492    fn test_decode_bus_message_valid() {
493        let stream_msg = Value::Array(vec![
494            Value::BulkString(b"0".to_vec()),
495            Value::BulkString(b"topic1".to_vec()),
496            Value::BulkString(b"unused".to_vec()),
497            Value::BulkString(b"data1".to_vec()),
498        ]);
499
500        let result = decode_bus_message(&stream_msg);
501        assert!(result.is_ok());
502        let msg = result.unwrap();
503        assert_eq!(msg.topic, "topic1");
504        assert_eq!(msg.payload, Bytes::from("data1"));
505    }
506
507    #[rstest]
508    fn test_decode_bus_message_missing_fields() {
509        let stream_msg = Value::Array(vec![
510            Value::BulkString(b"0".to_vec()),
511            Value::BulkString(b"topic1".to_vec()),
512        ]);
513
514        let result = decode_bus_message(&stream_msg);
515        assert!(result.is_err());
516        assert_eq!(
517            format!("{}", result.unwrap_err()),
518            "Invalid stream message format: [bulk-string('\"0\"'), bulk-string('\"topic1\"')]"
519        );
520    }
521
522    #[rstest]
523    fn test_decode_bus_message_invalid_topic_format() {
524        let stream_msg = Value::Array(vec![
525            Value::BulkString(b"0".to_vec()),
526            Value::Int(42), // Invalid topic format
527            Value::BulkString(b"unused".to_vec()),
528            Value::BulkString(b"data1".to_vec()),
529        ]);
530
531        let result = decode_bus_message(&stream_msg);
532        assert!(result.is_err());
533        assert_eq!(
534            format!("{}", result.unwrap_err()),
535            "Invalid topic format: [bulk-string('\"0\"'), int(42), bulk-string('\"unused\"'), bulk-string('\"data1\"')]"
536        );
537    }
538
539    #[rstest]
540    fn test_decode_bus_message_invalid_payload_format() {
541        let stream_msg = Value::Array(vec![
542            Value::BulkString(b"0".to_vec()),
543            Value::BulkString(b"topic1".to_vec()),
544            Value::BulkString(b"unused".to_vec()),
545            Value::Int(42), // Invalid payload format
546        ]);
547
548        let result = decode_bus_message(&stream_msg);
549        assert!(result.is_err());
550        assert_eq!(
551            format!("{}", result.unwrap_err()),
552            "Invalid payload format: [bulk-string('\"0\"'), bulk-string('\"topic1\"'), bulk-string('\"unused\"'), int(42)]"
553        );
554    }
555
556    #[rstest]
557    fn test_decode_bus_message_invalid_stream_msg_format() {
558        let stream_msg = Value::BulkString(b"not an array".to_vec());
559
560        let result = decode_bus_message(&stream_msg);
561        assert!(result.is_err());
562        assert_eq!(
563            format!("{}", result.unwrap_err()),
564            "Invalid stream message format: bulk-string('\"not an array\"')"
565        );
566    }
567}
568
569#[cfg(target_os = "linux")] // Run Redis tests on Linux platforms only
570#[cfg(test)]
571mod serial_tests {
572    use nautilus_common::testing::wait_until;
573    use redis::Commands;
574    use rstest::*;
575
576    use super::*;
577    use crate::redis::flush_redis;
578
579    #[fixture]
580    fn redis_connection() -> redis::Connection {
581        let config = DatabaseConfig::default();
582        let mut con = create_redis_connection(MSGBUS_STREAM, config).unwrap();
583        flush_redis(&mut con).unwrap();
584        con
585    }
586
587    #[rstest]
588    #[tokio::test(flavor = "multi_thread")]
589    async fn test_stream_messages_terminate_signal(redis_connection: redis::Connection) {
590        let mut con = redis_connection;
591        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
592
593        let trader_id = TraderId::from("tester-001");
594        let instance_id = UUID4::new();
595        let mut config = MessageBusConfig::default();
596        config.database = Some(DatabaseConfig::default());
597
598        let stream_key = get_stream_key(trader_id, instance_id, &config);
599        let external_streams = vec![stream_key.clone()];
600        let stream_signal = Arc::new(AtomicBool::new(false));
601        let stream_signal_clone = stream_signal.clone();
602
603        // Start the message streaming task
604        let handle = tokio::spawn(async move {
605            stream_messages(
606                tx,
607                DatabaseConfig::default(),
608                external_streams,
609                stream_signal_clone,
610            )
611            .await
612            .unwrap();
613        });
614
615        stream_signal.store(true, Ordering::Relaxed);
616        let _ = rx.recv().await; // Wait for the tx to close
617
618        // Shutdown and cleanup
619        rx.close();
620        handle.await.unwrap();
621        flush_redis(&mut con).unwrap()
622    }
623
624    #[rstest]
625    #[tokio::test(flavor = "multi_thread")]
626    async fn test_stream_messages_when_receiver_closed(redis_connection: redis::Connection) {
627        let mut con = redis_connection;
628        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
629
630        let trader_id = TraderId::from("tester-001");
631        let instance_id = UUID4::new();
632        let mut config = MessageBusConfig::default();
633        config.database = Some(DatabaseConfig::default());
634
635        let stream_key = get_stream_key(trader_id, instance_id, &config);
636        let external_streams = vec![stream_key.clone()];
637        let stream_signal = Arc::new(AtomicBool::new(false));
638        let stream_signal_clone = stream_signal.clone();
639
640        // Use a message ID in the future, as streaming begins
641        // around the timestamp the thread is spawned.
642        let clock = get_atomic_clock_realtime();
643        let future_id = (clock.get_time_ms() + 1_000_000).to_string();
644
645        // Publish test message
646        let _: () = con
647            .xadd(
648                stream_key,
649                future_id,
650                &[("topic", "topic1"), ("payload", "data1")],
651            )
652            .unwrap();
653
654        // Immediately close channel
655        rx.close();
656
657        // Start the message streaming task
658        let handle = tokio::spawn(async move {
659            stream_messages(
660                tx,
661                DatabaseConfig::default(),
662                external_streams,
663                stream_signal_clone,
664            )
665            .await
666            .unwrap();
667        });
668
669        // Shutdown and cleanup
670        handle.await.unwrap();
671        flush_redis(&mut con).unwrap()
672    }
673
674    #[rstest]
675    #[tokio::test(flavor = "multi_thread")]
676    async fn test_stream_messages(redis_connection: redis::Connection) {
677        let mut con = redis_connection;
678        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
679
680        let trader_id = TraderId::from("tester-001");
681        let instance_id = UUID4::new();
682        let mut config = MessageBusConfig::default();
683        config.database = Some(DatabaseConfig::default());
684
685        let stream_key = get_stream_key(trader_id, instance_id, &config);
686        let external_streams = vec![stream_key.clone()];
687        let stream_signal = Arc::new(AtomicBool::new(false));
688        let stream_signal_clone = stream_signal.clone();
689
690        // Use a message ID in the future, as streaming begins
691        // around the timestamp the task is spawned.
692        let clock = get_atomic_clock_realtime();
693        let future_id = (clock.get_time_ms() + 1_000_000).to_string();
694
695        // Publish test message
696        let _: () = con
697            .xadd(
698                stream_key,
699                future_id,
700                &[("topic", "topic1"), ("payload", "data1")],
701            )
702            .unwrap();
703
704        // Start the message streaming task
705        let handle = tokio::spawn(async move {
706            stream_messages(
707                tx,
708                DatabaseConfig::default(),
709                external_streams,
710                stream_signal_clone,
711            )
712            .await
713            .unwrap();
714        });
715
716        // Receive and verify the message
717        let msg = rx.recv().await.unwrap();
718        assert_eq!(msg.topic, "topic1");
719        assert_eq!(msg.payload, Bytes::from("data1"));
720
721        // Shutdown and cleanup
722        rx.close();
723        stream_signal.store(true, Ordering::Relaxed);
724        handle.await.unwrap();
725        flush_redis(&mut con).unwrap()
726    }
727
728    #[rstest]
729    #[tokio::test(flavor = "multi_thread")]
730    async fn test_publish_messages(redis_connection: redis::Connection) {
731        let mut con = redis_connection;
732        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
733
734        let trader_id = TraderId::from("tester-001");
735        let instance_id = UUID4::new();
736        let mut config = MessageBusConfig::default();
737        config.database = Some(DatabaseConfig::default());
738        config.stream_per_topic = false;
739        let stream_key = get_stream_key(trader_id, instance_id, &config);
740
741        // Start the publish_messages task
742        let handle = tokio::spawn(async move {
743            publish_messages(rx, trader_id, instance_id, config)
744                .await
745                .unwrap();
746        });
747
748        // Send a test message
749        let msg = BusMessage {
750            topic: "test_topic".to_string(),
751            payload: Bytes::from("test_payload"),
752        };
753        tx.send(msg).unwrap();
754
755        // Wait until the message is published to Redis
756        wait_until(
757            || {
758                let messages: RedisStreamBulk = con.xread(&[&stream_key], &["0"]).unwrap();
759                !messages.is_empty()
760            },
761            Duration::from_secs(2),
762        );
763
764        // Verify the message was published to Redis
765        let messages: RedisStreamBulk = con.xread(&[&stream_key], &["0"]).unwrap();
766        assert_eq!(messages.len(), 1);
767        let stream_msgs = messages[0].get(&stream_key).unwrap();
768        let stream_msg_array = &stream_msgs[0].values().next().unwrap();
769        let decoded_message = decode_bus_message(stream_msg_array).unwrap();
770        assert_eq!(decoded_message.topic, "test_topic");
771        assert_eq!(decoded_message.payload, Bytes::from("test_payload"));
772
773        // Stop publishing task
774        let msg = BusMessage {
775            topic: CLOSE_TOPIC.to_string(),
776            payload: Bytes::new(), // Empty
777        };
778        tx.send(msg).unwrap();
779
780        // Shutdown and cleanup
781        handle.await.unwrap();
782        flush_redis(&mut con).unwrap();
783    }
784
785    #[rstest]
786    #[tokio::test(flavor = "multi_thread")]
787    async fn test_close() {
788        let trader_id = TraderId::from("tester-001");
789        let instance_id = UUID4::new();
790        let mut config = MessageBusConfig::default();
791        config.database = Some(DatabaseConfig::default());
792
793        let mut db = RedisMessageBusDatabase::new(trader_id, instance_id, config).unwrap();
794
795        // Close the message bus database (test should not hang)
796        db.close();
797    }
798
799    #[rstest]
800    #[tokio::test(flavor = "multi_thread")]
801    async fn test_heartbeat_task() {
802        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
803        let signal = Arc::new(AtomicBool::new(false));
804
805        // Start the heartbeat task with a short interval
806        let handle = tokio::spawn(run_heartbeat(1, signal.clone(), tx));
807
808        // Wait for a couple of heartbeats
809        tokio::time::sleep(Duration::from_secs(2)).await;
810
811        // Stop the heartbeat task
812        signal.store(true, Ordering::Relaxed);
813        handle.await.unwrap();
814
815        // Ensure heartbeats were sent
816        let mut heartbeats: Vec<BusMessage> = Vec::new();
817        while let Ok(hb) = rx.try_recv() {
818            heartbeats.push(hb);
819        }
820
821        assert!(!heartbeats.is_empty());
822
823        for hb in heartbeats {
824            assert_eq!(hb.topic, HEARTBEAT_TOPIC);
825        }
826    }
827}