nautilus_infrastructure/redis/
mod.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Provides a Redis backed `CacheDatabase` and `MessageBusDatabase` implementation.
17
18pub mod cache;
19pub mod msgbus;
20pub mod queries;
21
22use std::time::Duration;
23
24use nautilus_common::msgbus::database::{DatabaseConfig, MessageBusConfig};
25use nautilus_core::UUID4;
26use nautilus_model::identifiers::TraderId;
27use redis::*;
28use semver::Version;
29
30const REDIS_MIN_VERSION: &str = "6.2.0";
31const REDIS_DELIMITER: char = ':';
32const REDIS_XTRIM: &str = "XTRIM";
33const REDIS_MINID: &str = "MINID";
34const REDIS_FLUSHDB: &str = "FLUSHDB";
35
36async fn await_handle(handle: Option<tokio::task::JoinHandle<()>>, task_name: &str) {
37    if let Some(handle) = handle {
38        tracing::debug!("Awaiting task '{task_name}'");
39        let timeout = Duration::from_secs(2);
40        match tokio::time::timeout(timeout, handle).await {
41            Ok(result) => {
42                if let Err(e) = result {
43                    log::error!("Error awaiting task '{task_name}': {e:?}");
44                }
45            }
46            Err(_) => {
47                log::error!("Timeout {timeout:?} awaiting task '{task_name}'");
48            }
49        }
50    }
51}
52
53/// Parse a Redis connection url from the given database config.
54pub fn get_redis_url(config: DatabaseConfig) -> (String, String) {
55    let host = config.host.unwrap_or("127.0.0.1".to_string());
56    let port = config.port.unwrap_or(6379);
57    let username = config.username.unwrap_or("".to_string());
58    let password = config.password.unwrap_or("".to_string());
59    let use_ssl = config.ssl;
60
61    let redacted_password = if password.len() > 4 {
62        format!("{}...{}", &password[..2], &password[password.len() - 2..],)
63    } else {
64        password.to_string()
65    };
66
67    let auth_part = if !username.is_empty() && !password.is_empty() {
68        format!("{}:{}@", username, password)
69    } else {
70        String::new()
71    };
72
73    let redacted_auth_part = if !username.is_empty() && !password.is_empty() {
74        format!("{}:{}@", username, redacted_password)
75    } else {
76        String::new()
77    };
78
79    let url = format!(
80        "redis{}://{}{}:{}",
81        if use_ssl { "s" } else { "" },
82        auth_part,
83        host,
84        port
85    );
86
87    let redacted_url = format!(
88        "redis{}://{}{}:{}",
89        if use_ssl { "s" } else { "" },
90        redacted_auth_part,
91        host,
92        port
93    );
94
95    (url, redacted_url)
96}
97
98/// Create a new Redis database connection from the given database config.
99///
100/// In case of reconnection issues, the connection will retry reconnection
101/// `number_of_retries` times, with an exponentially increasing delay, calculated as
102/// `rand(0 .. factor * (exponent_base ^ current-try))`.
103///
104/// Apply a maximum delay. No retry delay will be longer than this `max_delay` .
105///
106/// The new connection will time out operations after `response_timeout` has passed.
107/// Each connection attempt to the server will time out after `connection_timeout`.
108pub async fn create_redis_connection(
109    con_name: &str,
110    config: DatabaseConfig,
111) -> anyhow::Result<redis::aio::ConnectionManager> {
112    tracing::debug!("Creating {con_name} redis connection");
113    let (redis_url, redacted_url) = get_redis_url(config.clone());
114    tracing::debug!("Connecting to {redacted_url}");
115
116    let connection_timeout = Duration::from_secs(config.connection_timeout as u64);
117    let response_timeout = Duration::from_secs(config.response_timeout as u64);
118    let number_of_retries = config.number_of_retries;
119    let exponent_base = config.exponent_base;
120    let factor = config.factor;
121
122    // into milliseconds
123    let max_delay = config.max_delay * 1000;
124
125    let client = redis::Client::open(redis_url)?;
126
127    let connection_manager_config = redis::aio::ConnectionManagerConfig::new()
128        .set_exponent_base(exponent_base)
129        .set_factor(factor)
130        .set_number_of_retries(number_of_retries)
131        .set_response_timeout(response_timeout)
132        .set_connection_timeout(connection_timeout)
133        .set_max_delay(max_delay);
134
135    let mut con = client
136        .get_connection_manager_with_config(connection_manager_config)
137        .await?;
138
139    let version = get_redis_version(&mut con).await?;
140    let min_version = Version::parse(REDIS_MIN_VERSION)?;
141    let con_msg = format!("Connected to redis v{version}");
142
143    if version >= min_version {
144        tracing::info!(con_msg);
145    } else {
146        // TODO: Using `log` error here so that the message is displayed regardless of whether
147        // the logging config has pyo3 enabled. Later we can standardize this to `tracing`.
148        log::error!("{con_msg}, but minimum supported version is {REDIS_MIN_VERSION}");
149    }
150
151    Ok(con)
152}
153
154/// Flush the Redis database for the given connection.
155pub async fn flush_redis(
156    con: &mut redis::aio::ConnectionManager,
157) -> anyhow::Result<(), RedisError> {
158    redis::cmd(REDIS_FLUSHDB).exec_async(con).await
159}
160
161/// Parse the stream key from the given identifiers and config.
162pub fn get_stream_key(
163    trader_id: TraderId,
164    instance_id: UUID4,
165    config: &MessageBusConfig,
166) -> String {
167    let mut stream_key = String::new();
168
169    if config.use_trader_prefix {
170        stream_key.push_str("trader-");
171    }
172
173    if config.use_trader_id {
174        stream_key.push_str(trader_id.as_str());
175        stream_key.push(REDIS_DELIMITER);
176    }
177
178    if config.use_instance_id {
179        stream_key.push_str(&format!("{instance_id}"));
180        stream_key.push(REDIS_DELIMITER);
181    }
182
183    stream_key.push_str(&config.streams_prefix);
184    stream_key
185}
186
187/// Parses the Redis version from the "INFO" command output.
188pub async fn get_redis_version(
189    conn: &mut redis::aio::ConnectionManager,
190) -> anyhow::Result<Version> {
191    let info: String = redis::cmd("INFO").query_async(conn).await?;
192    let version_str = match info.lines().find_map(|line| {
193        if line.starts_with("redis_version:") {
194            line.split(':').nth(1).map(|s| s.trim().to_string())
195        } else {
196            None
197        }
198    }) {
199        Some(info) => info,
200        None => {
201            anyhow::bail!("Redis version not available");
202        }
203    };
204
205    parse_redis_version(&version_str)
206}
207
208fn parse_redis_version(version_str: &str) -> anyhow::Result<Version> {
209    let mut components = version_str.split('.').map(|s| s.parse::<u64>());
210
211    let major = components.next().unwrap_or(Ok(0))?;
212    let minor = components.next().unwrap_or(Ok(0))?;
213    let patch = components.next().unwrap_or(Ok(0))?;
214
215    Ok(Version::new(major, minor, patch))
216}
217
218////////////////////////////////////////////////////////////////////////////////
219// Tests
220////////////////////////////////////////////////////////////////////////////////
221#[cfg(test)]
222mod tests {
223    use rstest::rstest;
224    use serde_json::json;
225
226    use super::*;
227
228    #[rstest]
229    fn test_get_redis_url_default_values() {
230        let config: DatabaseConfig = serde_json::from_value(json!({})).unwrap();
231        let (url, redacted_url) = get_redis_url(config);
232        assert_eq!(url, "redis://127.0.0.1:6379");
233        assert_eq!(redacted_url, "redis://127.0.0.1:6379");
234    }
235
236    #[rstest]
237    fn test_get_redis_url_full_config_with_ssl() {
238        let config_json = json!({
239            "host": "example.com",
240            "port": 6380,
241            "username": "user",
242            "password": "pass",
243            "ssl": true,
244        });
245        let config: DatabaseConfig = serde_json::from_value(config_json).unwrap();
246        let (url, redacted_url) = get_redis_url(config);
247        assert_eq!(url, "rediss://user:pass@example.com:6380");
248        assert_eq!(redacted_url, "rediss://user:pass@example.com:6380");
249    }
250
251    #[rstest]
252    fn test_get_redis_url_full_config_without_ssl() {
253        let config_json = json!({
254            "host": "example.com",
255            "port": 6380,
256            "username": "username",
257            "password": "password",
258            "ssl": false,
259        });
260        let config: DatabaseConfig = serde_json::from_value(config_json).unwrap();
261        let (url, redacted_url) = get_redis_url(config);
262        assert_eq!(url, "redis://username:password@example.com:6380");
263        assert_eq!(redacted_url, "redis://username:pa...rd@example.com:6380");
264    }
265
266    #[rstest]
267    fn test_get_redis_url_missing_username_and_password() {
268        let config_json = json!({
269            "host": "example.com",
270            "port": 6380,
271            "ssl": false,
272        });
273        let config: DatabaseConfig = serde_json::from_value(config_json).unwrap();
274        let (url, redacted_url) = get_redis_url(config);
275        assert_eq!(url, "redis://example.com:6380");
276        assert_eq!(redacted_url, "redis://example.com:6380");
277    }
278
279    #[rstest]
280    fn test_get_redis_url_ssl_default_false() {
281        let config_json = json!({
282            "host": "example.com",
283            "port": 6380,
284            "username": "username",
285            "password": "password",
286            // "ssl" is intentionally omitted to test default behavior
287        });
288        let config: DatabaseConfig = serde_json::from_value(config_json).unwrap();
289        let (url, redacted_url) = get_redis_url(config);
290        assert_eq!(url, "redis://username:password@example.com:6380");
291        assert_eq!(redacted_url, "redis://username:pa...rd@example.com:6380");
292    }
293
294    #[rstest]
295    fn test_get_stream_key_with_trader_prefix_and_instance_id() {
296        let trader_id = TraderId::from("tester-123");
297        let instance_id = UUID4::new();
298        let mut config = MessageBusConfig::default();
299        config.use_instance_id = true;
300
301        let key = get_stream_key(trader_id, instance_id, &config);
302        assert_eq!(key, format!("trader-tester-123:{instance_id}:stream"));
303    }
304
305    #[rstest]
306    fn test_get_stream_key_without_trader_prefix_or_instance_id() {
307        let trader_id = TraderId::from("tester-123");
308        let instance_id = UUID4::new();
309        let mut config = MessageBusConfig::default();
310        config.use_trader_prefix = false;
311        config.use_trader_id = false;
312
313        let key = get_stream_key(trader_id, instance_id, &config);
314        assert_eq!(key, format!("stream"));
315    }
316}