nautilus_infrastructure/redis/
msgbus.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    collections::{HashMap, VecDeque},
18    fmt::Debug,
19    sync::{
20        Arc,
21        atomic::{AtomicBool, Ordering},
22    },
23    time::Duration,
24};
25
26use bytes::Bytes;
27use futures::stream::Stream;
28use nautilus_common::{
29    logging::{log_task_error, log_task_started, log_task_stopped},
30    msgbus::{
31        BusMessage,
32        database::{DatabaseConfig, MessageBusConfig, MessageBusDatabaseAdapter},
33        switchboard::CLOSE_TOPIC,
34    },
35    runtime::get_runtime,
36};
37use nautilus_core::{
38    UUID4,
39    time::{duration_since_unix_epoch, get_atomic_clock_realtime},
40};
41use nautilus_cryptography::providers::install_cryptographic_provider;
42use nautilus_model::identifiers::TraderId;
43use redis::{AsyncCommands, streams};
44use streams::StreamReadOptions;
45use tokio::time::Instant;
46use ustr::Ustr;
47
48use super::{REDIS_MINID, REDIS_XTRIM, await_handle};
49use crate::redis::{create_redis_connection, get_stream_key};
50
51const MSGBUS_PUBLISH: &str = "msgbus-publish";
52const MSGBUS_STREAM: &str = "msgbus-stream";
53const MSGBUS_HEARTBEAT: &str = "msgbus-heartbeat";
54const HEARTBEAT_TOPIC: &str = "health:heartbeat";
55const TRIM_BUFFER_SECS: u64 = 60;
56
57type RedisStreamBulk = Vec<HashMap<String, Vec<HashMap<String, redis::Value>>>>;
58
59#[cfg_attr(
60    feature = "python",
61    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.infrastructure")
62)]
63pub struct RedisMessageBusDatabase {
64    /// The trader ID for this message bus database.
65    pub trader_id: TraderId,
66    /// The instance ID for this message bus database.
67    pub instance_id: UUID4,
68    pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
69    pub_handle: Option<tokio::task::JoinHandle<()>>,
70    stream_rx: Option<tokio::sync::mpsc::Receiver<BusMessage>>,
71    stream_handle: Option<tokio::task::JoinHandle<()>>,
72    stream_signal: Arc<AtomicBool>,
73    heartbeat_handle: Option<tokio::task::JoinHandle<()>>,
74    heartbeat_signal: Arc<AtomicBool>,
75}
76
77impl Debug for RedisMessageBusDatabase {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct(stringify!(RedisMessageBusDatabase))
80            .field("trader_id", &self.trader_id)
81            .field("instance_id", &self.instance_id)
82            .finish()
83    }
84}
85
86impl MessageBusDatabaseAdapter for RedisMessageBusDatabase {
87    type DatabaseType = Self;
88
89    /// Creates a new [`RedisMessageBusDatabase`] instance for the given `trader_id`, `instance_id`, and `config`.
90    ///
91    /// # Errors
92    ///
93    /// Returns an error if:
94    /// - The database configuration is missing in `config`.
95    /// - Establishing the Redis connection for publishing fails.
96    fn new(
97        trader_id: TraderId,
98        instance_id: UUID4,
99        config: MessageBusConfig,
100    ) -> anyhow::Result<Self> {
101        install_cryptographic_provider();
102
103        let config_clone = config.clone();
104        let db_config = config
105            .database
106            .clone()
107            .ok_or_else(|| anyhow::anyhow!("No database config"))?;
108
109        let (pub_tx, pub_rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
110
111        // Create publish task (start the runtime here for now)
112        let pub_handle = Some(get_runtime().spawn(async move {
113            if let Err(e) = publish_messages(pub_rx, trader_id, instance_id, config_clone).await {
114                log_task_error(MSGBUS_PUBLISH, &e);
115            }
116        }));
117
118        // Conditionally create stream task and channel if external streams configured
119        let external_streams = config.external_streams.clone().unwrap_or_default();
120        let stream_signal = Arc::new(AtomicBool::new(false));
121        let (stream_rx, stream_handle) = if external_streams.is_empty() {
122            (None, None)
123        } else {
124            let stream_signal_clone = stream_signal.clone();
125            let (stream_tx, stream_rx) = tokio::sync::mpsc::channel::<BusMessage>(100_000);
126            (
127                Some(stream_rx),
128                Some(get_runtime().spawn(async move {
129                    if let Err(e) =
130                        stream_messages(stream_tx, db_config, external_streams, stream_signal_clone)
131                            .await
132                    {
133                        log_task_error(MSGBUS_STREAM, &e);
134                    }
135                })),
136            )
137        };
138
139        // Create heartbeat task
140        let heartbeat_signal = Arc::new(AtomicBool::new(false));
141        let heartbeat_handle = if let Some(heartbeat_interval_secs) = config.heartbeat_interval_secs
142        {
143            let signal = heartbeat_signal.clone();
144            let pub_tx_clone = pub_tx.clone();
145
146            Some(get_runtime().spawn(async move {
147                run_heartbeat(heartbeat_interval_secs, signal, pub_tx_clone).await;
148            }))
149        } else {
150            None
151        };
152
153        Ok(Self {
154            trader_id,
155            instance_id,
156            pub_tx,
157            pub_handle,
158            stream_rx,
159            stream_handle,
160            stream_signal,
161            heartbeat_handle,
162            heartbeat_signal,
163        })
164    }
165
166    /// Returns whether the message bus database adapter publishing channel is closed.
167    fn is_closed(&self) -> bool {
168        self.pub_tx.is_closed()
169    }
170
171    /// Publishes a message with the given `topic` and `payload`.
172    fn publish(&self, topic: Ustr, payload: Bytes) {
173        let msg = BusMessage::new(topic, payload);
174        if let Err(e) = self.pub_tx.send(msg) {
175            log::error!("Failed to send message: {e}");
176        }
177    }
178
179    /// Closes the message bus database adapter.
180    fn close(&mut self) {
181        log::debug!("Closing");
182
183        self.stream_signal.store(true, Ordering::Relaxed);
184        self.heartbeat_signal.store(true, Ordering::Relaxed);
185
186        if !self.pub_tx.is_closed() {
187            let msg = BusMessage::new_close();
188
189            if let Err(e) = self.pub_tx.send(msg) {
190                log::error!("Failed to send close message: {e:?}");
191            }
192        }
193
194        // Keep close sync for now to avoid async trait method
195        tokio::task::block_in_place(|| {
196            get_runtime().block_on(async {
197                self.close_async().await;
198            });
199        });
200
201        log::debug!("Closed");
202    }
203}
204
205impl RedisMessageBusDatabase {
206    /// Retrieves the Redis stream receiver for this message bus instance.
207    ///
208    /// # Errors
209    ///
210    /// Returns an error if the stream receiver has already been taken.
211    pub fn get_stream_receiver(
212        &mut self,
213    ) -> anyhow::Result<tokio::sync::mpsc::Receiver<BusMessage>> {
214        self.stream_rx
215            .take()
216            .ok_or_else(|| anyhow::anyhow!("Stream receiver already taken"))
217    }
218
219    /// Streams messages arriving on the stream receiver channel.
220    pub fn stream(
221        mut stream_rx: tokio::sync::mpsc::Receiver<BusMessage>,
222    ) -> impl Stream<Item = BusMessage> + 'static {
223        async_stream::stream! {
224            while let Some(msg) = stream_rx.recv().await {
225                yield msg;
226            }
227        }
228    }
229
230    pub async fn close_async(&mut self) {
231        await_handle(self.pub_handle.take(), MSGBUS_PUBLISH).await;
232        await_handle(self.stream_handle.take(), MSGBUS_STREAM).await;
233        await_handle(self.heartbeat_handle.take(), MSGBUS_HEARTBEAT).await;
234    }
235}
236
237/// Publishes messages received on `rx` to Redis streams for the given `trader_id` and `instance_id`, using `config`.
238///
239/// # Errors
240///
241/// Returns an error if:
242/// - The database configuration is missing in `config`.
243/// - Establishing the Redis connection fails.
244/// - Any Redis command fails during publishing.
245pub async fn publish_messages(
246    mut rx: tokio::sync::mpsc::UnboundedReceiver<BusMessage>,
247    trader_id: TraderId,
248    instance_id: UUID4,
249    config: MessageBusConfig,
250) -> anyhow::Result<()> {
251    log_task_started(MSGBUS_PUBLISH);
252
253    let db_config = config
254        .database
255        .as_ref()
256        .ok_or_else(|| anyhow::anyhow!("No database config"))?;
257    let mut con = create_redis_connection(MSGBUS_PUBLISH, db_config.clone()).await?;
258    let stream_key = get_stream_key(trader_id, instance_id, &config);
259
260    // Auto-trimming
261    let autotrim_duration = config
262        .autotrim_mins
263        .filter(|&mins| mins > 0)
264        .map(|mins| Duration::from_secs(u64::from(mins) * 60));
265    let mut last_trim_index: HashMap<String, usize> = HashMap::new();
266
267    // Buffering
268    let mut buffer: VecDeque<BusMessage> = VecDeque::new();
269    let buffer_interval = Duration::from_millis(u64::from(config.buffer_interval_ms.unwrap_or(0)));
270
271    // A sleep used to trigger periodic flushing of the buffer.
272    // When `buffer_interval` is zero we skip using the timer and flush immediately
273    // after every message.
274    let flush_timer = tokio::time::sleep(buffer_interval);
275    tokio::pin!(flush_timer);
276
277    loop {
278        tokio::select! {
279            maybe_msg = rx.recv() => {
280                if let Some(msg) = maybe_msg {
281                    if msg.topic == CLOSE_TOPIC {
282                        tracing::debug!("Received close message");
283                        // Ensure we exit the loop after flushing any remaining messages.
284                        if !buffer.is_empty() {
285                            drain_buffer(
286                                &mut con,
287                                &stream_key,
288                                config.stream_per_topic,
289                                autotrim_duration,
290                                &mut last_trim_index,
291                                &mut buffer,
292                            ).await?;
293                        }
294                        break;
295                    }
296
297                    buffer.push_back(msg);
298
299                    if buffer_interval.is_zero() {
300                        // Immediate flush mode
301                        drain_buffer(
302                            &mut con,
303                            &stream_key,
304                            config.stream_per_topic,
305                            autotrim_duration,
306                            &mut last_trim_index,
307                            &mut buffer,
308                        ).await?;
309                    }
310                } else {
311                    tracing::debug!("Channel hung up");
312                    break;
313                }
314            }
315            // Only poll the timer when the interval is non-zero. This avoids
316            // unnecessarily waking the task when immediate flushing is enabled.
317            () = &mut flush_timer, if !buffer_interval.is_zero() => {
318                if !buffer.is_empty() {
319                    drain_buffer(
320                        &mut con,
321                        &stream_key,
322                        config.stream_per_topic,
323                        autotrim_duration,
324                        &mut last_trim_index,
325                        &mut buffer,
326                    ).await?;
327                }
328
329                // Schedule the next tick
330                flush_timer.as_mut().reset(Instant::now() + buffer_interval);
331            }
332        }
333    }
334
335    // Drain any remaining messages
336    if !buffer.is_empty() {
337        drain_buffer(
338            &mut con,
339            &stream_key,
340            config.stream_per_topic,
341            autotrim_duration,
342            &mut last_trim_index,
343            &mut buffer,
344        )
345        .await?;
346    }
347
348    log_task_stopped(MSGBUS_PUBLISH);
349    Ok(())
350}
351
352async fn drain_buffer(
353    conn: &mut redis::aio::ConnectionManager,
354    stream_key: &str,
355    stream_per_topic: bool,
356    autotrim_duration: Option<Duration>,
357    last_trim_index: &mut HashMap<String, usize>,
358    buffer: &mut VecDeque<BusMessage>,
359) -> anyhow::Result<()> {
360    let mut pipe = redis::pipe();
361    pipe.atomic();
362
363    for msg in buffer.drain(..) {
364        let items: Vec<(&str, &[u8])> = vec![
365            ("topic", msg.topic.as_ref()),
366            ("payload", msg.payload.as_ref()),
367        ];
368        let stream_key = if stream_per_topic {
369            format!("{stream_key}:{}", &msg.topic)
370        } else {
371            stream_key.to_string()
372        };
373        pipe.xadd(&stream_key, "*", &items);
374
375        if autotrim_duration.is_none() {
376            continue; // Nothing else to do
377        }
378
379        // Autotrim stream
380        let last_trim_ms = last_trim_index.entry(stream_key.clone()).or_insert(0); // Remove clone
381        let unix_duration_now = duration_since_unix_epoch();
382        let trim_buffer = Duration::from_secs(TRIM_BUFFER_SECS);
383
384        // Improve efficiency of this by batching
385        if *last_trim_ms < (unix_duration_now - trim_buffer).as_millis() as usize {
386            let min_timestamp_ms =
387                (unix_duration_now - autotrim_duration.unwrap()).as_millis() as usize;
388            let result: Result<(), redis::RedisError> = redis::cmd(REDIS_XTRIM)
389                .arg(stream_key.clone())
390                .arg(REDIS_MINID)
391                .arg(min_timestamp_ms)
392                .query_async(conn)
393                .await;
394
395            if let Err(e) = result {
396                tracing::error!("Error trimming stream '{stream_key}': {e}");
397            } else {
398                last_trim_index.insert(stream_key.clone(), unix_duration_now.as_millis() as usize);
399            }
400        }
401    }
402
403    pipe.query_async(conn).await.map_err(anyhow::Error::from)
404}
405
406/// Streams messages from Redis streams and sends them over the provided `tx` channel.
407///
408/// # Errors
409///
410/// Returns an error if:
411/// - Establishing the Redis connection fails.
412/// - Any Redis read operation fails.
413pub async fn stream_messages(
414    tx: tokio::sync::mpsc::Sender<BusMessage>,
415    config: DatabaseConfig,
416    stream_keys: Vec<String>,
417    stream_signal: Arc<AtomicBool>,
418) -> anyhow::Result<()> {
419    log_task_started(MSGBUS_STREAM);
420
421    let mut con = create_redis_connection(MSGBUS_STREAM, config).await?;
422
423    let stream_keys = &stream_keys
424        .iter()
425        .map(String::as_str)
426        .collect::<Vec<&str>>();
427
428    tracing::debug!("Listening to streams: [{}]", stream_keys.join(", "));
429
430    // Start streaming from current timestamp
431    let clock = get_atomic_clock_realtime();
432    let timestamp_ms = clock.get_time_ms();
433    let initial_id = timestamp_ms.to_string();
434
435    let mut last_ids: HashMap<String, String> = stream_keys
436        .iter()
437        .map(|&key| (key.to_string(), initial_id.clone()))
438        .collect();
439
440    let opts = StreamReadOptions::default().block(100);
441
442    'outer: loop {
443        if stream_signal.load(Ordering::Relaxed) {
444            tracing::debug!("Received streaming terminate signal");
445            break;
446        }
447
448        let ids: Vec<String> = stream_keys
449            .iter()
450            .map(|&key| last_ids[key].clone())
451            .collect();
452        let id_refs: Vec<&str> = ids.iter().map(String::as_str).collect();
453
454        let result: Result<RedisStreamBulk, _> =
455            con.xread_options(&[&stream_keys], &[&id_refs], &opts).await;
456        match result {
457            Ok(stream_bulk) => {
458                if stream_bulk.is_empty() {
459                    // Timeout occurred: no messages received
460                    continue;
461                }
462                for entry in &stream_bulk {
463                    for (stream_key, stream_msgs) in entry {
464                        for stream_msg in stream_msgs {
465                            for (id, array) in stream_msg {
466                                last_ids.insert(stream_key.clone(), id.clone());
467
468                                match decode_bus_message(array) {
469                                    Ok(msg) => {
470                                        if let Err(e) = tx.send(msg).await {
471                                            tracing::debug!("Channel closed: {e:?}");
472                                            break 'outer; // End streaming
473                                        }
474                                    }
475                                    Err(e) => {
476                                        tracing::error!("{e:?}");
477                                        continue;
478                                    }
479                                }
480                            }
481                        }
482                    }
483                }
484            }
485            Err(e) => {
486                anyhow::bail!("Error reading from stream: {e:?}");
487            }
488        }
489    }
490
491    log_task_stopped(MSGBUS_STREAM);
492    Ok(())
493}
494
495/// Decodes a Redis stream message value into a `BusMessage`.
496///
497/// # Errors
498///
499/// Returns an error if:
500/// - The incoming `stream_msg` is not an array.
501/// - The array has fewer than four elements (invalid format).
502/// - Parsing the topic or payload fails.
503fn decode_bus_message(stream_msg: &redis::Value) -> anyhow::Result<BusMessage> {
504    if let redis::Value::Array(stream_msg) = stream_msg {
505        if stream_msg.len() < 4 {
506            anyhow::bail!("Invalid stream message format: {stream_msg:?}");
507        }
508
509        let topic = match &stream_msg[1] {
510            redis::Value::BulkString(bytes) => match String::from_utf8(bytes.clone()) {
511                Ok(topic) => topic,
512                Err(e) => anyhow::bail!("Error parsing topic: {e}"),
513            },
514            _ => {
515                anyhow::bail!("Invalid topic format: {stream_msg:?}");
516            }
517        };
518
519        let payload = match &stream_msg[3] {
520            redis::Value::BulkString(bytes) => Bytes::copy_from_slice(bytes),
521            _ => {
522                anyhow::bail!("Invalid payload format: {stream_msg:?}");
523            }
524        };
525
526        Ok(BusMessage::with_str_topic(topic, payload))
527    } else {
528        anyhow::bail!("Invalid stream message format: {stream_msg:?}")
529    }
530}
531
532async fn run_heartbeat(
533    heartbeat_interval_secs: u16,
534    signal: Arc<AtomicBool>,
535    pub_tx: tokio::sync::mpsc::UnboundedSender<BusMessage>,
536) {
537    log_task_started("heartbeat");
538    tracing::debug!("Heartbeat at {heartbeat_interval_secs} second intervals");
539
540    let heartbeat_interval = Duration::from_secs(u64::from(heartbeat_interval_secs));
541    let heartbeat_timer = tokio::time::interval(heartbeat_interval);
542
543    let check_interval = Duration::from_millis(100);
544    let check_timer = tokio::time::interval(check_interval);
545
546    tokio::pin!(heartbeat_timer);
547    tokio::pin!(check_timer);
548
549    loop {
550        if signal.load(Ordering::Relaxed) {
551            tracing::debug!("Received heartbeat terminate signal");
552            break;
553        }
554
555        tokio::select! {
556            _ = heartbeat_timer.tick() => {
557                let heartbeat = create_heartbeat_msg();
558                if let Err(e) = pub_tx.send(heartbeat) {
559                    // We expect an error if the channel is closed during shutdown
560                    tracing::debug!("Error sending heartbeat: {e}");
561                }
562            },
563            _ = check_timer.tick() => {}
564        }
565    }
566
567    log_task_stopped("heartbeat");
568}
569
570fn create_heartbeat_msg() -> BusMessage {
571    let payload = Bytes::from(chrono::Utc::now().to_rfc3339().into_bytes());
572    BusMessage::with_str_topic(HEARTBEAT_TOPIC, payload)
573}
574
575////////////////////////////////////////////////////////////////////////////////
576// Tests
577////////////////////////////////////////////////////////////////////////////////
578#[cfg(test)]
579mod tests {
580    use redis::Value;
581    use rstest::*;
582
583    use super::*;
584
585    #[rstest]
586    fn test_decode_bus_message_valid() {
587        let stream_msg = Value::Array(vec![
588            Value::BulkString(b"0".to_vec()),
589            Value::BulkString(b"topic1".to_vec()),
590            Value::BulkString(b"unused".to_vec()),
591            Value::BulkString(b"data1".to_vec()),
592        ]);
593
594        let result = decode_bus_message(&stream_msg);
595        assert!(result.is_ok());
596        let msg = result.unwrap();
597        assert_eq!(msg.topic, "topic1");
598        assert_eq!(msg.payload, Bytes::from("data1"));
599    }
600
601    #[rstest]
602    fn test_decode_bus_message_missing_fields() {
603        let stream_msg = Value::Array(vec![
604            Value::BulkString(b"0".to_vec()),
605            Value::BulkString(b"topic1".to_vec()),
606        ]);
607
608        let result = decode_bus_message(&stream_msg);
609        assert!(result.is_err());
610        assert_eq!(
611            format!("{}", result.unwrap_err()),
612            "Invalid stream message format: [bulk-string('\"0\"'), bulk-string('\"topic1\"')]"
613        );
614    }
615
616    #[rstest]
617    fn test_decode_bus_message_invalid_topic_format() {
618        let stream_msg = Value::Array(vec![
619            Value::BulkString(b"0".to_vec()),
620            Value::Int(42), // Invalid topic format
621            Value::BulkString(b"unused".to_vec()),
622            Value::BulkString(b"data1".to_vec()),
623        ]);
624
625        let result = decode_bus_message(&stream_msg);
626        assert!(result.is_err());
627        assert_eq!(
628            format!("{}", result.unwrap_err()),
629            "Invalid topic format: [bulk-string('\"0\"'), int(42), bulk-string('\"unused\"'), bulk-string('\"data1\"')]"
630        );
631    }
632
633    #[rstest]
634    fn test_decode_bus_message_invalid_payload_format() {
635        let stream_msg = Value::Array(vec![
636            Value::BulkString(b"0".to_vec()),
637            Value::BulkString(b"topic1".to_vec()),
638            Value::BulkString(b"unused".to_vec()),
639            Value::Int(42), // Invalid payload format
640        ]);
641
642        let result = decode_bus_message(&stream_msg);
643        assert!(result.is_err());
644        assert_eq!(
645            format!("{}", result.unwrap_err()),
646            "Invalid payload format: [bulk-string('\"0\"'), bulk-string('\"topic1\"'), bulk-string('\"unused\"'), int(42)]"
647        );
648    }
649
650    #[rstest]
651    fn test_decode_bus_message_invalid_stream_msg_format() {
652        let stream_msg = Value::BulkString(b"not an array".to_vec());
653
654        let result = decode_bus_message(&stream_msg);
655        assert!(result.is_err());
656        assert_eq!(
657            format!("{}", result.unwrap_err()),
658            "Invalid stream message format: bulk-string('\"not an array\"')"
659        );
660    }
661}
662
663#[cfg(target_os = "linux")] // Run Redis tests on Linux platforms only
664#[cfg(test)]
665mod serial_tests {
666    use nautilus_common::testing::wait_until_async;
667    use redis::aio::ConnectionManager;
668    use rstest::*;
669
670    use super::*;
671    use crate::redis::flush_redis;
672
673    #[fixture]
674    async fn redis_connection() -> ConnectionManager {
675        let config = DatabaseConfig::default();
676        let mut con = create_redis_connection(MSGBUS_STREAM, config)
677            .await
678            .unwrap();
679        flush_redis(&mut con).await.unwrap();
680        con
681    }
682
683    #[rstest]
684    #[tokio::test(flavor = "multi_thread")]
685    async fn test_stream_messages_terminate_signal(#[future] redis_connection: ConnectionManager) {
686        let mut con = redis_connection.await;
687        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
688
689        let trader_id = TraderId::from("tester-001");
690        let instance_id = UUID4::new();
691        let config = MessageBusConfig {
692            database: Some(DatabaseConfig::default()),
693            ..Default::default()
694        };
695
696        let stream_key = get_stream_key(trader_id, instance_id, &config);
697        let external_streams = vec![stream_key.clone()];
698        let stream_signal = Arc::new(AtomicBool::new(false));
699        let stream_signal_clone = stream_signal.clone();
700
701        // Start the message streaming task
702        let handle = tokio::spawn(async move {
703            stream_messages(
704                tx,
705                DatabaseConfig::default(),
706                external_streams,
707                stream_signal_clone,
708            )
709            .await
710            .unwrap();
711        });
712
713        stream_signal.store(true, Ordering::Relaxed);
714        let _ = rx.recv().await; // Wait for the tx to close
715
716        // Shutdown and cleanup
717        rx.close();
718        handle.await.unwrap();
719        flush_redis(&mut con).await.unwrap();
720    }
721
722    #[rstest]
723    #[tokio::test(flavor = "multi_thread")]
724    async fn test_stream_messages_when_receiver_closed(
725        #[future] redis_connection: ConnectionManager,
726    ) {
727        let mut con = redis_connection.await;
728        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
729
730        let trader_id = TraderId::from("tester-001");
731        let instance_id = UUID4::new();
732        let config = MessageBusConfig {
733            database: Some(DatabaseConfig::default()),
734            ..Default::default()
735        };
736
737        let stream_key = get_stream_key(trader_id, instance_id, &config);
738        let external_streams = vec![stream_key.clone()];
739        let stream_signal = Arc::new(AtomicBool::new(false));
740        let stream_signal_clone = stream_signal.clone();
741
742        // Use a message ID in the future, as streaming begins
743        // around the timestamp the task is spawned.
744        let clock = get_atomic_clock_realtime();
745        let future_id = (clock.get_time_ms() + 1_000_000).to_string();
746
747        // Publish test message
748        let _: () = con
749            .xadd(
750                stream_key,
751                future_id,
752                &[("topic", "topic1"), ("payload", "data1")],
753            )
754            .await
755            .unwrap();
756
757        // Immediately close channel
758        rx.close();
759
760        // Start the message streaming task
761        let handle = tokio::spawn(async move {
762            stream_messages(
763                tx,
764                DatabaseConfig::default(),
765                external_streams,
766                stream_signal_clone,
767            )
768            .await
769            .unwrap();
770        });
771
772        // Shutdown and cleanup
773        handle.await.unwrap();
774        flush_redis(&mut con).await.unwrap();
775    }
776
777    #[rstest]
778    #[tokio::test(flavor = "multi_thread")]
779    async fn test_stream_messages(#[future] redis_connection: ConnectionManager) {
780        let mut con = redis_connection.await;
781        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
782
783        let trader_id = TraderId::from("tester-001");
784        let instance_id = UUID4::new();
785        let config = MessageBusConfig {
786            database: Some(DatabaseConfig::default()),
787            ..Default::default()
788        };
789
790        let stream_key = get_stream_key(trader_id, instance_id, &config);
791        let external_streams = vec![stream_key.clone()];
792        let stream_signal = Arc::new(AtomicBool::new(false));
793        let stream_signal_clone = stream_signal.clone();
794
795        // Use a message ID in the future, as streaming begins
796        // around the timestamp the task is spawned.
797        let clock = get_atomic_clock_realtime();
798        let future_id = (clock.get_time_ms() + 1_000_000).to_string();
799
800        // Publish test message
801        let _: () = con
802            .xadd(
803                stream_key,
804                future_id,
805                &[("topic", "topic1"), ("payload", "data1")],
806            )
807            .await
808            .unwrap();
809
810        // Start the message streaming task
811        let handle = tokio::spawn(async move {
812            stream_messages(
813                tx,
814                DatabaseConfig::default(),
815                external_streams,
816                stream_signal_clone,
817            )
818            .await
819            .unwrap();
820        });
821
822        // Receive and verify the message
823        let msg = rx.recv().await.unwrap();
824        assert_eq!(msg.topic, "topic1");
825        assert_eq!(msg.payload, Bytes::from("data1"));
826
827        // Shutdown and cleanup
828        rx.close();
829        stream_signal.store(true, Ordering::Relaxed);
830        handle.await.unwrap();
831        flush_redis(&mut con).await.unwrap();
832    }
833
834    #[rstest]
835    #[tokio::test(flavor = "multi_thread")]
836    async fn test_publish_messages(#[future] redis_connection: ConnectionManager) {
837        let mut con = redis_connection.await;
838        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
839
840        let trader_id = TraderId::from("tester-001");
841        let instance_id = UUID4::new();
842        let config = MessageBusConfig {
843            database: Some(DatabaseConfig::default()),
844            stream_per_topic: false,
845            ..Default::default()
846        };
847        let stream_key = get_stream_key(trader_id, instance_id, &config);
848
849        // Start the publish_messages task
850        let handle = tokio::spawn(async move {
851            publish_messages(rx, trader_id, instance_id, config)
852                .await
853                .unwrap();
854        });
855
856        // Send a test message
857        let msg = BusMessage::with_str_topic("test_topic", Bytes::from("test_payload"));
858        tx.send(msg).unwrap();
859
860        // Wait until the message is published to Redis
861        wait_until_async(
862            || {
863                let mut con = con.clone();
864                let stream_key = stream_key.clone();
865                async move {
866                    let messages: RedisStreamBulk =
867                        con.xread(&[&stream_key], &["0"]).await.unwrap();
868                    !messages.is_empty()
869                }
870            },
871            Duration::from_secs(3),
872        )
873        .await;
874
875        // Verify the message was published to Redis
876        let messages: RedisStreamBulk = con.xread(&[&stream_key], &["0"]).await.unwrap();
877        assert_eq!(messages.len(), 1);
878        let stream_msgs = messages[0].get(&stream_key).unwrap();
879        let stream_msg_array = &stream_msgs[0].values().next().unwrap();
880        let decoded_message = decode_bus_message(stream_msg_array).unwrap();
881        assert_eq!(decoded_message.topic, "test_topic");
882        assert_eq!(decoded_message.payload, Bytes::from("test_payload"));
883
884        // Stop publishing task
885        let msg = BusMessage::new_close();
886        tx.send(msg).unwrap();
887
888        // Shutdown and cleanup
889        handle.await.unwrap();
890        flush_redis(&mut con).await.unwrap();
891    }
892
893    #[rstest]
894    #[tokio::test(flavor = "multi_thread")]
895    async fn test_stream_messages_multiple_streams(#[future] redis_connection: ConnectionManager) {
896        let mut con = redis_connection.await;
897        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
898
899        // Setup multiple stream keys
900        let stream_key1 = "test:stream:1".to_string();
901        let stream_key2 = "test:stream:2".to_string();
902        let external_streams = vec![stream_key1.clone(), stream_key2.clone()];
903        let stream_signal = Arc::new(AtomicBool::new(false));
904        let stream_signal_clone = stream_signal.clone();
905
906        let clock = get_atomic_clock_realtime();
907        let base_id = clock.get_time_ms() + 1_000_000;
908
909        // Start streaming task
910        let handle = tokio::spawn(async move {
911            stream_messages(
912                tx,
913                DatabaseConfig::default(),
914                external_streams,
915                stream_signal_clone,
916            )
917            .await
918            .unwrap();
919        });
920
921        tokio::time::sleep(Duration::from_millis(200)).await;
922
923        // Publish to stream 1 at higher ID
924        let _: () = con
925            .xadd(
926                &stream_key1,
927                format!("{}", base_id + 100),
928                &[("topic", "stream1-first"), ("payload", "data")],
929            )
930            .await
931            .unwrap();
932
933        let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
934            .await
935            .expect("Stream 1 message should be received")
936            .unwrap();
937        assert_eq!(msg.topic, "stream1-first");
938
939        // Publish to stream 2 at lower ID (tests independent cursor tracking)
940        let _: () = con
941            .xadd(
942                &stream_key2,
943                format!("{}", base_id + 50),
944                &[("topic", "stream2-second"), ("payload", "data")],
945            )
946            .await
947            .unwrap();
948
949        let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
950            .await
951            .expect("Stream 2 message should be received")
952            .unwrap();
953        assert_eq!(msg.topic, "stream2-second");
954
955        // Shutdown and cleanup
956        rx.close();
957        stream_signal.store(true, Ordering::Relaxed);
958        handle.await.unwrap();
959        flush_redis(&mut con).await.unwrap();
960    }
961
962    #[rstest]
963    #[tokio::test(flavor = "multi_thread")]
964    async fn test_stream_messages_interleaved_at_different_rates(
965        #[future] redis_connection: ConnectionManager,
966    ) {
967        let mut con = redis_connection.await;
968        let (tx, mut rx) = tokio::sync::mpsc::channel::<BusMessage>(100);
969
970        // Setup multiple stream keys
971        let stream_key1 = "test:stream:interleaved:1".to_string();
972        let stream_key2 = "test:stream:interleaved:2".to_string();
973        let stream_key3 = "test:stream:interleaved:3".to_string();
974        let external_streams = vec![
975            stream_key1.clone(),
976            stream_key2.clone(),
977            stream_key3.clone(),
978        ];
979        let stream_signal = Arc::new(AtomicBool::new(false));
980        let stream_signal_clone = stream_signal.clone();
981
982        let clock = get_atomic_clock_realtime();
983        let base_id = clock.get_time_ms() + 1_000_000;
984
985        let handle = tokio::spawn(async move {
986            stream_messages(
987                tx,
988                DatabaseConfig::default(),
989                external_streams,
990                stream_signal_clone,
991            )
992            .await
993            .unwrap();
994        });
995
996        tokio::time::sleep(Duration::from_millis(200)).await;
997
998        // Stream 1 advances with high ID
999        let _: () = con
1000            .xadd(
1001                &stream_key1,
1002                format!("{}", base_id + 100),
1003                &[("topic", "s1m1"), ("payload", "data")],
1004            )
1005            .await
1006            .unwrap();
1007        let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1008            .await
1009            .expect("Stream 1 message should be received")
1010            .unwrap();
1011        assert_eq!(msg.topic, "s1m1");
1012
1013        // Stream 2 gets message at lower ID - would be skipped with global cursor
1014        let _: () = con
1015            .xadd(
1016                &stream_key2,
1017                format!("{}", base_id + 50),
1018                &[("topic", "s2m1"), ("payload", "data")],
1019            )
1020            .await
1021            .unwrap();
1022        let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1023            .await
1024            .expect("Stream 2 message should be received")
1025            .unwrap();
1026        assert_eq!(msg.topic, "s2m1");
1027
1028        // Stream 3 gets message at even lower ID
1029        let _: () = con
1030            .xadd(
1031                &stream_key3,
1032                format!("{}", base_id + 25),
1033                &[("topic", "s3m1"), ("payload", "data")],
1034            )
1035            .await
1036            .unwrap();
1037        let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1038            .await
1039            .expect("Stream 3 message should be received")
1040            .unwrap();
1041        assert_eq!(msg.topic, "s3m1");
1042
1043        // Shutdown and cleanup
1044        rx.close();
1045        stream_signal.store(true, Ordering::Relaxed);
1046        handle.await.unwrap();
1047        flush_redis(&mut con).await.unwrap();
1048    }
1049
1050    #[rstest]
1051    #[tokio::test(flavor = "multi_thread")]
1052    async fn test_close() {
1053        let trader_id = TraderId::from("tester-001");
1054        let instance_id = UUID4::new();
1055        let config = MessageBusConfig {
1056            database: Some(DatabaseConfig::default()),
1057            ..Default::default()
1058        };
1059
1060        let mut db = RedisMessageBusDatabase::new(trader_id, instance_id, config).unwrap();
1061
1062        // Close the message bus database (test should not hang)
1063        db.close();
1064    }
1065
1066    #[rstest]
1067    #[tokio::test(flavor = "multi_thread")]
1068    async fn test_heartbeat_task() {
1069        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<BusMessage>();
1070        let signal = Arc::new(AtomicBool::new(false));
1071
1072        // Start the heartbeat task with a short interval
1073        let handle = tokio::spawn(run_heartbeat(1, signal.clone(), tx));
1074
1075        // Wait for a couple of heartbeats
1076        tokio::time::sleep(Duration::from_secs(2)).await;
1077
1078        // Stop the heartbeat task
1079        signal.store(true, Ordering::Relaxed);
1080        handle.await.unwrap();
1081
1082        // Ensure heartbeats were sent
1083        let mut heartbeats: Vec<BusMessage> = Vec::new();
1084        while let Ok(hb) = rx.try_recv() {
1085            heartbeats.push(hb);
1086        }
1087
1088        assert!(!heartbeats.is_empty());
1089
1090        for hb in heartbeats {
1091            assert_eq!(hb.topic, HEARTBEAT_TOPIC);
1092        }
1093    }
1094}