nautilus_infrastructure/redis/
msgbus.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    collections::{HashMap, VecDeque},
18    fmt::Debug,
19    sync::{
20        Arc,
21        atomic::{AtomicBool, Ordering},
22    },
23    time::Duration,
24};
25
26use bytes::Bytes;
27use futures::stream::Stream;
28use nautilus_common::{
29    logging::{log_task_error, log_task_started, log_task_stopped},
30    msgbus::{
31        BusMessage,
32        database::{DatabaseConfig, MessageBusConfig, MessageBusDatabaseAdapter},
33        switchboard::CLOSE_TOPIC,
34    },
35    runtime::get_runtime,
36};
37use nautilus_core::{
38    UUID4,
39    time::{duration_since_unix_epoch, get_atomic_clock_realtime},
40};
41use nautilus_cryptography::providers::install_cryptographic_provider;
42use nautilus_model::identifiers::TraderId;
43use redis::{AsyncCommands, streams};
44use streams::StreamReadOptions;
45use tokio::time::Instant;
46use ustr::Ustr;
47
48use super::{REDIS_MINID, REDIS_XTRIM, await_handle};
49use crate::redis::{create_redis_connection, get_stream_key};
50
51const MSGBUS_PUBLISH: &str = "msgbus-publish";
52const MSGBUS_STREAM: &str = "msgbus-stream";
53const MSGBUS_HEARTBEAT: &str = "msgbus-heartbeat";
54const HEARTBEAT_TOPIC: &str = "health:heartbeat";
55const TRIM_BUFFER_SECS: u64 = 60;
56
57type RedisStreamBulk = Vec<HashMap<String, Vec<HashMap<String, redis::Value>>>>;
58
59#[cfg_attr(
60    feature = "python",
61    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.infrastructure")
62)]
63pub struct RedisMessageBusDatabase {
64    /// The trader ID for this message bus database.
65    pub trader_id: TraderId,
66    /// The instance ID for this message bus database.
67    pub instance_id: UUID4,
68    pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
69    pub_handle: Option<tokio::task::JoinHandle<()>>,
70    stream_rx: Option<tokio::sync::mpsc::Receiver<BusMessage>>,
71    stream_handle: Option<tokio::task::JoinHandle<()>>,
72    stream_signal: Arc<AtomicBool>,
73    heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
74    heartbeat_signal: Arc<AtomicBool>,
75}
76
77impl Debug for RedisMessageBusDatabase {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct(stringify!(RedisMessageBusDatabase))
80            .field("trader_id", &self.trader_id)
81            .field("instance_id", &self.instance_id)
82            .finish()
83    }
84}
85
86impl MessageBusDatabaseAdapter for RedisMessageBusDatabase {
87    type DatabaseType = Self;
88
89    /// Creates a new [`RedisMessageBusDatabase`] instance for the given `trader_id`, `instance_id`, and `config`.
90    ///
91    /// # Errors
92    ///
93    /// Returns an error if:
94    /// - The database configuration is missing in `config`.
95    /// - Establishing the Redis connection for publishing fails.
96    fn new(
97        trader_id: TraderId,
98        instance_id: UUID4,
99        config: MessageBusConfig,
100    ) -> anyhow::Result<Self> {
101        install_cryptographic_provider();
102
103        let config_clone = config.clone();
104        let db_config = config
105            .database
106            .clone()
107            .ok_or_else(|| anyhow::anyhow!("No database config"))?;
108
109        let (pub_tx, pub_rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
110
111        // Create publish task (start the runtime here for now)
112        let pub_handle = Some(get_runtime().spawn(async move {
113            if let Err(e) = publish_messages(pub_rx, trader_id, instance_id, config_clone).await {
114                log_task_error(MSGBUS_PUBLISH, &e);
115            }
116        }));
117
118        // Conditionally create stream task and channel if external streams configured
119        let external_streams = config.external_streams.clone().unwrap_or_default();
120        let stream_signal = Arc::new(AtomicBool::new(false));
121        let (stream_rx, stream_handle) = if external_streams.is_empty() {
122            (None, None)
123        } else {
124            let stream_signal_clone = stream_signal.clone();
125            let (stream_tx, stream_rx) = tokio::sync::mpsc::channel::<BusMessage>(100_000);
126            (
127                Some(stream_rx),
128                Some(get_runtime().spawn(async move {
129                    if let Err(e) =
130                        stream_messages(stream_tx, db_config, external_streams, stream_signal_clone)
131                            .await
132                    {
133                        log_task_error(MSGBUS_STREAM, &e);
134                    }
135                })),
136            )
137        };
138
139        // Create heartbeat task
140        let heartbeat_signal = Arc::new(AtomicBool::new(false));
141        let heartbeat_handle = if let Some(heartbeat_interval_secs) = config.heartbeat_interval_secs
142        {
143            let signal = heartbeat_signal.clone();
144            let pub_tx_clone = pub_tx.clone();
145
146            Some(get_runtime().spawn(async move {
147                run_heartbeat(heartbeat_interval_secs, signal, pub_tx_clone).await;
148            }))
149        } else {
150            None
151        };
152
153        Ok(Self {
154            trader_id,
155            instance_id,
156            pub_tx,
157            pub_handle,
158            stream_rx,
159            stream_handle,
160            stream_signal,
161            heartbeat_handle,
162            heartbeat_signal,
163        })
164    }
165
166    /// Returns whether the message bus database adapter publishing channel is closed.
167    fn is_closed(&self) -> bool {
168        self.pub_tx.is_closed()
169    }
170
171    /// Publishes a message with the given `topic` and `payload`.
172    fn publish(&self, topic: Ustr, payload: Bytes) {
173        let msg = BusMessage::new(topic, payload);
174        if let Err(e) = self.pub_tx.send(msg) {
175            log::error!("Failed to send message: {e}");
176        }
177    }
178
179    /// Closes the message bus database adapter.
180    fn close(&mut self) {
181        log::debug!("Closing");
182
183        self.stream_signal.store(true, Ordering::Relaxed);
184        self.heartbeat_signal.store(true, Ordering::Relaxed);
185
186        if !self.pub_tx.is_closed() {
187            let msg = BusMessage::new_close();
188
189            if let Err(e) = self.pub_tx.send(msg) {
190                log::error!("Failed to send close message: {e:?}");
191            }
192        }
193
194        // Keep close sync for now to avoid async trait method
195        tokio::task::block_in_place(|| {
196            get_runtime().block_on(async {
197                self.close_async().await;
198            });
199        });
200
201        log::debug!("Closed");
202    }
203}
204
205impl RedisMessageBusDatabase {
206    /// Retrieves the Redis stream receiver for this message bus instance.
207    ///
208    /// # Errors
209    ///
210    /// Returns an error if the stream receiver has already been taken.
211    pub fn get_stream_receiver(
212        &mut self,
213    ) -> anyhow::Result<tokio::sync::mpsc::Receiver<BusMessage>> {
214        self.stream_rx
215            .take()
216            .ok_or_else(|| anyhow::anyhow!("Stream receiver already taken"))
217    }
218
219    /// Streams messages arriving on the stream receiver channel.
220    pub fn stream(
221        mut stream_rx: tokio::sync::mpsc::Receiver<BusMessage>,
222    ) -> impl Stream<Item = BusMessage> + 'static {
223        async_stream::stream! {
224            while let Some(msg) = stream_rx.recv().await {
225                yield msg;
226            }
227        }
228    }
229
230    pub async fn close_async(&mut self) {
231        await_handle(self.pub_handle.take(), MSGBUS_PUBLISH).await;
232        await_handle(self.stream_handle.take(), MSGBUS_STREAM).await;
233        await_handle(self.heartbeat_handle.take(), MSGBUS_HEARTBEAT).await;
234    }
235}
236
237/// Publishes messages received on `rx` to Redis streams for the given `trader_id` and `instance_id`, using `config`.
238///
239/// # Errors
240///
241/// Returns an error if:
242/// - The database configuration is missing in `config`.
243/// - Establishing the Redis connection fails.
244/// - Any Redis command fails during publishing.
245pub async fn publish_messages(
246    mut rx: tokio::sync::mpsc::UnboundedReceiver<BusMessage>,
247    trader_id: TraderId,
248    instance_id: UUID4,
249    config: MessageBusConfig,
250) -> anyhow::Result<()> {
251    log_task_started(MSGBUS_PUBLISH);
252
253    let db_config = config
254        .database
255        .as_ref()
256        .ok_or_else(|| anyhow::anyhow!("No database config"))?;
257    let mut con = create_redis_connection(MSGBUS_PUBLISH, db_config.clone()).await?;
258    let stream_key = get_stream_key(trader_id, instance_id, &config);
259
260    // Auto-trimming
261    let autotrim_duration = config
262        .autotrim_mins
263        .filter(|&mins| mins > 0)
264        .map(|mins| Duration::from_secs(u64::from(mins) * 60));
265    let mut last_trim_index: HashMap<String, usize> = HashMap::new();
266
267    // Buffering
268    let mut buffer: VecDeque<BusMessage> = VecDeque::new();
269    let buffer_interval = Duration::from_millis(u64::from(config.buffer_interval_ms.unwrap_or(0)));
270
271    // A sleep used to trigger periodic flushing of the buffer.
272    // When `buffer_interval` is zero we skip using the timer and flush immediately
273    // after every message.
274    let flush_timer = tokio::time::sleep(buffer_interval);
275    tokio::pin!(flush_timer);
276
277    loop {
278        tokio::select! {
279            maybe_msg = rx.recv() => {
280                if let Some(msg) = maybe_msg {
281                    if msg.topic == CLOSE_TOPIC {
282                        tracing::debug!("Received close message");
283                        // Ensure we exit the loop after flushing any remaining messages.
284                        if !buffer.is_empty() {
285                            drain_buffer(
286                                &mut con,
287                                &stream_key,
288                                config.stream_per_topic,
289                                autotrim_duration,
290                                &mut last_trim_index,
291                                &mut buffer,
292                            ).await?;
293                        }
294                        break;
295                    }
296
297                    buffer.push_back(msg);
298
299                    if buffer_interval.is_zero() {
300                        // Immediate flush mode
301                        drain_buffer(
302                            &mut con,
303                            &stream_key,
304                            config.stream_per_topic,
305                            autotrim_duration,
306                            &mut last_trim_index,
307                            &mut buffer,
308                        ).await?;
309                    }
310                } else {
311                    tracing::debug!("Channel hung up");
312                    break;
313                }
314            }
315            // Only poll the timer when the interval is non-zero. This avoids
316            // unnecessarily waking the task when immediate flushing is enabled.
317            () = &mut flush_timer, if !buffer_interval.is_zero() => {
318                if !buffer.is_empty() {
319                    drain_buffer(
320                        &mut con,
321                        &stream_key,
322                        config.stream_per_topic,
323                        autotrim_duration,
324                        &mut last_trim_index,
325                        &mut buffer,
326                    ).await?;
327                }
328
329                // Schedule the next tick
330                flush_timer.as_mut().reset(Instant::now() + buffer_interval);
331            }
332        }
333    }
334
335    // Drain any remaining messages
336    if !buffer.is_empty() {
337        drain_buffer(
338            &mut con,
339            &stream_key,
340            config.stream_per_topic,
341            autotrim_duration,
342            &mut last_trim_index,
343            &mut buffer,
344        )
345        .await?;
346    }
347
348    log_task_stopped(MSGBUS_PUBLISH);
349    Ok(())
350}
351
352async fn drain_buffer(
353    conn: &mut redis::aio::ConnectionManager,
354    stream_key: &str,
355    stream_per_topic: bool,
356    autotrim_duration: Option<Duration>,
357    last_trim_index: &mut HashMap<String, usize>,
358    buffer: &mut VecDeque<BusMessage>,
359) -> anyhow::Result<()> {
360    let mut pipe = redis::pipe();
361    pipe.atomic();
362
363    for msg in buffer.drain(..) {
364        let items: Vec<(&str, &[u8])> = vec![
365            ("topic", msg.topic.as_ref()),
366            ("payload", msg.payload.as_ref()),
367        ];
368        let stream_key = if stream_per_topic {
369            format!("{stream_key}:{}", &msg.topic)
370        } else {
371            stream_key.to_string()
372        };
373        pipe.xadd(&stream_key, "*", &items);
374
375        if autotrim_duration.is_none() {
376            continue; // Nothing else to do
377        }
378
379        // Autotrim stream
380        let last_trim_ms = last_trim_index.entry(stream_key.clone()).or_insert(0); // Remove clone
381        let unix_duration_now = duration_since_unix_epoch();
382        let trim_buffer = Duration::from_secs(TRIM_BUFFER_SECS);
383
384        // Improve efficiency of this by batching
385        if *last_trim_ms < (unix_duration_now - trim_buffer).as_millis() as usize {
386            let min_timestamp_ms =
387                (unix_duration_now - autotrim_duration.unwrap()).as_millis() as usize;
388            let result: Result<(), redis::RedisError> = redis::cmd(REDIS_XTRIM)
389                .arg(stream_key.clone())
390                .arg(REDIS_MINID)
391                .arg(min_timestamp_ms)
392                .query_async(conn)
393                .await;
394
395            if let Err(e) = result {
396                tracing::error!("Error trimming stream '{stream_key}': {e}");
397            } else {
398                last_trim_index.insert(
399                    stream_key.to_string(),
400                    unix_duration_now.as_millis() as usize,
401                );
402            }
403        }
404    }
405
406    pipe.query_async(conn).await.map_err(anyhow::Error::from)
407}
408
409/// Streams messages from Redis streams and sends them over the provided `tx` channel.
410///
411/// # Errors
412///
413/// Returns an error if:
414/// - Establishing the Redis connection fails.
415/// - Any Redis read operation fails.
416pub async fn stream_messages(
417    tx: tokio::sync::mpsc::Sender<BusMessage>,
418    config: DatabaseConfig,
419    stream_keys: Vec<String>,
420    stream_signal: Arc<AtomicBool>,
421) -> anyhow::Result<()> {
422    log_task_started(MSGBUS_STREAM);
423
424    let mut con = create_redis_connection(MSGBUS_STREAM, config).await?;
425
426    let stream_keys = &stream_keys
427        .iter()
428        .map(String::as_str)
429        .collect::<Vec<&str>>();
430
431    tracing::debug!("Listening to streams: [{}]", stream_keys.join(", "));
432
433    // Start streaming from current timestamp
434    let clock = get_atomic_clock_realtime();
435    let timestamp_ms = clock.get_time_ms();
436    let mut last_id = timestamp_ms.to_string();
437
438    let opts = StreamReadOptions::default().block(100);
439
440    'outer: loop {
441        if stream_signal.load(Ordering::Relaxed) {
442            tracing::debug!("Received streaming terminate signal");
443            break;
444        }
445        let result: Result<RedisStreamBulk, _> =
446            con.xread_options(&[&stream_keys], &[&last_id], &opts).await;
447        match result {
448            Ok(stream_bulk) => {
449                if stream_bulk.is_empty() {
450                    // Timeout occurred: no messages received
451                    continue;
452                }
453                for entry in &stream_bulk {
454                    for stream_msgs in entry.values() {
455                        for stream_msg in stream_msgs {
456                            for (id, array) in stream_msg {
457                                last_id.clear();
458                                last_id.push_str(id);
459                                match decode_bus_message(array) {
460                                    Ok(msg) => {
461                                        if let Err(e) = tx.send(msg).await {
462                                            tracing::debug!("Channel closed: {e:?}");
463                                            break 'outer; // End streaming
464                                        }
465                                    }
466                                    Err(e) => {
467                                        tracing::error!("{e:?}");
468                                        continue;
469                                    }
470                                }
471                            }
472                        }
473                    }
474                }
475            }
476            Err(e) => {
477                anyhow::bail!("Error reading from stream: {e:?}");
478            }
479        }
480    }
481
482    log_task_stopped(MSGBUS_STREAM);
483    Ok(())
484}
485
486/// Decodes a Redis stream message value into a `BusMessage`.
487///
488/// # Errors
489///
490/// Returns an error if:
491/// - The incoming `stream_msg` is not an array.
492/// - The array has fewer than four elements (invalid format).
493/// - Parsing the topic or payload fails.
494fn decode_bus_message(stream_msg: &redis::Value) -> anyhow::Result<BusMessage> {
495    if let redis::Value::Array(stream_msg) = stream_msg {
496        if stream_msg.len() < 4 {
497            anyhow::bail!("Invalid stream message format: {stream_msg:?}");
498        }
499
500        let topic = match &stream_msg[1] {
501            redis::Value::BulkString(bytes) => match String::from_utf8(bytes.clone()) {
502                Ok(topic) => topic,
503                Err(e) => anyhow::bail!("Error parsing topic: {e}"),
504            },
505            _ => {
506                anyhow::bail!("Invalid topic format: {stream_msg:?}");
507            }
508        };
509
510        let payload = match &stream_msg[3] {
511            redis::Value::BulkString(bytes) => Bytes::copy_from_slice(bytes),
512            _ => {
513                anyhow::bail!("Invalid payload format: {stream_msg:?}");
514            }
515        };
516
517        Ok(BusMessage::with_str_topic(topic, payload))
518    } else {
519        anyhow::bail!("Invalid stream message format: {stream_msg:?}")
520    }
521}
522
523async fn run_heartbeat(
524    heartbeat_interval_secs: u16,
525    signal: Arc<AtomicBool>,
526    pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
527) {
528    log_task_started("heartbeat");
529    tracing::debug!("Heartbeat at {heartbeat_interval_secs} second intervals");
530
531    let heartbeat_interval = Duration::from_secs(u64::from(heartbeat_interval_secs));
532    let heartbeat_timer = tokio::time::interval(heartbeat_interval);
533
534    let check_interval = Duration::from_millis(100);
535    let check_timer = tokio::time::interval(check_interval);
536
537    tokio::pin!(heartbeat_timer);
538    tokio::pin!(check_timer);
539
540    loop {
541        if signal.load(Ordering::Relaxed) {
542            tracing::debug!("Received heartbeat terminate signal");
543            break;
544        }
545
546        tokio::select! {
547            _ = heartbeat_timer.tick() => {
548                let heartbeat = create_heartbeat_msg();
549                if let Err(e) = pub_tx.send(heartbeat) {
550                    // We expect an error if the channel is closed during shutdown
551                    tracing::debug!("Error sending heartbeat: {e}");
552                }
553            },
554            _ = check_timer.tick() => {}
555        }
556    }
557
558    log_task_stopped("heartbeat");
559}
560
561fn create_heartbeat_msg() -> BusMessage {
562    let payload = Bytes::from(chrono::Utc::now().to_rfc3339().into_bytes());
563    BusMessage::with_str_topic(HEARTBEAT_TOPIC, payload)
564}
565
566////////////////////////////////////////////////////////////////////////////////
567// Tests
568////////////////////////////////////////////////////////////////////////////////
569#[cfg(test)]
570mod tests {
571    use redis::Value;
572    use rstest::*;
573
574    use super::*;
575
576    #[rstest]
577    fn test_decode_bus_message_valid() {
578        let stream_msg = Value::Array(vec![
579            Value::BulkString(b"0".to_vec()),
580            Value::BulkString(b"topic1".to_vec()),
581            Value::BulkString(b"unused".to_vec()),
582            Value::BulkString(b"data1".to_vec()),
583        ]);
584
585        let result = decode_bus_message(&stream_msg);
586        assert!(result.is_ok());
587        let msg = result.unwrap();
588        assert_eq!(msg.topic, "topic1");
589        assert_eq!(msg.payload, Bytes::from("data1"));
590    }
591
592    #[rstest]
593    fn test_decode_bus_message_missing_fields() {
594        let stream_msg = Value::Array(vec![
595            Value::BulkString(b"0".to_vec()),
596            Value::BulkString(b"topic1".to_vec()),
597        ]);
598
599        let result = decode_bus_message(&stream_msg);
600        assert!(result.is_err());
601        assert_eq!(
602            format!("{}", result.unwrap_err()),
603            "Invalid stream message format: [bulk-string('\"0\"'), bulk-string('\"topic1\"')]"
604        );
605    }
606
607    #[rstest]
608    fn test_decode_bus_message_invalid_topic_format() {
609        let stream_msg = Value::Array(vec![
610            Value::BulkString(b"0".to_vec()),
611            Value::Int(42), // Invalid topic format
612            Value::BulkString(b"unused".to_vec()),
613            Value::BulkString(b"data1".to_vec()),
614        ]);
615
616        let result = decode_bus_message(&stream_msg);
617        assert!(result.is_err());
618        assert_eq!(
619            format!("{}", result.unwrap_err()),
620            "Invalid topic format: [bulk-string('\"0\"'), int(42), bulk-string('\"unused\"'), bulk-string('\"data1\"')]"
621        );
622    }
623
624    #[rstest]
625    fn test_decode_bus_message_invalid_payload_format() {
626        let stream_msg = Value::Array(vec![
627            Value::BulkString(b"0".to_vec()),
628            Value::BulkString(b"topic1".to_vec()),
629            Value::BulkString(b"unused".to_vec()),
630            Value::Int(42), // Invalid payload format
631        ]);
632
633        let result = decode_bus_message(&stream_msg);
634        assert!(result.is_err());
635        assert_eq!(
636            format!("{}", result.unwrap_err()),
637            "Invalid payload format: [bulk-string('\"0\"'), bulk-string('\"topic1\"'), bulk-string('\"unused\"'), int(42)]"
638        );
639    }
640
641    #[rstest]
642    fn test_decode_bus_message_invalid_stream_msg_format() {
643        let stream_msg = Value::BulkString(b"not an array".to_vec());
644
645        let result = decode_bus_message(&stream_msg);
646        assert!(result.is_err());
647        assert_eq!(
648            format!("{}", result.unwrap_err()),
649            "Invalid stream message format: bulk-string('\"not an array\"')"
650        );
651    }
652}
653
654#[cfg(target_os = "linux")] // Run Redis tests on Linux platforms only
655#[cfg(test)]
656mod serial_tests {
657    use nautilus_common::testing::wait_until_async;
658    use redis::aio::ConnectionManager;
659    use rstest::*;
660
661    use super::*;
662    use crate::redis::flush_redis;
663
664    #[fixture]
665    async fn redis_connection() -> ConnectionManager {
666        let config = DatabaseConfig::default();
667        let mut con = create_redis_connection(MSGBUS_STREAM, config)
668            .await
669            .unwrap();
670        flush_redis(&mut con).await.unwrap();
671        con
672    }
673
674    #[rstest]
675    #[tokio::test(flavor = "multi_thread")]
676    async fn test_stream_messages_terminate_signal(#[future] redis_connection: ConnectionManager) {
677        let mut con = redis_connection.await;
678        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
679
680        let trader_id = TraderId::from("tester-001");
681        let instance_id = UUID4::new();
682        let config = MessageBusConfig {
683            database: Some(DatabaseConfig::default()),
684            ..Default::default()
685        };
686
687        let stream_key = get_stream_key(trader_id, instance_id, &config);
688        let external_streams = vec![stream_key.clone()];
689        let stream_signal = Arc::new(AtomicBool::new(false));
690        let stream_signal_clone = stream_signal.clone();
691
692        // Start the message streaming task
693        let handle = tokio::spawn(async move {
694            stream_messages(
695                tx,
696                DatabaseConfig::default(),
697                external_streams,
698                stream_signal_clone,
699            )
700            .await
701            .unwrap();
702        });
703
704        stream_signal.store(true, Ordering::Relaxed);
705        let _ = rx.recv().await; // Wait for the tx to close
706
707        // Shutdown and cleanup
708        rx.close();
709        handle.await.unwrap();
710        flush_redis(&mut con).await.unwrap();
711    }
712
713    #[rstest]
714    #[tokio::test(flavor = "multi_thread")]
715    async fn test_stream_messages_when_receiver_closed(
716        #[future] redis_connection: ConnectionManager,
717    ) {
718        let mut con = redis_connection.await;
719        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
720
721        let trader_id = TraderId::from("tester-001");
722        let instance_id = UUID4::new();
723        let config = MessageBusConfig {
724            database: Some(DatabaseConfig::default()),
725            ..Default::default()
726        };
727
728        let stream_key = get_stream_key(trader_id, instance_id, &config);
729        let external_streams = vec![stream_key.clone()];
730        let stream_signal = Arc::new(AtomicBool::new(false));
731        let stream_signal_clone = stream_signal.clone();
732
733        // Use a message ID in the future, as streaming begins
734        // around the timestamp the task is spawned.
735        let clock = get_atomic_clock_realtime();
736        let future_id = (clock.get_time_ms() + 1_000_000).to_string();
737
738        // Publish test message
739        let _: () = con
740            .xadd(
741                stream_key,
742                future_id,
743                &[("topic", "topic1"), ("payload", "data1")],
744            )
745            .await
746            .unwrap();
747
748        // Immediately close channel
749        rx.close();
750
751        // Start the message streaming task
752        let handle = tokio::spawn(async move {
753            stream_messages(
754                tx,
755                DatabaseConfig::default(),
756                external_streams,
757                stream_signal_clone,
758            )
759            .await
760            .unwrap();
761        });
762
763        // Shutdown and cleanup
764        handle.await.unwrap();
765        flush_redis(&mut con).await.unwrap();
766    }
767
768    #[rstest]
769    #[tokio::test(flavor = "multi_thread")]
770    async fn test_stream_messages(#[future] redis_connection: ConnectionManager) {
771        let mut con = redis_connection.await;
772        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
773
774        let trader_id = TraderId::from("tester-001");
775        let instance_id = UUID4::new();
776        let config = MessageBusConfig {
777            database: Some(DatabaseConfig::default()),
778            ..Default::default()
779        };
780
781        let stream_key = get_stream_key(trader_id, instance_id, &config);
782        let external_streams = vec![stream_key.clone()];
783        let stream_signal = Arc::new(AtomicBool::new(false));
784        let stream_signal_clone = stream_signal.clone();
785
786        // Use a message ID in the future, as streaming begins
787        // around the timestamp the task is spawned.
788        let clock = get_atomic_clock_realtime();
789        let future_id = (clock.get_time_ms() + 1_000_000).to_string();
790
791        // Publish test message
792        let _: () = con
793            .xadd(
794                stream_key,
795                future_id,
796                &[("topic", "topic1"), ("payload", "data1")],
797            )
798            .await
799            .unwrap();
800
801        // Start the message streaming task
802        let handle = tokio::spawn(async move {
803            stream_messages(
804                tx,
805                DatabaseConfig::default(),
806                external_streams,
807                stream_signal_clone,
808            )
809            .await
810            .unwrap();
811        });
812
813        // Receive and verify the message
814        let msg = rx.recv().await.unwrap();
815        assert_eq!(msg.topic, "topic1");
816        assert_eq!(msg.payload, Bytes::from("data1"));
817
818        // Shutdown and cleanup
819        rx.close();
820        stream_signal.store(true, Ordering::Relaxed);
821        handle.await.unwrap();
822        flush_redis(&mut con).await.unwrap();
823    }
824
825    #[rstest]
826    #[tokio::test(flavor = "multi_thread")]
827    async fn test_publish_messages(#[future] redis_connection: ConnectionManager) {
828        let mut con = redis_connection.await;
829        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
830
831        let trader_id = TraderId::from("tester-001");
832        let instance_id = UUID4::new();
833        let config = MessageBusConfig {
834            database: Some(DatabaseConfig::default()),
835            stream_per_topic: false,
836            ..Default::default()
837        };
838        let stream_key = get_stream_key(trader_id, instance_id, &config);
839
840        // Start the publish_messages task
841        let handle = tokio::spawn(async move {
842            publish_messages(rx, trader_id, instance_id, config)
843                .await
844                .unwrap();
845        });
846
847        // Send a test message
848        let msg = BusMessage::with_str_topic("test_topic", Bytes::from("test_payload"));
849        tx.send(msg).unwrap();
850
851        // Wait until the message is published to Redis
852        wait_until_async(
853            || {
854                let mut con = con.clone();
855                let stream_key = stream_key.clone();
856                async move {
857                    let messages: RedisStreamBulk =
858                        con.xread(&[&stream_key], &["0"]).await.unwrap();
859                    !messages.is_empty()
860                }
861            },
862            Duration::from_secs(3),
863        )
864        .await;
865
866        // Verify the message was published to Redis
867        let messages: RedisStreamBulk = con.xread(&[&stream_key], &["0"]).await.unwrap();
868        assert_eq!(messages.len(), 1);
869        let stream_msgs = messages[0].get(&stream_key).unwrap();
870        let stream_msg_array = &stream_msgs[0].values().next().unwrap();
871        let decoded_message = decode_bus_message(stream_msg_array).unwrap();
872        assert_eq!(decoded_message.topic, "test_topic");
873        assert_eq!(decoded_message.payload, Bytes::from("test_payload"));
874
875        // Stop publishing task
876        let msg = BusMessage::new_close();
877        tx.send(msg).unwrap();
878
879        // Shutdown and cleanup
880        handle.await.unwrap();
881        flush_redis(&mut con).await.unwrap();
882    }
883
884    #[rstest]
885    #[tokio::test(flavor = "multi_thread")]
886    async fn test_close() {
887        let trader_id = TraderId::from("tester-001");
888        let instance_id = UUID4::new();
889        let config = MessageBusConfig {
890            database: Some(DatabaseConfig::default()),
891            ..Default::default()
892        };
893
894        let mut db = RedisMessageBusDatabase::new(trader_id, instance_id, config).unwrap();
895
896        // Close the message bus database (test should not hang)
897        db.close();
898    }
899
900    #[rstest]
901    #[tokio::test(flavor = "multi_thread")]
902    async fn test_heartbeat_task() {
903        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
904        let signal = Arc::new(AtomicBool::new(false));
905
906        // Start the heartbeat task with a short interval
907        let handle = tokio::spawn(run_heartbeat(1, signal.clone(), tx));
908
909        // Wait for a couple of heartbeats
910        tokio::time::sleep(Duration::from_secs(2)).await;
911
912        // Stop the heartbeat task
913        signal.store(true, Ordering::Relaxed);
914        handle.await.unwrap();
915
916        // Ensure heartbeats were sent
917        let mut heartbeats: Vec<BusMessage> = Vec::new();
918        while let Ok(hb) = rx.try_recv() {
919            heartbeats.push(hb);
920        }
921
922        assert!(!heartbeats.is_empty());
923
924        for hb in heartbeats {
925            assert_eq!(hb.topic, HEARTBEAT_TOPIC);
926        }
927    }
928}