nautilus_infrastructure/redis/
msgbus.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    collections::{HashMap, VecDeque},
18    sync::{
19        Arc,
20        atomic::{AtomicBool, Ordering},
21    },
22    time::{Duration, Instant},
23};
24
25use bytes::Bytes;
26use futures::stream::Stream;
27use nautilus_common::{
28    msgbus::{
29        CLOSE_TOPIC,
30        database::{BusMessage, DatabaseConfig, MessageBusConfig, MessageBusDatabaseAdapter},
31    },
32    runtime::get_runtime,
33};
34use nautilus_core::{
35    UUID4,
36    time::{duration_since_unix_epoch, get_atomic_clock_realtime},
37};
38use nautilus_cryptography::providers::install_cryptographic_provider;
39use nautilus_model::identifiers::TraderId;
40use redis::*;
41use streams::StreamReadOptions;
42
43use super::{REDIS_MINID, REDIS_XTRIM, await_handle};
44use crate::redis::{create_redis_connection, get_stream_key};
45
46const MSGBUS_PUBLISH: &str = "msgbus-publish";
47const MSGBUS_STREAM: &str = "msgbus-stream";
48const MSGBUS_HEARTBEAT: &str = "msgbus-heartbeat";
49const HEARTBEAT_TOPIC: &str = "health:heartbeat";
50const TRIM_BUFFER_SECS: u64 = 60;
51
52type RedisStreamBulk = Vec<HashMap<String, Vec<HashMap<String, redis::Value>>>>;
53
54#[cfg_attr(
55    feature = "python",
56    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.infrastructure")
57)]
58pub struct RedisMessageBusDatabase {
59    /// The trader ID for this message bus database.
60    pub trader_id: TraderId,
61    /// The instance ID for this message bus database.
62    pub instance_id: UUID4,
63    pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
64    pub_handle: Option<tokio::task::JoinHandle<()>>,
65    stream_rx: Option<tokio::sync::mpsc::Receiver<BusMessage>>,
66    stream_handle: Option<tokio::task::JoinHandle<()>>,
67    stream_signal: Arc<AtomicBool>,
68    heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
69    heartbeat_signal: Arc<AtomicBool>,
70}
71
72impl MessageBusDatabaseAdapter for RedisMessageBusDatabase {
73    type DatabaseType = RedisMessageBusDatabase;
74
75    /// Creates a new [`RedisMessageBusDatabase`] instance.
76    fn new(
77        trader_id: TraderId,
78        instance_id: UUID4,
79        config: MessageBusConfig,
80    ) -> anyhow::Result<Self> {
81        install_cryptographic_provider();
82
83        let config_clone = config.clone();
84        let db_config = config
85            .database
86            .clone()
87            .ok_or_else(|| anyhow::anyhow!("No database config"))?;
88
89        let (pub_tx, pub_rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
90
91        // Create publish task (start the runtime here for now)
92        let pub_handle = Some(get_runtime().spawn(async move {
93            if let Err(e) = publish_messages(pub_rx, trader_id, instance_id, config_clone).await {
94                log::error!("Failed to spawn task '{}': {}", MSGBUS_PUBLISH, e);
95            };
96        }));
97
98        // Conditionally create stream task and channel if external streams configured
99        let external_streams = config.external_streams.clone().unwrap_or_default();
100        let stream_signal = Arc::new(AtomicBool::new(false));
101        let (stream_rx, stream_handle) = if !external_streams.is_empty() {
102            let stream_signal_clone = stream_signal.clone();
103            let (stream_tx, stream_rx) = tokio::sync::mpsc::channel::<BusMessage>(100_000);
104            (
105                Some(stream_rx),
106                Some(get_runtime().spawn(async move {
107                    if let Err(e) =
108                        stream_messages(stream_tx, db_config, external_streams, stream_signal_clone)
109                            .await
110                    {
111                        log::error!("Failed to spawn task '{}': {}", MSGBUS_STREAM, e);
112                    }
113                })),
114            )
115        } else {
116            (None, None)
117        };
118
119        // Create heartbeat task
120        let heartbeat_signal = Arc::new(AtomicBool::new(false));
121        let heartbeat_handle = if let Some(heartbeat_interval_secs) = config.heartbeat_interval_secs
122        {
123            let signal = heartbeat_signal.clone();
124            let pub_tx_clone = pub_tx.clone();
125
126            Some(get_runtime().spawn(async move {
127                run_heartbeat(heartbeat_interval_secs, signal, pub_tx_clone).await
128            }))
129        } else {
130            None
131        };
132
133        Ok(Self {
134            trader_id,
135            instance_id,
136            pub_tx,
137            pub_handle,
138            stream_rx,
139            stream_handle,
140            stream_signal,
141            heartbeat_handle,
142            heartbeat_signal,
143        })
144    }
145
146    /// Returns whether the message bus database adapter publishing channel is closed.
147    #[must_use]
148    fn is_closed(&self) -> bool {
149        self.pub_tx.is_closed()
150    }
151
152    /// Publishes a message with the given `topic` and `payload`.
153    fn publish(&self, topic: String, payload: Bytes) {
154        let msg = BusMessage { topic, payload };
155        if let Err(e) = self.pub_tx.send(msg) {
156            log::error!("Failed to send message: {e}");
157        }
158    }
159
160    /// Closes the message bus database adapter.
161    fn close(&mut self) {
162        log::debug!("Closing");
163
164        self.stream_signal.store(true, Ordering::Relaxed);
165        self.heartbeat_signal.store(true, Ordering::Relaxed);
166
167        if !self.pub_tx.is_closed() {
168            let msg = BusMessage {
169                topic: CLOSE_TOPIC.to_string(),
170                payload: Bytes::new(), // Empty
171            };
172            if let Err(e) = self.pub_tx.send(msg) {
173                log::error!("Failed to send close message: {e:?}");
174            }
175        }
176
177        // Keep close sync for now to avoid async trait method
178        tokio::task::block_in_place(|| {
179            get_runtime().block_on(async {
180                self.close_async().await;
181            });
182        });
183
184        log::debug!("Closed");
185    }
186}
187
188impl RedisMessageBusDatabase {
189    /// Gets the stream receiver for this instance.
190    pub fn get_stream_receiver(
191        &mut self,
192    ) -> anyhow::Result<tokio::sync::mpsc::Receiver<BusMessage>> {
193        self.stream_rx
194            .take()
195            .ok_or_else(|| anyhow::anyhow!("Stream receiver already taken"))
196    }
197
198    /// Streams messages arriving on the stream receiver channel.
199    pub fn stream(
200        mut stream_rx: tokio::sync::mpsc::Receiver<BusMessage>,
201    ) -> impl Stream<Item = BusMessage> + 'static {
202        async_stream::stream! {
203            while let Some(msg) = stream_rx.recv().await {
204                yield msg;
205            }
206        }
207    }
208
209    pub async fn close_async(&mut self) {
210        await_handle(self.pub_handle.take(), MSGBUS_PUBLISH).await;
211        await_handle(self.stream_handle.take(), MSGBUS_STREAM).await;
212        await_handle(self.heartbeat_handle.take(), MSGBUS_HEARTBEAT).await;
213    }
214}
215
216pub async fn publish_messages(
217    mut rx: tokio::sync::mpsc::UnboundedReceiver<BusMessage>,
218    trader_id: TraderId,
219    instance_id: UUID4,
220    config: MessageBusConfig,
221) -> anyhow::Result<()> {
222    tracing::debug!("Starting message publishing");
223
224    let db_config = config
225        .database
226        .as_ref()
227        .ok_or_else(|| anyhow::anyhow!("No database config"))?;
228    let mut con = create_redis_connection(MSGBUS_PUBLISH, db_config.clone()).await?;
229    let stream_key = get_stream_key(trader_id, instance_id, &config);
230
231    // Auto-trimming
232    let autotrim_duration = config
233        .autotrim_mins
234        .filter(|&mins| mins > 0)
235        .map(|mins| Duration::from_secs(mins as u64 * 60));
236    let mut last_trim_index: HashMap<String, usize> = HashMap::new();
237
238    // Buffering
239    let mut buffer: VecDeque<BusMessage> = VecDeque::new();
240    let mut last_drain = Instant::now();
241    let buffer_interval = Duration::from_millis(config.buffer_interval_ms.unwrap_or(0) as u64);
242
243    loop {
244        if last_drain.elapsed() >= buffer_interval && !buffer.is_empty() {
245            drain_buffer(
246                &mut con,
247                &stream_key,
248                config.stream_per_topic,
249                autotrim_duration,
250                &mut last_trim_index,
251                &mut buffer,
252            )
253            .await?;
254            last_drain = Instant::now();
255        } else {
256            match rx.recv().await {
257                Some(msg) => {
258                    if msg.topic == CLOSE_TOPIC {
259                        tracing::debug!("Received close message");
260                        drop(rx);
261                        break;
262                    }
263                    buffer.push_back(msg);
264                }
265                None => {
266                    tracing::debug!("Channel hung up");
267                    break;
268                }
269            }
270        }
271    }
272
273    // Drain any remaining messages
274    if !buffer.is_empty() {
275        drain_buffer(
276            &mut con,
277            &stream_key,
278            config.stream_per_topic,
279            autotrim_duration,
280            &mut last_trim_index,
281            &mut buffer,
282        )
283        .await?;
284    }
285
286    tracing::debug!("Stopped message publishing");
287    Ok(())
288}
289
290async fn drain_buffer(
291    conn: &mut redis::aio::ConnectionManager,
292    stream_key: &str,
293    stream_per_topic: bool,
294    autotrim_duration: Option<Duration>,
295    last_trim_index: &mut HashMap<String, usize>,
296    buffer: &mut VecDeque<BusMessage>,
297) -> anyhow::Result<()> {
298    let mut pipe = redis::pipe();
299    pipe.atomic();
300
301    for msg in buffer.drain(..) {
302        let items: Vec<(&str, &[u8])> = vec![
303            ("topic", msg.topic.as_ref()),
304            ("payload", msg.payload.as_ref()),
305        ];
306        let stream_key = match stream_per_topic {
307            true => format!("{stream_key}:{}", &msg.topic),
308            false => stream_key.to_string(),
309        };
310        pipe.xadd(&stream_key, "*", &items);
311
312        if autotrim_duration.is_none() {
313            continue; // Nothing else to do
314        }
315
316        // Autotrim stream
317        let last_trim_ms = last_trim_index.entry(stream_key.clone()).or_insert(0); // Remove clone
318        let unix_duration_now = duration_since_unix_epoch();
319        let trim_buffer = Duration::from_secs(TRIM_BUFFER_SECS);
320
321        // Improve efficiency of this by batching
322        if *last_trim_ms < (unix_duration_now - trim_buffer).as_millis() as usize {
323            let min_timestamp_ms =
324                (unix_duration_now - autotrim_duration.unwrap()).as_millis() as usize;
325            let result: Result<(), redis::RedisError> = redis::cmd(REDIS_XTRIM)
326                .arg(stream_key.clone())
327                .arg(REDIS_MINID)
328                .arg(min_timestamp_ms)
329                .query_async(conn)
330                .await;
331
332            if let Err(e) = result {
333                tracing::error!("Error trimming stream '{stream_key}': {e}");
334            } else {
335                last_trim_index.insert(
336                    stream_key.to_string(),
337                    unix_duration_now.as_millis() as usize,
338                );
339            }
340        }
341    }
342
343    pipe.query_async(conn).await.map_err(anyhow::Error::from)
344}
345
346pub async fn stream_messages(
347    tx: tokio::sync::mpsc::Sender<BusMessage>,
348    config: DatabaseConfig,
349    stream_keys: Vec<String>,
350    stream_signal: Arc<AtomicBool>,
351) -> anyhow::Result<()> {
352    tracing::info!("Starting message streaming");
353    let mut con = create_redis_connection(MSGBUS_STREAM, config).await?;
354
355    let stream_keys = &stream_keys
356        .iter()
357        .map(String::as_str)
358        .collect::<Vec<&str>>();
359
360    tracing::debug!("Listening to streams: [{}]", stream_keys.join(", "));
361
362    // Start streaming from current timestamp
363    let clock = get_atomic_clock_realtime();
364    let timestamp_ms = clock.get_time_ms();
365    let mut last_id = timestamp_ms.to_string();
366
367    let opts = StreamReadOptions::default().block(100);
368
369    'outer: loop {
370        if stream_signal.load(Ordering::Relaxed) {
371            tracing::debug!("Received streaming terminate signal");
372            break;
373        }
374        let result: Result<RedisStreamBulk, _> =
375            con.xread_options(&[&stream_keys], &[&last_id], &opts).await;
376        match result {
377            Ok(stream_bulk) => {
378                if stream_bulk.is_empty() {
379                    // Timeout occurred: no messages received
380                    continue;
381                }
382                for entry in stream_bulk.iter() {
383                    for (_stream_key, stream_msgs) in entry.iter() {
384                        for stream_msg in stream_msgs.iter() {
385                            for (id, array) in stream_msg {
386                                last_id.clear();
387                                last_id.push_str(id);
388                                match decode_bus_message(array) {
389                                    Ok(msg) => {
390                                        if let Err(e) = tx.send(msg).await {
391                                            tracing::debug!("Channel closed: {e:?}");
392                                            break 'outer; // End streaming
393                                        }
394                                    }
395                                    Err(e) => {
396                                        tracing::error!("{e:?}");
397                                        continue;
398                                    }
399                                }
400                            }
401                        }
402                    }
403                }
404            }
405            Err(e) => {
406                return Err(anyhow::anyhow!("Error reading from stream: {e:?}"));
407            }
408        }
409    }
410
411    tracing::debug!("Stopped message streaming");
412    Ok(())
413}
414
415fn decode_bus_message(stream_msg: &redis::Value) -> anyhow::Result<BusMessage> {
416    if let redis::Value::Array(stream_msg) = stream_msg {
417        if stream_msg.len() < 4 {
418            anyhow::bail!("Invalid stream message format: {stream_msg:?}");
419        }
420
421        let topic = match &stream_msg[1] {
422            redis::Value::BulkString(bytes) => match String::from_utf8(bytes.clone()) {
423                Ok(topic) => topic,
424                Err(e) => anyhow::bail!("Error parsing topic: {}", e),
425            },
426            _ => {
427                anyhow::bail!("Invalid topic format: {stream_msg:?}");
428            }
429        };
430
431        let payload = match &stream_msg[3] {
432            redis::Value::BulkString(bytes) => Bytes::copy_from_slice(bytes),
433            _ => {
434                anyhow::bail!("Invalid payload format: {stream_msg:?}");
435            }
436        };
437
438        Ok(BusMessage { topic, payload })
439    } else {
440        anyhow::bail!("Invalid stream message format: {stream_msg:?}")
441    }
442}
443
444async fn run_heartbeat(
445    heartbeat_interval_secs: u16,
446    signal: Arc<AtomicBool>,
447    pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
448) {
449    tracing::debug!("Starting heartbeat at {heartbeat_interval_secs} second intervals");
450
451    let heartbeat_interval = Duration::from_secs(heartbeat_interval_secs as u64);
452    let heartbeat_timer = tokio::time::interval(heartbeat_interval);
453
454    let check_interval = Duration::from_millis(100);
455    let check_timer = tokio::time::interval(check_interval);
456
457    tokio::pin!(heartbeat_timer);
458    tokio::pin!(check_timer);
459
460    loop {
461        if signal.load(Ordering::Relaxed) {
462            tracing::debug!("Received heartbeat terminate signal");
463            break;
464        }
465
466        tokio::select! {
467            _ = heartbeat_timer.tick() => {
468                let heartbeat = create_heartbeat_msg();
469                if let Err(e) = pub_tx.send(heartbeat) {
470                    // We expect an error if the channel is closed during shutdown
471                    tracing::debug!("Error sending heartbeat: {e}");
472                }
473            },
474            _ = check_timer.tick() => {}
475        }
476    }
477
478    tracing::debug!("Stopped heartbeat");
479}
480
481fn create_heartbeat_msg() -> BusMessage {
482    BusMessage {
483        topic: HEARTBEAT_TOPIC.to_string(),
484        payload: Bytes::from(chrono::Utc::now().to_rfc3339().into_bytes()),
485    }
486}
487
488////////////////////////////////////////////////////////////////////////////////
489// Tests
490////////////////////////////////////////////////////////////////////////////////
491#[cfg(test)]
492mod tests {
493    use redis::Value;
494    use rstest::*;
495
496    use super::*;
497
498    #[rstest]
499    fn test_decode_bus_message_valid() {
500        let stream_msg = Value::Array(vec![
501            Value::BulkString(b"0".to_vec()),
502            Value::BulkString(b"topic1".to_vec()),
503            Value::BulkString(b"unused".to_vec()),
504            Value::BulkString(b"data1".to_vec()),
505        ]);
506
507        let result = decode_bus_message(&stream_msg);
508        assert!(result.is_ok());
509        let msg = result.unwrap();
510        assert_eq!(msg.topic, "topic1");
511        assert_eq!(msg.payload, Bytes::from("data1"));
512    }
513
514    #[rstest]
515    fn test_decode_bus_message_missing_fields() {
516        let stream_msg = Value::Array(vec![
517            Value::BulkString(b"0".to_vec()),
518            Value::BulkString(b"topic1".to_vec()),
519        ]);
520
521        let result = decode_bus_message(&stream_msg);
522        assert!(result.is_err());
523        assert_eq!(
524            format!("{}", result.unwrap_err()),
525            "Invalid stream message format: [bulk-string('\"0\"'), bulk-string('\"topic1\"')]"
526        );
527    }
528
529    #[rstest]
530    fn test_decode_bus_message_invalid_topic_format() {
531        let stream_msg = Value::Array(vec![
532            Value::BulkString(b"0".to_vec()),
533            Value::Int(42), // Invalid topic format
534            Value::BulkString(b"unused".to_vec()),
535            Value::BulkString(b"data1".to_vec()),
536        ]);
537
538        let result = decode_bus_message(&stream_msg);
539        assert!(result.is_err());
540        assert_eq!(
541            format!("{}", result.unwrap_err()),
542            "Invalid topic format: [bulk-string('\"0\"'), int(42), bulk-string('\"unused\"'), bulk-string('\"data1\"')]"
543        );
544    }
545
546    #[rstest]
547    fn test_decode_bus_message_invalid_payload_format() {
548        let stream_msg = Value::Array(vec![
549            Value::BulkString(b"0".to_vec()),
550            Value::BulkString(b"topic1".to_vec()),
551            Value::BulkString(b"unused".to_vec()),
552            Value::Int(42), // Invalid payload format
553        ]);
554
555        let result = decode_bus_message(&stream_msg);
556        assert!(result.is_err());
557        assert_eq!(
558            format!("{}", result.unwrap_err()),
559            "Invalid payload format: [bulk-string('\"0\"'), bulk-string('\"topic1\"'), bulk-string('\"unused\"'), int(42)]"
560        );
561    }
562
563    #[rstest]
564    fn test_decode_bus_message_invalid_stream_msg_format() {
565        let stream_msg = Value::BulkString(b"not an array".to_vec());
566
567        let result = decode_bus_message(&stream_msg);
568        assert!(result.is_err());
569        assert_eq!(
570            format!("{}", result.unwrap_err()),
571            "Invalid stream message format: bulk-string('\"not an array\"')"
572        );
573    }
574}
575
576#[cfg(target_os = "linux")] // Run Redis tests on Linux platforms only
577#[cfg(test)]
578mod serial_tests {
579    use nautilus_common::testing::wait_until_async;
580    use redis::aio::ConnectionManager;
581    use rstest::*;
582
583    use super::*;
584    use crate::redis::flush_redis;
585
586    #[fixture]
587    async fn redis_connection() -> ConnectionManager {
588        let config = DatabaseConfig::default();
589        let mut con = create_redis_connection(MSGBUS_STREAM, config)
590            .await
591            .unwrap();
592        flush_redis(&mut con).await.unwrap();
593        con
594    }
595
596    #[rstest]
597    #[tokio::test(flavor = "multi_thread")]
598    async fn test_stream_messages_terminate_signal(#[future] redis_connection: ConnectionManager) {
599        let mut con = redis_connection.await;
600        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
601
602        let trader_id = TraderId::from("tester-001");
603        let instance_id = UUID4::new();
604        let mut config = MessageBusConfig::default();
605        config.database = Some(DatabaseConfig::default());
606
607        let stream_key = get_stream_key(trader_id, instance_id, &config);
608        let external_streams = vec![stream_key.clone()];
609        let stream_signal = Arc::new(AtomicBool::new(false));
610        let stream_signal_clone = stream_signal.clone();
611
612        // Start the message streaming task
613        let handle = tokio::spawn(async move {
614            stream_messages(
615                tx,
616                DatabaseConfig::default(),
617                external_streams,
618                stream_signal_clone,
619            )
620            .await
621            .unwrap();
622        });
623
624        stream_signal.store(true, Ordering::Relaxed);
625        let _ = rx.recv().await; // Wait for the tx to close
626
627        // Shutdown and cleanup
628        rx.close();
629        handle.await.unwrap();
630        flush_redis(&mut con).await.unwrap()
631    }
632
633    #[rstest]
634    #[tokio::test(flavor = "multi_thread")]
635    async fn test_stream_messages_when_receiver_closed(
636        #[future] redis_connection: ConnectionManager,
637    ) {
638        let mut con = redis_connection.await;
639        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
640
641        let trader_id = TraderId::from("tester-001");
642        let instance_id = UUID4::new();
643        let mut config = MessageBusConfig::default();
644        config.database = Some(DatabaseConfig::default());
645
646        let stream_key = get_stream_key(trader_id, instance_id, &config);
647        let external_streams = vec![stream_key.clone()];
648        let stream_signal = Arc::new(AtomicBool::new(false));
649        let stream_signal_clone = stream_signal.clone();
650
651        // Use a message ID in the future, as streaming begins
652        // around the timestamp the thread is spawned.
653        let clock = get_atomic_clock_realtime();
654        let future_id = (clock.get_time_ms() + 1_000_000).to_string();
655
656        // Publish test message
657        let _: () = con
658            .xadd(
659                stream_key,
660                future_id,
661                &[("topic", "topic1"), ("payload", "data1")],
662            )
663            .await
664            .unwrap();
665
666        // Immediately close channel
667        rx.close();
668
669        // Start the message streaming task
670        let handle = tokio::spawn(async move {
671            stream_messages(
672                tx,
673                DatabaseConfig::default(),
674                external_streams,
675                stream_signal_clone,
676            )
677            .await
678            .unwrap();
679        });
680
681        // Shutdown and cleanup
682        handle.await.unwrap();
683        flush_redis(&mut con).await.unwrap()
684    }
685
686    #[rstest]
687    #[tokio::test(flavor = "multi_thread")]
688    async fn test_stream_messages(#[future] redis_connection: ConnectionManager) {
689        let mut con = redis_connection.await;
690        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
691
692        let trader_id = TraderId::from("tester-001");
693        let instance_id = UUID4::new();
694        let mut config = MessageBusConfig::default();
695        config.database = Some(DatabaseConfig::default());
696
697        let stream_key = get_stream_key(trader_id, instance_id, &config);
698        let external_streams = vec![stream_key.clone()];
699        let stream_signal = Arc::new(AtomicBool::new(false));
700        let stream_signal_clone = stream_signal.clone();
701
702        // Use a message ID in the future, as streaming begins
703        // around the timestamp the task is spawned.
704        let clock = get_atomic_clock_realtime();
705        let future_id = (clock.get_time_ms() + 1_000_000).to_string();
706
707        // Publish test message
708        let _: () = con
709            .xadd(
710                stream_key,
711                future_id,
712                &[("topic", "topic1"), ("payload", "data1")],
713            )
714            .await
715            .unwrap();
716
717        // Start the message streaming task
718        let handle = tokio::spawn(async move {
719            stream_messages(
720                tx,
721                DatabaseConfig::default(),
722                external_streams,
723                stream_signal_clone,
724            )
725            .await
726            .unwrap();
727        });
728
729        // Receive and verify the message
730        let msg = rx.recv().await.unwrap();
731        assert_eq!(msg.topic, "topic1");
732        assert_eq!(msg.payload, Bytes::from("data1"));
733
734        // Shutdown and cleanup
735        rx.close();
736        stream_signal.store(true, Ordering::Relaxed);
737        handle.await.unwrap();
738        flush_redis(&mut con).await.unwrap()
739    }
740
741    #[rstest]
742    #[tokio::test(flavor = "multi_thread")]
743    async fn test_publish_messages(#[future] redis_connection: ConnectionManager) {
744        let mut con = redis_connection.await;
745        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
746
747        let trader_id = TraderId::from("tester-001");
748        let instance_id = UUID4::new();
749        let mut config = MessageBusConfig::default();
750        config.database = Some(DatabaseConfig::default());
751        config.stream_per_topic = false;
752        let stream_key = get_stream_key(trader_id, instance_id, &config);
753
754        // Start the publish_messages task
755        let handle = tokio::spawn(async move {
756            publish_messages(rx, trader_id, instance_id, config)
757                .await
758                .unwrap();
759        });
760
761        // Send a test message
762        let msg = BusMessage {
763            topic: "test_topic".to_string(),
764            payload: Bytes::from("test_payload"),
765        };
766        tx.send(msg).unwrap();
767
768        // Wait until the message is published to Redis
769        wait_until_async(
770            || {
771                let mut con = con.clone();
772                let stream_key = stream_key.clone();
773                async move {
774                    let messages: RedisStreamBulk =
775                        con.xread(&[&stream_key], &["0"]).await.unwrap();
776                    !messages.is_empty()
777                }
778            },
779            Duration::from_secs(2),
780        )
781        .await;
782
783        // Verify the message was published to Redis
784        let messages: RedisStreamBulk = con.xread(&[&stream_key], &["0"]).await.unwrap();
785        assert_eq!(messages.len(), 1);
786        let stream_msgs = messages[0].get(&stream_key).unwrap();
787        let stream_msg_array = &stream_msgs[0].values().next().unwrap();
788        let decoded_message = decode_bus_message(stream_msg_array).unwrap();
789        assert_eq!(decoded_message.topic, "test_topic");
790        assert_eq!(decoded_message.payload, Bytes::from("test_payload"));
791
792        // Stop publishing task
793        let msg = BusMessage {
794            topic: CLOSE_TOPIC.to_string(),
795            payload: Bytes::new(), // Empty
796        };
797        tx.send(msg).unwrap();
798
799        // Shutdown and cleanup
800        handle.await.unwrap();
801        flush_redis(&mut con).await.unwrap();
802    }
803
804    #[rstest]
805    #[tokio::test(flavor = "multi_thread")]
806    async fn test_close() {
807        let trader_id = TraderId::from("tester-001");
808        let instance_id = UUID4::new();
809        let mut config = MessageBusConfig::default();
810        config.database = Some(DatabaseConfig::default());
811
812        let mut db = RedisMessageBusDatabase::new(trader_id, instance_id, config).unwrap();
813
814        // Close the message bus database (test should not hang)
815        db.close();
816    }
817
818    #[rstest]
819    #[tokio::test(flavor = "multi_thread")]
820    async fn test_heartbeat_task() {
821        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
822        let signal = Arc::new(AtomicBool::new(false));
823
824        // Start the heartbeat task with a short interval
825        let handle = tokio::spawn(run_heartbeat(1, signal.clone(), tx));
826
827        // Wait for a couple of heartbeats
828        tokio::time::sleep(Duration::from_secs(2)).await;
829
830        // Stop the heartbeat task
831        signal.store(true, Ordering::Relaxed);
832        handle.await.unwrap();
833
834        // Ensure heartbeats were sent
835        let mut heartbeats: Vec<BusMessage> = Vec::new();
836        while let Ok(hb) = rx.try_recv() {
837            heartbeats.push(hb);
838        }
839
840        assert!(!heartbeats.is_empty());
841
842        for hb in heartbeats {
843            assert_eq!(hb.topic, HEARTBEAT_TOPIC);
844        }
845    }
846}