nautilus_infrastructure/redis/
mod.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Provides a Redis backed `CacheDatabase` and `MessageBusDatabase` implementation.
17
18pub mod cache;
19pub mod msgbus;
20
21use std::time::Duration;
22
23use nautilus_common::msgbus::database::{DatabaseConfig, MessageBusConfig};
24use nautilus_core::UUID4;
25use nautilus_model::identifiers::TraderId;
26use redis::*;
27use semver::Version;
28
29const REDIS_MIN_VERSION: &str = "6.2.0";
30const REDIS_DELIMITER: char = ':';
31const REDIS_XTRIM: &str = "XTRIM";
32const REDIS_MINID: &str = "MINID";
33const REDIS_FLUSHDB: &str = "FLUSHDB";
34
35async fn await_handle(handle: Option<tokio::task::JoinHandle<()>>, task_name: &str) {
36    if let Some(handle) = handle {
37        tracing::debug!("Awaiting task '{task_name}'");
38        let timeout = Duration::from_secs(2);
39        match tokio::time::timeout(timeout, handle).await {
40            Ok(result) => {
41                if let Err(e) = result {
42                    log::error!("Error awaiting task '{task_name}': {e:?}");
43                }
44            }
45            Err(_) => {
46                log::error!("Timeout {timeout:?} awaiting task '{task_name}'");
47            }
48        }
49    }
50}
51
52/// Parse a Redis connection url from the given database config.
53pub fn get_redis_url(config: DatabaseConfig) -> (String, String) {
54    let host = config.host.unwrap_or("127.0.0.1".to_string());
55    let port = config.port.unwrap_or(6379);
56    let username = config.username.unwrap_or("".to_string());
57    let password = config.password.unwrap_or("".to_string());
58    let use_ssl = config.ssl;
59
60    let redacted_password = if password.len() > 4 {
61        format!("{}...{}", &password[..2], &password[password.len() - 2..],)
62    } else {
63        password.to_string()
64    };
65
66    let auth_part = if !username.is_empty() && !password.is_empty() {
67        format!("{}:{}@", username, password)
68    } else {
69        String::new()
70    };
71
72    let redacted_auth_part = if !username.is_empty() && !password.is_empty() {
73        format!("{}:{}@", username, redacted_password)
74    } else {
75        String::new()
76    };
77
78    let url = format!(
79        "redis{}://{}{}:{}",
80        if use_ssl { "s" } else { "" },
81        auth_part,
82        host,
83        port
84    );
85
86    let redacted_url = format!(
87        "redis{}://{}{}:{}",
88        if use_ssl { "s" } else { "" },
89        redacted_auth_part,
90        host,
91        port
92    );
93
94    (url, redacted_url)
95}
96
97/// Create a new Redis database connection from the given database config.
98pub fn create_redis_connection(
99    con_name: &str,
100    config: DatabaseConfig,
101) -> anyhow::Result<redis::Connection> {
102    tracing::debug!("Creating {con_name} redis connection");
103    let (redis_url, redacted_url) = get_redis_url(config.clone());
104    tracing::debug!("Connecting to {redacted_url}");
105    let timeout = Duration::from_secs(config.timeout as u64);
106    let client = redis::Client::open(redis_url)?;
107    let mut con = client.get_connection_with_timeout(timeout)?;
108
109    let version = get_redis_version(&mut con)?;
110    let min_version = Version::parse(REDIS_MIN_VERSION)?;
111    let con_msg = format!("Connected to redis v{version}");
112
113    if version >= min_version {
114        tracing::info!(con_msg);
115    } else {
116        // TODO: Using `log` error here so that the message is displayed regardless of whether
117        // the logging config has pyo3 enabled. Later we can standardize this to `tracing`.
118        log::error!("{con_msg}, but minimum supported version is {REDIS_MIN_VERSION}");
119    }
120
121    Ok(con)
122}
123
124/// Flush the Redis database for the given connection.
125pub fn flush_redis(con: &mut redis::Connection) -> anyhow::Result<(), RedisError> {
126    redis::cmd(REDIS_FLUSHDB).exec(con)
127}
128
129/// Parse the stream key from the given identifiers and config.
130pub fn get_stream_key(
131    trader_id: TraderId,
132    instance_id: UUID4,
133    config: &MessageBusConfig,
134) -> String {
135    let mut stream_key = String::new();
136
137    if config.use_trader_prefix {
138        stream_key.push_str("trader-");
139    }
140
141    if config.use_trader_id {
142        stream_key.push_str(trader_id.as_str());
143        stream_key.push(REDIS_DELIMITER);
144    }
145
146    if config.use_instance_id {
147        stream_key.push_str(&format!("{instance_id}"));
148        stream_key.push(REDIS_DELIMITER);
149    }
150
151    stream_key.push_str(&config.streams_prefix);
152    stream_key
153}
154
155/// Parses the Redis version from the "INFO" command output.
156pub fn get_redis_version(conn: &mut Connection) -> anyhow::Result<Version> {
157    let info: String = redis::cmd("INFO").query(conn)?;
158    let version_str = info
159        .lines()
160        .find_map(|line| {
161            if line.starts_with("redis_version:") {
162                line.split(':').nth(1).map(|s| s.trim().to_string())
163            } else {
164                None
165            }
166        })
167        .expect("Redis version not available");
168
169    parse_redis_version(&version_str)
170}
171
172fn parse_redis_version(version_str: &str) -> anyhow::Result<Version> {
173    let mut components = version_str.split('.').map(|s| s.parse::<u64>());
174
175    let major = components.next().unwrap_or(Ok(0))?;
176    let minor = components.next().unwrap_or(Ok(0))?;
177    let patch = components.next().unwrap_or(Ok(0))?;
178
179    Ok(Version::new(major, minor, patch))
180}
181
182////////////////////////////////////////////////////////////////////////////////
183// Tests
184////////////////////////////////////////////////////////////////////////////////
185#[cfg(test)]
186mod tests {
187    use rstest::rstest;
188    use serde_json::json;
189
190    use super::*;
191
192    #[rstest]
193    fn test_get_redis_url_default_values() {
194        let config: DatabaseConfig = serde_json::from_value(json!({})).unwrap();
195        let (url, redacted_url) = get_redis_url(config);
196        assert_eq!(url, "redis://127.0.0.1:6379");
197        assert_eq!(redacted_url, "redis://127.0.0.1:6379");
198    }
199
200    #[rstest]
201    fn test_get_redis_url_full_config_with_ssl() {
202        let config_json = json!({
203            "host": "example.com",
204            "port": 6380,
205            "username": "user",
206            "password": "pass",
207            "ssl": true,
208        });
209        let config: DatabaseConfig = serde_json::from_value(config_json).unwrap();
210        let (url, redacted_url) = get_redis_url(config);
211        assert_eq!(url, "rediss://user:pass@example.com:6380");
212        assert_eq!(redacted_url, "rediss://user:pass@example.com:6380");
213    }
214
215    #[rstest]
216    fn test_get_redis_url_full_config_without_ssl() {
217        let config_json = json!({
218            "host": "example.com",
219            "port": 6380,
220            "username": "username",
221            "password": "password",
222            "ssl": false,
223        });
224        let config: DatabaseConfig = serde_json::from_value(config_json).unwrap();
225        let (url, redacted_url) = get_redis_url(config);
226        assert_eq!(url, "redis://username:password@example.com:6380");
227        assert_eq!(redacted_url, "redis://username:pa...rd@example.com:6380");
228    }
229
230    #[rstest]
231    fn test_get_redis_url_missing_username_and_password() {
232        let config_json = json!({
233            "host": "example.com",
234            "port": 6380,
235            "ssl": false,
236        });
237        let config: DatabaseConfig = serde_json::from_value(config_json).unwrap();
238        let (url, redacted_url) = get_redis_url(config);
239        assert_eq!(url, "redis://example.com:6380");
240        assert_eq!(redacted_url, "redis://example.com:6380");
241    }
242
243    #[rstest]
244    fn test_get_redis_url_ssl_default_false() {
245        let config_json = json!({
246            "host": "example.com",
247            "port": 6380,
248            "username": "username",
249            "password": "password",
250            // "ssl" is intentionally omitted to test default behavior
251        });
252        let config: DatabaseConfig = serde_json::from_value(config_json).unwrap();
253        let (url, redacted_url) = get_redis_url(config);
254        assert_eq!(url, "redis://username:password@example.com:6380");
255        assert_eq!(redacted_url, "redis://username:pa...rd@example.com:6380");
256    }
257
258    #[rstest]
259    fn test_get_stream_key_with_trader_prefix_and_instance_id() {
260        let trader_id = TraderId::from("tester-123");
261        let instance_id = UUID4::new();
262        let mut config = MessageBusConfig::default();
263        config.use_instance_id = true;
264
265        let key = get_stream_key(trader_id, instance_id, &config);
266        assert_eq!(key, format!("trader-tester-123:{instance_id}:stream"));
267    }
268
269    #[rstest]
270    fn test_get_stream_key_without_trader_prefix_or_instance_id() {
271        let trader_id = TraderId::from("tester-123");
272        let instance_id = UUID4::new();
273        let mut config = MessageBusConfig::default();
274        config.use_trader_prefix = false;
275        config.use_trader_id = false;
276
277        let key = get_stream_key(trader_id, instance_id, &config);
278        assert_eq!(key, format!("stream"));
279    }
280}