nautilus_infrastructure/redis/
msgbus.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    collections::{HashMap, VecDeque},
18    fmt::Debug,
19    sync::{
20        Arc,
21        atomic::{AtomicBool, Ordering},
22    },
23    time::Duration,
24};
25
26use bytes::Bytes;
27use futures::stream::Stream;
28use nautilus_common::{
29    live::runtime::get_runtime,
30    logging::{log_task_error, log_task_started, log_task_stopped},
31    msgbus::{
32        BusMessage,
33        database::{DatabaseConfig, MessageBusConfig, MessageBusDatabaseAdapter},
34        switchboard::CLOSE_TOPIC,
35    },
36};
37use nautilus_core::{
38    UUID4,
39    time::{duration_since_unix_epoch, get_atomic_clock_realtime},
40};
41use nautilus_cryptography::providers::install_cryptographic_provider;
42use nautilus_model::identifiers::TraderId;
43use redis::{AsyncCommands, streams};
44use streams::StreamReadOptions;
45use ustr::Ustr;
46
47use super::{REDIS_MINID, REDIS_XTRIM, await_handle};
48use crate::redis::{create_redis_connection, get_stream_key};
49
50const MSGBUS_PUBLISH: &str = "msgbus-publish";
51const MSGBUS_STREAM: &str = "msgbus-stream";
52const MSGBUS_HEARTBEAT: &str = "msgbus-heartbeat";
53const HEARTBEAT_TOPIC: &str = "health:heartbeat";
54const TRIM_BUFFER_SECS: u64 = 60;
55
56type RedisStreamBulk = Vec<HashMap<String, Vec<HashMap<String, redis::Value>>>>;
57
58#[cfg_attr(
59    feature = "python",
60    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.infrastructure")
61)]
62pub struct RedisMessageBusDatabase {
63    /// The trader ID for this message bus database.
64    pub trader_id: TraderId,
65    /// The instance ID for this message bus database.
66    pub instance_id: UUID4,
67    pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
68    pub_handle: Option<tokio::task::JoinHandle<()>>,
69    stream_rx: Option<tokio::sync::mpsc::Receiver<BusMessage>>,
70    stream_handle: Option<tokio::task::JoinHandle<()>>,
71    stream_signal: Arc<AtomicBool>,
72    heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
73    heartbeat_signal: Arc<AtomicBool>,
74}
75
76impl Debug for RedisMessageBusDatabase {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        f.debug_struct(stringify!(RedisMessageBusDatabase))
79            .field("trader_id", &self.trader_id)
80            .field("instance_id", &self.instance_id)
81            .finish()
82    }
83}
84
85impl MessageBusDatabaseAdapter for RedisMessageBusDatabase {
86    type DatabaseType = Self;
87
88    /// Creates a new [`RedisMessageBusDatabase`] instance for the given `trader_id`, `instance_id`, and `config`.
89    ///
90    /// # Errors
91    ///
92    /// Returns an error if:
93    /// - The database configuration is missing in `config`.
94    /// - Establishing the Redis connection for publishing fails.
95    fn new(
96        trader_id: TraderId,
97        instance_id: UUID4,
98        config: MessageBusConfig,
99    ) -> anyhow::Result<Self> {
100        install_cryptographic_provider();
101
102        let config_clone = config.clone();
103        let db_config = config
104            .database
105            .clone()
106            .ok_or_else(|| anyhow::anyhow!("No database config"))?;
107
108        let (pub_tx, pub_rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
109
110        // Create publish task (start the runtime here for now)
111        let pub_handle = Some(get_runtime().spawn(async move {
112            if let Err(e) = publish_messages(pub_rx, trader_id, instance_id, config_clone).await {
113                log_task_error(MSGBUS_PUBLISH, &e);
114            }
115        }));
116
117        // Conditionally create stream task and channel if external streams configured
118        let external_streams = config.external_streams.clone().unwrap_or_default();
119        let stream_signal = Arc::new(AtomicBool::new(false));
120        let (stream_rx, stream_handle) = if external_streams.is_empty() {
121            (None, None)
122        } else {
123            let stream_signal_clone = stream_signal.clone();
124            let (stream_tx, stream_rx) = tokio::sync::mpsc::channel::<BusMessage>(100_000);
125            (
126                Some(stream_rx),
127                Some(get_runtime().spawn(async move {
128                    if let Err(e) =
129                        stream_messages(stream_tx, db_config, external_streams, stream_signal_clone)
130                            .await
131                    {
132                        log_task_error(MSGBUS_STREAM, &e);
133                    }
134                })),
135            )
136        };
137
138        // Create heartbeat task
139        let heartbeat_signal = Arc::new(AtomicBool::new(false));
140        let heartbeat_handle = if let Some(heartbeat_interval_secs) = config.heartbeat_interval_secs
141        {
142            let signal = heartbeat_signal.clone();
143            let pub_tx_clone = pub_tx.clone();
144
145            Some(get_runtime().spawn(async move {
146                run_heartbeat(heartbeat_interval_secs, signal, pub_tx_clone).await;
147            }))
148        } else {
149            None
150        };
151
152        Ok(Self {
153            trader_id,
154            instance_id,
155            pub_tx,
156            pub_handle,
157            stream_rx,
158            stream_handle,
159            stream_signal,
160            heartbeat_handle,
161            heartbeat_signal,
162        })
163    }
164
165    /// Returns whether the message bus database adapter publishing channel is closed.
166    fn is_closed(&self) -> bool {
167        self.pub_tx.is_closed()
168    }
169
170    /// Publishes a message with the given `topic` and `payload`.
171    fn publish(&self, topic: Ustr, payload: Bytes) {
172        let msg = BusMessage::new(topic, payload);
173        if let Err(e) = self.pub_tx.send(msg) {
174            log::error!("Failed to send message: {e}");
175        }
176    }
177
178    /// Closes the message bus database adapter.
179    fn close(&mut self) {
180        log::debug!("Closing");
181
182        self.stream_signal.store(true, Ordering::Relaxed);
183        self.heartbeat_signal.store(true, Ordering::Relaxed);
184
185        if !self.pub_tx.is_closed() {
186            let msg = BusMessage::new_close();
187
188            if let Err(e) = self.pub_tx.send(msg) {
189                log::error!("Failed to send close message: {e:?}");
190            }
191        }
192
193        // Keep close sync for now to avoid async trait method
194        tokio::task::block_in_place(|| {
195            get_runtime().block_on(async {
196                self.close_async().await;
197            });
198        });
199
200        log::debug!("Closed");
201    }
202}
203
204impl RedisMessageBusDatabase {
205    /// Retrieves the Redis stream receiver for this message bus instance.
206    ///
207    /// # Errors
208    ///
209    /// Returns an error if the stream receiver has already been taken.
210    pub fn get_stream_receiver(
211        &mut self,
212    ) -> anyhow::Result<tokio::sync::mpsc::Receiver<BusMessage>> {
213        self.stream_rx
214            .take()
215            .ok_or_else(|| anyhow::anyhow!("Stream receiver already taken"))
216    }
217
218    /// Streams messages arriving on the stream receiver channel.
219    pub fn stream(
220        mut stream_rx: tokio::sync::mpsc::Receiver<BusMessage>,
221    ) -> impl Stream<Item = BusMessage> + 'static {
222        async_stream::stream! {
223            while let Some(msg) = stream_rx.recv().await {
224                yield msg;
225            }
226        }
227    }
228
229    pub async fn close_async(&mut self) {
230        await_handle(self.pub_handle.take(), MSGBUS_PUBLISH).await;
231        await_handle(self.stream_handle.take(), MSGBUS_STREAM).await;
232        await_handle(self.heartbeat_handle.take(), MSGBUS_HEARTBEAT).await;
233    }
234}
235
236/// Publishes messages received on `rx` to Redis streams for the given `trader_id` and `instance_id`, using `config`.
237///
238/// # Errors
239///
240/// Returns an error if:
241/// - The database configuration is missing in `config`.
242/// - Establishing the Redis connection fails.
243/// - Any Redis command fails during publishing.
244pub async fn publish_messages(
245    mut rx: tokio::sync::mpsc::UnboundedReceiver<BusMessage>,
246    trader_id: TraderId,
247    instance_id: UUID4,
248    config: MessageBusConfig,
249) -> anyhow::Result<()> {
250    log_task_started(MSGBUS_PUBLISH);
251
252    let db_config = config
253        .database
254        .as_ref()
255        .ok_or_else(|| anyhow::anyhow!("No database config"))?;
256    let mut con = create_redis_connection(MSGBUS_PUBLISH, db_config.clone()).await?;
257    let stream_key = get_stream_key(trader_id, instance_id, &config);
258
259    // Auto-trimming
260    let autotrim_duration = config
261        .autotrim_mins
262        .filter(|&mins| mins > 0)
263        .map(|mins| Duration::from_secs(u64::from(mins) * 60));
264    let mut last_trim_index: HashMap<String, usize> = HashMap::new();
265
266    // Buffering
267    let mut buffer: VecDeque<BusMessage> = VecDeque::new();
268    let buffer_interval = Duration::from_millis(u64::from(config.buffer_interval_ms.unwrap_or(0)));
269
270    // A sleep used to trigger periodic flushing of the buffer.
271    // When `buffer_interval` is zero we skip using the timer and flush immediately
272    // after every message.
273    let flush_timer = tokio::time::sleep(buffer_interval);
274    tokio::pin!(flush_timer);
275
276    loop {
277        tokio::select! {
278            maybe_msg = rx.recv() => {
279                if let Some(msg) = maybe_msg {
280                    if msg.topic == CLOSE_TOPIC {
281                        tracing::debug!("Received close message");
282                        // Ensure we exit the loop after flushing any remaining messages.
283                        if !buffer.is_empty() {
284                            drain_buffer(
285                                &mut con,
286                                &stream_key,
287                                config.stream_per_topic,
288                                autotrim_duration,
289                                &mut last_trim_index,
290                                &mut buffer,
291                            ).await?;
292                        }
293                        break;
294                    }
295
296                    buffer.push_back(msg);
297
298                    if buffer_interval.is_zero() {
299                        // Immediate flush mode
300                        drain_buffer(
301                            &mut con,
302                            &stream_key,
303                            config.stream_per_topic,
304                            autotrim_duration,
305                            &mut last_trim_index,
306                            &mut buffer,
307                        ).await?;
308                    }
309                } else {
310                    tracing::debug!("Channel hung up");
311                    break;
312                }
313            }
314            // Only poll the timer when the interval is non-zero. This avoids
315            // unnecessarily waking the task when immediate flushing is enabled.
316            () = &mut flush_timer, if !buffer_interval.is_zero() => {
317                if !buffer.is_empty() {
318                    drain_buffer(
319                        &mut con,
320                        &stream_key,
321                        config.stream_per_topic,
322                        autotrim_duration,
323                        &mut last_trim_index,
324                        &mut buffer,
325                    ).await?;
326                }
327
328                // Schedule the next tick
329                flush_timer.as_mut().reset(tokio::time::Instant::now() + buffer_interval);
330            }
331        }
332    }
333
334    // Drain any remaining messages
335    if !buffer.is_empty() {
336        drain_buffer(
337            &mut con,
338            &stream_key,
339            config.stream_per_topic,
340            autotrim_duration,
341            &mut last_trim_index,
342            &mut buffer,
343        )
344        .await?;
345    }
346
347    log_task_stopped(MSGBUS_PUBLISH);
348    Ok(())
349}
350
351async fn drain_buffer(
352    conn: &mut redis::aio::ConnectionManager,
353    stream_key: &str,
354    stream_per_topic: bool,
355    autotrim_duration: Option<Duration>,
356    last_trim_index: &mut HashMap<String, usize>,
357    buffer: &mut VecDeque<BusMessage>,
358) -> anyhow::Result<()> {
359    let mut pipe = redis::pipe();
360    pipe.atomic();
361
362    for msg in buffer.drain(..) {
363        let items: Vec<(&str, &[u8])> = vec![
364            ("topic", msg.topic.as_ref()),
365            ("payload", msg.payload.as_ref()),
366        ];
367        let stream_key = if stream_per_topic {
368            format!("{stream_key}:{}", &msg.topic)
369        } else {
370            stream_key.to_string()
371        };
372        pipe.xadd(&stream_key, "*", &items);
373
374        if autotrim_duration.is_none() {
375            continue; // Nothing else to do
376        }
377
378        // Autotrim stream
379        let last_trim_ms = last_trim_index.entry(stream_key.clone()).or_insert(0); // Remove clone
380        let unix_duration_now = duration_since_unix_epoch();
381        let trim_buffer = Duration::from_secs(TRIM_BUFFER_SECS);
382
383        // Improve efficiency of this by batching
384        if *last_trim_ms < (unix_duration_now - trim_buffer).as_millis() as usize {
385            let min_timestamp_ms =
386                (unix_duration_now - autotrim_duration.unwrap()).as_millis() as usize;
387            let result: Result<(), redis::RedisError> = redis::cmd(REDIS_XTRIM)
388                .arg(stream_key.clone())
389                .arg(REDIS_MINID)
390                .arg(min_timestamp_ms)
391                .query_async(conn)
392                .await;
393
394            if let Err(e) = result {
395                tracing::error!("Error trimming stream '{stream_key}': {e}");
396            } else {
397                last_trim_index.insert(stream_key.clone(), unix_duration_now.as_millis() as usize);
398            }
399        }
400    }
401
402    pipe.query_async(conn).await.map_err(anyhow::Error::from)
403}
404
405/// Streams messages from Redis streams and sends them over the provided `tx` channel.
406///
407/// # Errors
408///
409/// Returns an error if:
410/// - Establishing the Redis connection fails.
411/// - Any Redis read operation fails.
412pub async fn stream_messages(
413    tx: tokio::sync::mpsc::Sender<BusMessage>,
414    config: DatabaseConfig,
415    stream_keys: Vec<String>,
416    stream_signal: Arc<AtomicBool>,
417) -> anyhow::Result<()> {
418    log_task_started(MSGBUS_STREAM);
419
420    let mut con = create_redis_connection(MSGBUS_STREAM, config).await?;
421
422    let stream_keys = &stream_keys
423        .iter()
424        .map(String::as_str)
425        .collect::<Vec<&str>>();
426
427    tracing::debug!("Listening to streams: [{}]", stream_keys.join(", "));
428
429    // Start streaming from current timestamp
430    let clock = get_atomic_clock_realtime();
431    let timestamp_ms = clock.get_time_ms();
432    let initial_id = timestamp_ms.to_string();
433
434    let mut last_ids: HashMap<String, String> = stream_keys
435        .iter()
436        .map(|&key| (key.to_string(), initial_id.clone()))
437        .collect();
438
439    let opts = StreamReadOptions::default().block(100);
440
441    'outer: loop {
442        if stream_signal.load(Ordering::Relaxed) {
443            tracing::debug!("Received streaming terminate signal");
444            break;
445        }
446
447        let ids: Vec<String> = stream_keys
448            .iter()
449            .map(|&key| last_ids[key].clone())
450            .collect();
451        let id_refs: Vec<&str> = ids.iter().map(String::as_str).collect();
452
453        let result: Result<RedisStreamBulk, _> =
454            con.xread_options(&[&stream_keys], &[&id_refs], &opts).await;
455        match result {
456            Ok(stream_bulk) => {
457                if stream_bulk.is_empty() {
458                    // Timeout occurred: no messages received
459                    continue;
460                }
461                for entry in &stream_bulk {
462                    for (stream_key, stream_msgs) in entry {
463                        for stream_msg in stream_msgs {
464                            for (id, array) in stream_msg {
465                                last_ids.insert(stream_key.clone(), id.clone());
466
467                                match decode_bus_message(array) {
468                                    Ok(msg) => {
469                                        if let Err(e) = tx.send(msg).await {
470                                            tracing::debug!("Channel closed: {e:?}");
471                                            break 'outer; // End streaming
472                                        }
473                                    }
474                                    Err(e) => {
475                                        tracing::error!("{e:?}");
476                                        continue;
477                                    }
478                                }
479                            }
480                        }
481                    }
482                }
483            }
484            Err(e) => {
485                anyhow::bail!("Error reading from stream: {e:?}");
486            }
487        }
488    }
489
490    log_task_stopped(MSGBUS_STREAM);
491    Ok(())
492}
493
494/// Decodes a Redis stream message value into a `BusMessage`.
495///
496/// # Errors
497///
498/// Returns an error if:
499/// - The incoming `stream_msg` is not an array.
500/// - The array has fewer than four elements (invalid format).
501/// - Parsing the topic or payload fails.
502fn decode_bus_message(stream_msg: &redis::Value) -> anyhow::Result<BusMessage> {
503    if let redis::Value::Array(stream_msg) = stream_msg {
504        if stream_msg.len() < 4 {
505            anyhow::bail!("Invalid stream message format: {stream_msg:?}");
506        }
507
508        let topic = match &stream_msg[1] {
509            redis::Value::BulkString(bytes) => match String::from_utf8(bytes.clone()) {
510                Ok(topic) => topic,
511                Err(e) => anyhow::bail!("Error parsing topic: {e}"),
512            },
513            _ => {
514                anyhow::bail!("Invalid topic format: {stream_msg:?}");
515            }
516        };
517
518        let payload = match &stream_msg[3] {
519            redis::Value::BulkString(bytes) => Bytes::copy_from_slice(bytes),
520            _ => {
521                anyhow::bail!("Invalid payload format: {stream_msg:?}");
522            }
523        };
524
525        Ok(BusMessage::with_str_topic(topic, payload))
526    } else {
527        anyhow::bail!("Invalid stream message format: {stream_msg:?}")
528    }
529}
530
531async fn run_heartbeat(
532    heartbeat_interval_secs: u16,
533    signal: Arc<AtomicBool>,
534    pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
535) {
536    log_task_started("heartbeat");
537    tracing::debug!("Heartbeat at {heartbeat_interval_secs} second intervals");
538
539    let heartbeat_interval = Duration::from_secs(u64::from(heartbeat_interval_secs));
540    let heartbeat_timer = tokio::time::interval(heartbeat_interval);
541
542    let check_interval = Duration::from_millis(100);
543    let check_timer = tokio::time::interval(check_interval);
544
545    tokio::pin!(heartbeat_timer);
546    tokio::pin!(check_timer);
547
548    loop {
549        if signal.load(Ordering::Relaxed) {
550            tracing::debug!("Received heartbeat terminate signal");
551            break;
552        }
553
554        tokio::select! {
555            _ = heartbeat_timer.tick() => {
556                let heartbeat = create_heartbeat_msg();
557                if let Err(e) = pub_tx.send(heartbeat) {
558                    // We expect an error if the channel is closed during shutdown
559                    tracing::debug!("Error sending heartbeat: {e}");
560                }
561            },
562            _ = check_timer.tick() => {}
563        }
564    }
565
566    log_task_stopped("heartbeat");
567}
568
569fn create_heartbeat_msg() -> BusMessage {
570    let payload = Bytes::from(chrono::Utc::now().to_rfc3339().into_bytes());
571    BusMessage::with_str_topic(HEARTBEAT_TOPIC, payload)
572}
573
574#[cfg(test)]
575mod tests {
576    use redis::Value;
577    use rstest::*;
578
579    use super::*;
580
581    #[rstest]
582    fn test_decode_bus_message_valid() {
583        let stream_msg = Value::Array(vec![
584            Value::BulkString(b"0".to_vec()),
585            Value::BulkString(b"topic1".to_vec()),
586            Value::BulkString(b"unused".to_vec()),
587            Value::BulkString(b"data1".to_vec()),
588        ]);
589
590        let result = decode_bus_message(&stream_msg);
591        assert!(result.is_ok());
592        let msg = result.unwrap();
593        assert_eq!(msg.topic, "topic1");
594        assert_eq!(msg.payload, Bytes::from("data1"));
595    }
596
597    #[rstest]
598    fn test_decode_bus_message_missing_fields() {
599        let stream_msg = Value::Array(vec![
600            Value::BulkString(b"0".to_vec()),
601            Value::BulkString(b"topic1".to_vec()),
602        ]);
603
604        let result = decode_bus_message(&stream_msg);
605        assert!(result.is_err());
606        assert_eq!(
607            format!("{}", result.unwrap_err()),
608            "Invalid stream message format: [bulk-string('\"0\"'), bulk-string('\"topic1\"')]"
609        );
610    }
611
612    #[rstest]
613    fn test_decode_bus_message_invalid_topic_format() {
614        let stream_msg = Value::Array(vec![
615            Value::BulkString(b"0".to_vec()),
616            Value::Int(42), // Invalid topic format
617            Value::BulkString(b"unused".to_vec()),
618            Value::BulkString(b"data1".to_vec()),
619        ]);
620
621        let result = decode_bus_message(&stream_msg);
622        assert!(result.is_err());
623        assert_eq!(
624            format!("{}", result.unwrap_err()),
625            "Invalid topic format: [bulk-string('\"0\"'), int(42), bulk-string('\"unused\"'), bulk-string('\"data1\"')]"
626        );
627    }
628
629    #[rstest]
630    fn test_decode_bus_message_invalid_payload_format() {
631        let stream_msg = Value::Array(vec![
632            Value::BulkString(b"0".to_vec()),
633            Value::BulkString(b"topic1".to_vec()),
634            Value::BulkString(b"unused".to_vec()),
635            Value::Int(42), // Invalid payload format
636        ]);
637
638        let result = decode_bus_message(&stream_msg);
639        assert!(result.is_err());
640        assert_eq!(
641            format!("{}", result.unwrap_err()),
642            "Invalid payload format: [bulk-string('\"0\"'), bulk-string('\"topic1\"'), bulk-string('\"unused\"'), int(42)]"
643        );
644    }
645
646    #[rstest]
647    fn test_decode_bus_message_invalid_stream_msg_format() {
648        let stream_msg = Value::BulkString(b"not an array".to_vec());
649
650        let result = decode_bus_message(&stream_msg);
651        assert!(result.is_err());
652        assert_eq!(
653            format!("{}", result.unwrap_err()),
654            "Invalid stream message format: bulk-string('\"not an array\"')"
655        );
656    }
657}
658
659#[cfg(target_os = "linux")] // Run Redis tests on Linux platforms only
660#[cfg(test)]
661mod serial_tests {
662    use nautilus_common::testing::wait_until_async;
663    use redis::aio::ConnectionManager;
664    use rstest::*;
665
666    use super::*;
667    use crate::redis::flush_redis;
668
669    #[fixture]
670    async fn redis_connection() -> ConnectionManager {
671        let config = DatabaseConfig::default();
672        let mut con = create_redis_connection(MSGBUS_STREAM, config)
673            .await
674            .unwrap();
675        flush_redis(&mut con).await.unwrap();
676        con
677    }
678
679    #[rstest]
680    #[tokio::test(flavor = "multi_thread")]
681    async fn test_stream_messages_terminate_signal(#[future] redis_connection: ConnectionManager) {
682        let mut con = redis_connection.await;
683        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
684
685        let trader_id = TraderId::from("tester-001");
686        let instance_id = UUID4::new();
687        let config = MessageBusConfig {
688            database: Some(DatabaseConfig::default()),
689            ..Default::default()
690        };
691
692        let stream_key = get_stream_key(trader_id, instance_id, &config);
693        let external_streams = vec![stream_key.clone()];
694        let stream_signal = Arc::new(AtomicBool::new(false));
695        let stream_signal_clone = stream_signal.clone();
696
697        // Start the message streaming task
698        let handle = tokio::spawn(async move {
699            stream_messages(
700                tx,
701                DatabaseConfig::default(),
702                external_streams,
703                stream_signal_clone,
704            )
705            .await
706            .unwrap();
707        });
708
709        stream_signal.store(true, Ordering::Relaxed);
710        let _ = rx.recv().await; // Wait for the tx to close
711
712        // Shutdown and cleanup
713        rx.close();
714        handle.await.unwrap();
715        flush_redis(&mut con).await.unwrap();
716    }
717
718    #[rstest]
719    #[tokio::test(flavor = "multi_thread")]
720    async fn test_stream_messages_when_receiver_closed(
721        #[future] redis_connection: ConnectionManager,
722    ) {
723        let mut con = redis_connection.await;
724        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
725
726        let trader_id = TraderId::from("tester-001");
727        let instance_id = UUID4::new();
728        let config = MessageBusConfig {
729            database: Some(DatabaseConfig::default()),
730            ..Default::default()
731        };
732
733        let stream_key = get_stream_key(trader_id, instance_id, &config);
734        let external_streams = vec![stream_key.clone()];
735        let stream_signal = Arc::new(AtomicBool::new(false));
736        let stream_signal_clone = stream_signal.clone();
737
738        // Use a message ID in the future, as streaming begins
739        // around the timestamp the task is spawned.
740        let clock = get_atomic_clock_realtime();
741        let future_id = (clock.get_time_ms() + 1_000_000).to_string();
742
743        // Publish test message
744        let _: () = con
745            .xadd(
746                stream_key,
747                future_id,
748                &[("topic", "topic1"), ("payload", "data1")],
749            )
750            .await
751            .unwrap();
752
753        // Immediately close channel
754        rx.close();
755
756        // Start the message streaming task
757        let handle = tokio::spawn(async move {
758            stream_messages(
759                tx,
760                DatabaseConfig::default(),
761                external_streams,
762                stream_signal_clone,
763            )
764            .await
765            .unwrap();
766        });
767
768        // Shutdown and cleanup
769        handle.await.unwrap();
770        flush_redis(&mut con).await.unwrap();
771    }
772
773    #[rstest]
774    #[tokio::test(flavor = "multi_thread")]
775    async fn test_stream_messages(#[future] redis_connection: ConnectionManager) {
776        let mut con = redis_connection.await;
777        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
778
779        let trader_id = TraderId::from("tester-001");
780        let instance_id = UUID4::new();
781        let config = MessageBusConfig {
782            database: Some(DatabaseConfig::default()),
783            ..Default::default()
784        };
785
786        let stream_key = get_stream_key(trader_id, instance_id, &config);
787        let external_streams = vec![stream_key.clone()];
788        let stream_signal = Arc::new(AtomicBool::new(false));
789        let stream_signal_clone = stream_signal.clone();
790
791        // Use a message ID in the future, as streaming begins
792        // around the timestamp the task is spawned.
793        let clock = get_atomic_clock_realtime();
794        let future_id = (clock.get_time_ms() + 1_000_000).to_string();
795
796        // Publish test message
797        let _: () = con
798            .xadd(
799                stream_key,
800                future_id,
801                &[("topic", "topic1"), ("payload", "data1")],
802            )
803            .await
804            .unwrap();
805
806        // Start the message streaming task
807        let handle = tokio::spawn(async move {
808            stream_messages(
809                tx,
810                DatabaseConfig::default(),
811                external_streams,
812                stream_signal_clone,
813            )
814            .await
815            .unwrap();
816        });
817
818        // Receive and verify the message
819        let msg = rx.recv().await.unwrap();
820        assert_eq!(msg.topic, "topic1");
821        assert_eq!(msg.payload, Bytes::from("data1"));
822
823        // Shutdown and cleanup
824        rx.close();
825        stream_signal.store(true, Ordering::Relaxed);
826        handle.await.unwrap();
827        flush_redis(&mut con).await.unwrap();
828    }
829
830    #[rstest]
831    #[tokio::test(flavor = "multi_thread")]
832    async fn test_publish_messages(#[future] redis_connection: ConnectionManager) {
833        let mut con = redis_connection.await;
834        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
835
836        let trader_id = TraderId::from("tester-001");
837        let instance_id = UUID4::new();
838        let config = MessageBusConfig {
839            database: Some(DatabaseConfig::default()),
840            stream_per_topic: false,
841            ..Default::default()
842        };
843        let stream_key = get_stream_key(trader_id, instance_id, &config);
844
845        // Start the publish_messages task
846        let handle = tokio::spawn(async move {
847            publish_messages(rx, trader_id, instance_id, config)
848                .await
849                .unwrap();
850        });
851
852        // Send a test message
853        let msg = BusMessage::with_str_topic("test_topic", Bytes::from("test_payload"));
854        tx.send(msg).unwrap();
855
856        // Wait until the message is published to Redis
857        wait_until_async(
858            || {
859                let mut con = con.clone();
860                let stream_key = stream_key.clone();
861                async move {
862                    let messages: RedisStreamBulk =
863                        con.xread(&[&stream_key], &["0"]).await.unwrap();
864                    !messages.is_empty()
865                }
866            },
867            Duration::from_secs(3),
868        )
869        .await;
870
871        // Verify the message was published to Redis
872        let messages: RedisStreamBulk = con.xread(&[&stream_key], &["0"]).await.unwrap();
873        assert_eq!(messages.len(), 1);
874        let stream_msgs = messages[0].get(&stream_key).unwrap();
875        let stream_msg_array = &stream_msgs[0].values().next().unwrap();
876        let decoded_message = decode_bus_message(stream_msg_array).unwrap();
877        assert_eq!(decoded_message.topic, "test_topic");
878        assert_eq!(decoded_message.payload, Bytes::from("test_payload"));
879
880        // Stop publishing task
881        let msg = BusMessage::new_close();
882        tx.send(msg).unwrap();
883
884        // Shutdown and cleanup
885        handle.await.unwrap();
886        flush_redis(&mut con).await.unwrap();
887    }
888
889    #[rstest]
890    #[tokio::test(flavor = "multi_thread")]
891    async fn test_stream_messages_multiple_streams(#[future] redis_connection: ConnectionManager) {
892        let mut con = redis_connection.await;
893        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
894
895        // Setup multiple stream keys
896        let stream_key1 = "test:stream:1".to_string();
897        let stream_key2 = "test:stream:2".to_string();
898        let external_streams = vec![stream_key1.clone(), stream_key2.clone()];
899        let stream_signal = Arc::new(AtomicBool::new(false));
900        let stream_signal_clone = stream_signal.clone();
901
902        let clock = get_atomic_clock_realtime();
903        let base_id = clock.get_time_ms() + 1_000_000;
904
905        // Start streaming task
906        let handle = tokio::spawn(async move {
907            stream_messages(
908                tx,
909                DatabaseConfig::default(),
910                external_streams,
911                stream_signal_clone,
912            )
913            .await
914            .unwrap();
915        });
916
917        tokio::time::sleep(Duration::from_millis(200)).await;
918
919        // Publish to stream 1 at higher ID
920        let _: () = con
921            .xadd(
922                &stream_key1,
923                format!("{}", base_id + 100),
924                &[("topic", "stream1-first"), ("payload", "data")],
925            )
926            .await
927            .unwrap();
928
929        let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
930            .await
931            .expect("Stream 1 message should be received")
932            .unwrap();
933        assert_eq!(msg.topic, "stream1-first");
934
935        // Publish to stream 2 at lower ID (tests independent cursor tracking)
936        let _: () = con
937            .xadd(
938                &stream_key2,
939                format!("{}", base_id + 50),
940                &[("topic", "stream2-second"), ("payload", "data")],
941            )
942            .await
943            .unwrap();
944
945        let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
946            .await
947            .expect("Stream 2 message should be received")
948            .unwrap();
949        assert_eq!(msg.topic, "stream2-second");
950
951        // Shutdown and cleanup
952        rx.close();
953        stream_signal.store(true, Ordering::Relaxed);
954        handle.await.unwrap();
955        flush_redis(&mut con).await.unwrap();
956    }
957
958    #[rstest]
959    #[tokio::test(flavor = "multi_thread")]
960    async fn test_stream_messages_interleaved_at_different_rates(
961        #[future] redis_connection: ConnectionManager,
962    ) {
963        let mut con = redis_connection.await;
964        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
965
966        // Setup multiple stream keys
967        let stream_key1 = "test:stream:interleaved:1".to_string();
968        let stream_key2 = "test:stream:interleaved:2".to_string();
969        let stream_key3 = "test:stream:interleaved:3".to_string();
970        let external_streams = vec![
971            stream_key1.clone(),
972            stream_key2.clone(),
973            stream_key3.clone(),
974        ];
975        let stream_signal = Arc::new(AtomicBool::new(false));
976        let stream_signal_clone = stream_signal.clone();
977
978        let clock = get_atomic_clock_realtime();
979        let base_id = clock.get_time_ms() + 1_000_000;
980
981        let handle = tokio::spawn(async move {
982            stream_messages(
983                tx,
984                DatabaseConfig::default(),
985                external_streams,
986                stream_signal_clone,
987            )
988            .await
989            .unwrap();
990        });
991
992        tokio::time::sleep(Duration::from_millis(200)).await;
993
994        // Stream 1 advances with high ID
995        let _: () = con
996            .xadd(
997                &stream_key1,
998                format!("{}", base_id + 100),
999                &[("topic", "s1m1"), ("payload", "data")],
1000            )
1001            .await
1002            .unwrap();
1003        let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1004            .await
1005            .expect("Stream 1 message should be received")
1006            .unwrap();
1007        assert_eq!(msg.topic, "s1m1");
1008
1009        // Stream 2 gets message at lower ID - would be skipped with global cursor
1010        let _: () = con
1011            .xadd(
1012                &stream_key2,
1013                format!("{}", base_id + 50),
1014                &[("topic", "s2m1"), ("payload", "data")],
1015            )
1016            .await
1017            .unwrap();
1018        let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1019            .await
1020            .expect("Stream 2 message should be received")
1021            .unwrap();
1022        assert_eq!(msg.topic, "s2m1");
1023
1024        // Stream 3 gets message at even lower ID
1025        let _: () = con
1026            .xadd(
1027                &stream_key3,
1028                format!("{}", base_id + 25),
1029                &[("topic", "s3m1"), ("payload", "data")],
1030            )
1031            .await
1032            .unwrap();
1033        let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1034            .await
1035            .expect("Stream 3 message should be received")
1036            .unwrap();
1037        assert_eq!(msg.topic, "s3m1");
1038
1039        // Shutdown and cleanup
1040        rx.close();
1041        stream_signal.store(true, Ordering::Relaxed);
1042        handle.await.unwrap();
1043        flush_redis(&mut con).await.unwrap();
1044    }
1045
1046    #[rstest]
1047    #[tokio::test(flavor = "multi_thread")]
1048    async fn test_close() {
1049        let trader_id = TraderId::from("tester-001");
1050        let instance_id = UUID4::new();
1051        let config = MessageBusConfig {
1052            database: Some(DatabaseConfig::default()),
1053            ..Default::default()
1054        };
1055
1056        let mut db = RedisMessageBusDatabase::new(trader_id, instance_id, config).unwrap();
1057
1058        // Close the message bus database (test should not hang)
1059        db.close();
1060    }
1061
1062    #[rstest]
1063    #[tokio::test(flavor = "multi_thread")]
1064    async fn test_heartbeat_task() {
1065        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
1066        let signal = Arc::new(AtomicBool::new(false));
1067
1068        // Start the heartbeat task with a short interval
1069        let handle = tokio::spawn(run_heartbeat(1, signal.clone(), tx));
1070
1071        // Wait for a couple of heartbeats
1072        tokio::time::sleep(Duration::from_secs(2)).await;
1073
1074        // Stop the heartbeat task
1075        signal.store(true, Ordering::Relaxed);
1076        handle.await.unwrap();
1077
1078        // Ensure heartbeats were sent
1079        let mut heartbeats: Vec<BusMessage> = Vec::new();
1080        while let Ok(hb) = rx.try_recv() {
1081            heartbeats.push(hb);
1082        }
1083
1084        assert!(!heartbeats.is_empty());
1085
1086        for hb in heartbeats {
1087            assert_eq!(hb.topic, HEARTBEAT_TOPIC);
1088        }
1089    }
1090}