1use std::{collections::HashMap, str::FromStr};
17
18use ahash::AHashMap;
19use bytes::Bytes;
20use chrono::{DateTime, Utc};
21use futures::future::join_all;
22use nautilus_common::{cache::database::CacheMap, enums::SerializationEncoding};
23use nautilus_model::{
24 accounts::AccountAny,
25 identifiers::{AccountId, ClientOrderId, InstrumentId, PositionId},
26 instruments::{InstrumentAny, SyntheticInstrument},
27 orders::OrderAny,
28 position::Position,
29 types::Currency,
30};
31use redis::{AsyncCommands, aio::ConnectionManager};
32use serde::{Serialize, de::DeserializeOwned};
33use serde_json::Value;
34use ustr::Ustr;
35
36use super::get_index_key;
37
38const INDEX: &str = "index";
40const GENERAL: &str = "general";
41const CURRENCIES: &str = "currencies";
42const INSTRUMENTS: &str = "instruments";
43const SYNTHETICS: &str = "synthetics";
44const ACCOUNTS: &str = "accounts";
45const ORDERS: &str = "orders";
46const POSITIONS: &str = "positions";
47const ACTORS: &str = "actors";
48const STRATEGIES: &str = "strategies";
49const REDIS_DELIMITER: char = ':';
50
51const INDEX_ORDER_IDS: &str = "index:order_ids";
53const INDEX_ORDER_POSITION: &str = "index:order_position";
54const INDEX_ORDER_CLIENT: &str = "index:order_client";
55const INDEX_ORDERS: &str = "index:orders";
56const INDEX_ORDERS_OPEN: &str = "index:orders_open";
57const INDEX_ORDERS_CLOSED: &str = "index:orders_closed";
58const INDEX_ORDERS_EMULATED: &str = "index:orders_emulated";
59const INDEX_ORDERS_INFLIGHT: &str = "index:orders_inflight";
60const INDEX_POSITIONS: &str = "index:positions";
61const INDEX_POSITIONS_OPEN: &str = "index:positions_open";
62const INDEX_POSITIONS_CLOSED: &str = "index:positions_closed";
63
64#[derive(Debug)]
65pub struct DatabaseQueries;
66
67impl DatabaseQueries {
68 pub fn serialize_payload<T: Serialize>(
74 encoding: SerializationEncoding,
75 payload: &T,
76 ) -> anyhow::Result<Vec<u8>> {
77 let mut value = serde_json::to_value(payload)?;
78 convert_timestamps(&mut value);
79 match encoding {
80 SerializationEncoding::MsgPack => rmp_serde::to_vec(&value)
81 .map_err(|e| anyhow::anyhow!("Failed to serialize msgpack `payload`: {e}")),
82 SerializationEncoding::Json => serde_json::to_vec(&value)
83 .map_err(|e| anyhow::anyhow!("Failed to serialize json `payload`: {e}")),
84 }
85 }
86
87 pub fn deserialize_payload<T: DeserializeOwned>(
93 encoding: SerializationEncoding,
94 payload: &[u8],
95 ) -> anyhow::Result<T> {
96 let mut value = match encoding {
97 SerializationEncoding::MsgPack => rmp_serde::from_slice(payload)
98 .map_err(|e| anyhow::anyhow!("Failed to deserialize msgpack `payload`: {e}"))?,
99 SerializationEncoding::Json => serde_json::from_slice(payload)
100 .map_err(|e| anyhow::anyhow!("Failed to deserialize json `payload`: {e}"))?,
101 };
102
103 convert_timestamp_strings(&mut value);
104
105 serde_json::from_value(value)
106 .map_err(|e| anyhow::anyhow!("Failed to convert value to target type: {e}"))
107 }
108
109 pub async fn scan_keys(
115 con: &mut ConnectionManager,
116 pattern: String,
117 ) -> anyhow::Result<Vec<String>> {
118 let mut result = Vec::new();
119 let mut cursor = 0u64;
120
121 loop {
122 let scan_result: (u64, Vec<String>) = redis::cmd("SCAN")
123 .arg(cursor)
124 .arg("MATCH")
125 .arg(&pattern)
126 .arg("COUNT")
127 .arg(5000)
128 .query_async(con)
129 .await?;
130
131 let (new_cursor, keys) = scan_result;
132 result.extend(keys);
133
134 if new_cursor == 0 {
136 break;
137 }
138
139 cursor = new_cursor;
140 }
141
142 Ok(result)
143 }
144
145 pub async fn read_bulk(
151 con: &ConnectionManager,
152 keys: &[String],
153 ) -> anyhow::Result<Vec<Option<Bytes>>> {
154 if keys.is_empty() {
155 return Ok(vec![]);
156 }
157
158 let mut con = con.clone();
159
160 let results: Vec<Option<Vec<u8>>> =
162 redis::cmd("MGET").arg(keys).query_async(&mut con).await?;
163
164 let bytes_results: Vec<Option<Bytes>> = results
166 .into_iter()
167 .map(|opt| opt.map(Bytes::from))
168 .collect();
169
170 Ok(bytes_results)
171 }
172
173 pub async fn read_bulk_batched(
182 con: &ConnectionManager,
183 keys: &[String],
184 batch_size: usize,
185 ) -> anyhow::Result<Vec<Option<Bytes>>> {
186 if batch_size == 0 {
187 anyhow::bail!("`batch_size` must be greater than zero");
188 }
189
190 if keys.is_empty() {
191 return Ok(vec![]);
192 }
193
194 let mut all_results: Vec<Option<Bytes>> = Vec::with_capacity(keys.len());
195
196 for chunk in keys.chunks(batch_size) {
197 let mut con = con.clone();
198
199 let results: Vec<Option<Vec<u8>>> =
200 redis::cmd("MGET").arg(chunk).query_async(&mut con).await?;
201
202 all_results.extend(results.into_iter().map(|opt| opt.map(Bytes::from)));
203 }
204
205 Ok(all_results)
206 }
207
208 pub async fn read(
214 con: &ConnectionManager,
215 trader_key: &str,
216 key: &str,
217 ) -> anyhow::Result<Vec<Bytes>> {
218 let collection = Self::get_collection_key(key)?;
219 let full_key = format!("{trader_key}{REDIS_DELIMITER}{key}");
220
221 let mut con = con.clone();
222
223 match collection {
224 INDEX => Self::read_index(&mut con, &full_key).await,
225 GENERAL => Self::read_string(&mut con, &full_key).await,
226 CURRENCIES => Self::read_string(&mut con, &full_key).await,
227 INSTRUMENTS => Self::read_string(&mut con, &full_key).await,
228 SYNTHETICS => Self::read_string(&mut con, &full_key).await,
229 ACCOUNTS => Self::read_list(&mut con, &full_key).await,
230 ORDERS => Self::read_list(&mut con, &full_key).await,
231 POSITIONS => Self::read_list(&mut con, &full_key).await,
232 ACTORS => Self::read_string(&mut con, &full_key).await,
233 STRATEGIES => Self::read_string(&mut con, &full_key).await,
234 _ => anyhow::bail!("Unsupported operation: `read` for collection '{collection}'"),
235 }
236 }
237
238 pub async fn load_all(
244 con: &ConnectionManager,
245 encoding: SerializationEncoding,
246 trader_key: &str,
247 ) -> anyhow::Result<CacheMap> {
248 let (currencies, instruments, synthetics, accounts, orders, positions) = tokio::try_join!(
249 Self::load_currencies(con, trader_key, encoding),
250 Self::load_instruments(con, trader_key, encoding),
251 Self::load_synthetics(con, trader_key, encoding),
252 Self::load_accounts(con, trader_key, encoding),
253 Self::load_orders(con, trader_key, encoding),
254 Self::load_positions(con, trader_key, encoding)
255 )
256 .map_err(|e| anyhow::anyhow!("Error loading cache data: {e}"))?;
257
258 let greeks = AHashMap::new();
261 let yield_curves = AHashMap::new();
262
263 Ok(CacheMap {
264 currencies,
265 instruments,
266 synthetics,
267 accounts,
268 orders,
269 positions,
270 greeks,
271 yield_curves,
272 })
273 }
274
275 pub async fn load_currencies(
281 con: &ConnectionManager,
282 trader_key: &str,
283 encoding: SerializationEncoding,
284 ) -> anyhow::Result<AHashMap<Ustr, Currency>> {
285 let mut currencies = AHashMap::new();
286 let pattern = format!("{trader_key}{REDIS_DELIMITER}{CURRENCIES}*");
287 log::debug!("Loading {pattern}");
288
289 let mut con = con.clone();
290 let keys = Self::scan_keys(&mut con, pattern).await?;
291
292 if keys.is_empty() {
293 return Ok(currencies);
294 }
295
296 let bulk_values = Self::read_bulk(&con, &keys).await?;
298
299 for (key, value_opt) in keys.iter().zip(bulk_values.iter()) {
301 let currency_code = if let Some(code) = key.as_str().rsplit(':').next() {
302 Ustr::from(code)
303 } else {
304 log::error!("Invalid key format: {key}");
305 continue;
306 };
307
308 if let Some(value_bytes) = value_opt {
309 match Self::deserialize_payload(encoding, value_bytes) {
310 Ok(currency) => {
311 currencies.insert(currency_code, currency);
312 }
313 Err(e) => {
314 log::error!("Failed to deserialize currency {currency_code}: {e}");
315 }
316 }
317 } else {
318 log::error!("Currency not found in Redis: {currency_code}");
319 }
320 }
321
322 log::debug!("Loaded {} currencies(s)", currencies.len());
323
324 Ok(currencies)
325 }
326
327 pub async fn load_instruments(
338 con: &ConnectionManager,
339 trader_key: &str,
340 encoding: SerializationEncoding,
341 ) -> anyhow::Result<AHashMap<InstrumentId, InstrumentAny>> {
342 let mut instruments = AHashMap::new();
343 let pattern = format!("{trader_key}{REDIS_DELIMITER}{INSTRUMENTS}*");
344 log::debug!("Loading {pattern}");
345
346 let mut con = con.clone();
347 let keys = Self::scan_keys(&mut con, pattern).await?;
348
349 let futures: Vec<_> = keys
350 .iter()
351 .map(|key| {
352 let con = con.clone();
353 async move {
354 let instrument_id = key
355 .as_str()
356 .rsplit(':')
357 .next()
358 .ok_or_else(|| {
359 log::error!("Invalid key format: {key}");
360 "Invalid key format"
361 })
362 .and_then(|code| {
363 InstrumentId::from_str(code).map_err(|e| {
364 log::error!("Failed to convert to InstrumentId for {key}: {e}");
365 "Invalid instrument ID"
366 })
367 });
368
369 let instrument_id = match instrument_id {
370 Ok(id) => id,
371 Err(_) => return None,
372 };
373
374 match Self::load_instrument(&con, trader_key, &instrument_id, encoding).await {
375 Ok(Some(instrument)) => Some((instrument_id, instrument)),
376 Ok(None) => {
377 log::error!("Instrument not found: {instrument_id}");
378 None
379 }
380 Err(e) => {
381 log::error!("Failed to load instrument {instrument_id}: {e}");
382 None
383 }
384 }
385 }
386 })
387 .collect();
388
389 instruments.extend(join_all(futures).await.into_iter().flatten());
391 log::debug!("Loaded {} instruments(s)", instruments.len());
392
393 Ok(instruments)
394 }
395
396 pub async fn load_synthetics(
407 con: &ConnectionManager,
408 trader_key: &str,
409 encoding: SerializationEncoding,
410 ) -> anyhow::Result<AHashMap<InstrumentId, SyntheticInstrument>> {
411 let mut synthetics = AHashMap::new();
412 let pattern = format!("{trader_key}{REDIS_DELIMITER}{SYNTHETICS}*");
413 log::debug!("Loading {pattern}");
414
415 let mut con = con.clone();
416 let keys = Self::scan_keys(&mut con, pattern).await?;
417
418 let futures: Vec<_> = keys
419 .iter()
420 .map(|key| {
421 let con = con.clone();
422 async move {
423 let instrument_id = key
424 .as_str()
425 .rsplit(':')
426 .next()
427 .ok_or_else(|| {
428 log::error!("Invalid key format: {key}");
429 "Invalid key format"
430 })
431 .and_then(|code| {
432 InstrumentId::from_str(code).map_err(|e| {
433 log::error!("Failed to parse InstrumentId for {key}: {e}");
434 "Invalid instrument ID"
435 })
436 });
437
438 let instrument_id = match instrument_id {
439 Ok(id) => id,
440 Err(_) => return None,
441 };
442
443 match Self::load_synthetic(&con, trader_key, &instrument_id, encoding).await {
444 Ok(Some(synthetic)) => Some((instrument_id, synthetic)),
445 Ok(None) => {
446 log::error!("Synthetic not found: {instrument_id}");
447 None
448 }
449 Err(e) => {
450 log::error!("Failed to load synthetic {instrument_id}: {e}");
451 None
452 }
453 }
454 }
455 })
456 .collect();
457
458 synthetics.extend(join_all(futures).await.into_iter().flatten());
460 log::debug!("Loaded {} synthetics(s)", synthetics.len());
461
462 Ok(synthetics)
463 }
464
465 pub async fn load_accounts(
476 con: &ConnectionManager,
477 trader_key: &str,
478 encoding: SerializationEncoding,
479 ) -> anyhow::Result<AHashMap<AccountId, AccountAny>> {
480 let mut accounts = AHashMap::new();
481 let pattern = format!("{trader_key}{REDIS_DELIMITER}{ACCOUNTS}*");
482 log::debug!("Loading {pattern}");
483
484 let mut con = con.clone();
485 let keys = Self::scan_keys(&mut con, pattern).await?;
486
487 let futures: Vec<_> = keys
488 .iter()
489 .map(|key| {
490 let con = con.clone();
491 async move {
492 let account_id = if let Some(code) = key.as_str().rsplit(':').next() {
493 AccountId::from(code)
494 } else {
495 log::error!("Invalid key format: {key}");
496 return None;
497 };
498
499 match Self::load_account(&con, trader_key, &account_id, encoding).await {
500 Ok(Some(account)) => Some((account_id, account)),
501 Ok(None) => {
502 log::error!("Account not found: {account_id}");
503 None
504 }
505 Err(e) => {
506 log::error!("Failed to load account {account_id}: {e}");
507 None
508 }
509 }
510 }
511 })
512 .collect();
513
514 accounts.extend(join_all(futures).await.into_iter().flatten());
516 log::debug!("Loaded {} accounts(s)", accounts.len());
517
518 Ok(accounts)
519 }
520
521 pub async fn load_orders(
532 con: &ConnectionManager,
533 trader_key: &str,
534 encoding: SerializationEncoding,
535 ) -> anyhow::Result<AHashMap<ClientOrderId, OrderAny>> {
536 let mut orders = AHashMap::new();
537 let pattern = format!("{trader_key}{REDIS_DELIMITER}{ORDERS}*");
538 log::debug!("Loading {pattern}");
539
540 let mut con = con.clone();
541 let keys = Self::scan_keys(&mut con, pattern).await?;
542
543 let futures: Vec<_> = keys
544 .iter()
545 .map(|key| {
546 let con = con.clone();
547 async move {
548 let client_order_id = if let Some(code) = key.as_str().rsplit(':').next() {
549 ClientOrderId::from(code)
550 } else {
551 log::error!("Invalid key format: {key}");
552 return None;
553 };
554
555 match Self::load_order(&con, trader_key, &client_order_id, encoding).await {
556 Ok(Some(order)) => Some((client_order_id, order)),
557 Ok(None) => {
558 log::error!("Order not found: {client_order_id}");
559 None
560 }
561 Err(e) => {
562 log::error!("Failed to load order {client_order_id}: {e}");
563 None
564 }
565 }
566 }
567 })
568 .collect();
569
570 orders.extend(join_all(futures).await.into_iter().flatten());
572 log::debug!("Loaded {} order(s)", orders.len());
573
574 Ok(orders)
575 }
576
577 pub async fn load_positions(
588 con: &ConnectionManager,
589 trader_key: &str,
590 encoding: SerializationEncoding,
591 ) -> anyhow::Result<AHashMap<PositionId, Position>> {
592 let mut positions = AHashMap::new();
593 let pattern = format!("{trader_key}{REDIS_DELIMITER}{POSITIONS}*");
594 log::debug!("Loading {pattern}");
595
596 let mut con = con.clone();
597 let keys = Self::scan_keys(&mut con, pattern).await?;
598
599 let futures: Vec<_> = keys
600 .iter()
601 .map(|key| {
602 let con = con.clone();
603 async move {
604 let position_id = if let Some(code) = key.as_str().rsplit(':').next() {
605 PositionId::from(code)
606 } else {
607 log::error!("Invalid key format: {key}");
608 return None;
609 };
610
611 match Self::load_position(&con, trader_key, &position_id, encoding).await {
612 Ok(Some(position)) => Some((position_id, position)),
613 Ok(None) => {
614 log::error!("Position not found: {position_id}");
615 None
616 }
617 Err(e) => {
618 log::error!("Failed to load position {position_id}: {e}");
619 None
620 }
621 }
622 }
623 })
624 .collect();
625
626 positions.extend(join_all(futures).await.into_iter().flatten());
628 log::debug!("Loaded {} position(s)", positions.len());
629
630 Ok(positions)
631 }
632
633 pub async fn load_currency(
639 con: &ConnectionManager,
640 trader_key: &str,
641 code: &Ustr,
642 encoding: SerializationEncoding,
643 ) -> anyhow::Result<Option<Currency>> {
644 let key = format!("{CURRENCIES}{REDIS_DELIMITER}{code}");
645 let result = Self::read(con, trader_key, &key).await?;
646
647 if result.is_empty() {
648 return Ok(None);
649 }
650
651 let currency = Self::deserialize_payload(encoding, &result[0])?;
652 Ok(currency)
653 }
654
655 pub async fn load_instrument(
661 con: &ConnectionManager,
662 trader_key: &str,
663 instrument_id: &InstrumentId,
664 encoding: SerializationEncoding,
665 ) -> anyhow::Result<Option<InstrumentAny>> {
666 let key = format!("{INSTRUMENTS}{REDIS_DELIMITER}{instrument_id}");
667 let result = Self::read(con, trader_key, &key).await?;
668 if result.is_empty() {
669 return Ok(None);
670 }
671
672 let instrument: InstrumentAny = Self::deserialize_payload(encoding, &result[0])?;
673 Ok(Some(instrument))
674 }
675
676 pub async fn load_synthetic(
682 con: &ConnectionManager,
683 trader_key: &str,
684 instrument_id: &InstrumentId,
685 encoding: SerializationEncoding,
686 ) -> anyhow::Result<Option<SyntheticInstrument>> {
687 let key = format!("{SYNTHETICS}{REDIS_DELIMITER}{instrument_id}");
688 let result = Self::read(con, trader_key, &key).await?;
689 if result.is_empty() {
690 return Ok(None);
691 }
692
693 let synthetic: SyntheticInstrument = Self::deserialize_payload(encoding, &result[0])?;
694 Ok(Some(synthetic))
695 }
696
697 pub async fn load_account(
703 con: &ConnectionManager,
704 trader_key: &str,
705 account_id: &AccountId,
706 encoding: SerializationEncoding,
707 ) -> anyhow::Result<Option<AccountAny>> {
708 let key = format!("{ACCOUNTS}{REDIS_DELIMITER}{account_id}");
709 let result = Self::read(con, trader_key, &key).await?;
710 if result.is_empty() {
711 return Ok(None);
712 }
713
714 let account: AccountAny = Self::deserialize_payload(encoding, &result[0])?;
715 Ok(Some(account))
716 }
717
718 pub async fn load_order(
724 con: &ConnectionManager,
725 trader_key: &str,
726 client_order_id: &ClientOrderId,
727 encoding: SerializationEncoding,
728 ) -> anyhow::Result<Option<OrderAny>> {
729 let key = format!("{ORDERS}{REDIS_DELIMITER}{client_order_id}");
730 let result = Self::read(con, trader_key, &key).await?;
731 if result.is_empty() {
732 return Ok(None);
733 }
734
735 let order: OrderAny = Self::deserialize_payload(encoding, &result[0])?;
736 Ok(Some(order))
737 }
738
739 pub async fn load_position(
745 con: &ConnectionManager,
746 trader_key: &str,
747 position_id: &PositionId,
748 encoding: SerializationEncoding,
749 ) -> anyhow::Result<Option<Position>> {
750 let key = format!("{POSITIONS}{REDIS_DELIMITER}{position_id}");
751 let result = Self::read(con, trader_key, &key).await?;
752 if result.is_empty() {
753 return Ok(None);
754 }
755
756 let position: Position = Self::deserialize_payload(encoding, &result[0])?;
757 Ok(Some(position))
758 }
759
760 fn get_collection_key(key: &str) -> anyhow::Result<&str> {
761 key.split_once(REDIS_DELIMITER)
762 .map(|(collection, _)| collection)
763 .ok_or_else(|| {
764 anyhow::anyhow!("Invalid `key`, missing a '{REDIS_DELIMITER}' delimiter, was {key}")
765 })
766 }
767
768 async fn read_index(conn: &mut ConnectionManager, key: &str) -> anyhow::Result<Vec<Bytes>> {
769 let index_key = get_index_key(key)?;
770 match index_key {
771 INDEX_ORDER_IDS => Self::read_set(conn, key).await,
772 INDEX_ORDER_POSITION => Self::read_hset(conn, key).await,
773 INDEX_ORDER_CLIENT => Self::read_hset(conn, key).await,
774 INDEX_ORDERS => Self::read_set(conn, key).await,
775 INDEX_ORDERS_OPEN => Self::read_set(conn, key).await,
776 INDEX_ORDERS_CLOSED => Self::read_set(conn, key).await,
777 INDEX_ORDERS_EMULATED => Self::read_set(conn, key).await,
778 INDEX_ORDERS_INFLIGHT => Self::read_set(conn, key).await,
779 INDEX_POSITIONS => Self::read_set(conn, key).await,
780 INDEX_POSITIONS_OPEN => Self::read_set(conn, key).await,
781 INDEX_POSITIONS_CLOSED => Self::read_set(conn, key).await,
782 _ => anyhow::bail!("Index unknown '{index_key}' on read"),
783 }
784 }
785
786 async fn read_string(conn: &mut ConnectionManager, key: &str) -> anyhow::Result<Vec<Bytes>> {
787 let result: Vec<u8> = conn.get(key).await?;
788
789 if result.is_empty() {
790 Ok(vec![])
791 } else {
792 Ok(vec![Bytes::from(result)])
793 }
794 }
795
796 async fn read_set(conn: &mut ConnectionManager, key: &str) -> anyhow::Result<Vec<Bytes>> {
797 let result: Vec<Bytes> = conn.smembers(key).await?;
798 Ok(result)
799 }
800
801 async fn read_hset(conn: &mut ConnectionManager, key: &str) -> anyhow::Result<Vec<Bytes>> {
802 let result: HashMap<String, String> = conn.hgetall(key).await?;
803 let json = serde_json::to_string(&result)?;
804 Ok(vec![Bytes::from(json.into_bytes())])
805 }
806
807 async fn read_list(conn: &mut ConnectionManager, key: &str) -> anyhow::Result<Vec<Bytes>> {
808 let result: Vec<Bytes> = conn.lrange(key, 0, -1).await?;
809 Ok(result)
810 }
811}
812
813fn is_timestamp_field(key: &str) -> bool {
814 let expire_match = key == "expire_time_ns";
815 let ts_match = key.starts_with("ts_");
816 expire_match || ts_match
817}
818
819fn convert_timestamps(value: &mut Value) {
820 match value {
821 Value::Object(map) => {
822 for (key, v) in map {
823 if is_timestamp_field(key)
824 && let Value::Number(n) = v
825 && let Some(n) = n.as_u64()
826 {
827 let dt = DateTime::<Utc>::from_timestamp_nanos(n as i64);
828 *v = Value::String(dt.to_rfc3339_opts(chrono::SecondsFormat::Nanos, true));
829 }
830 convert_timestamps(v);
831 }
832 }
833 Value::Array(arr) => {
834 for item in arr {
835 convert_timestamps(item);
836 }
837 }
838 _ => {}
839 }
840}
841
842fn convert_timestamp_strings(value: &mut Value) {
843 match value {
844 Value::Object(map) => {
845 for (key, v) in map {
846 if is_timestamp_field(key)
847 && let Value::String(s) = v
848 && let Ok(dt) = DateTime::parse_from_rfc3339(s)
849 {
850 *v = Value::Number(
851 (dt.with_timezone(&Utc)
852 .timestamp_nanos_opt()
853 .expect("Invalid DateTime") as u64)
854 .into(),
855 );
856 }
857 convert_timestamp_strings(v);
858 }
859 }
860 Value::Array(arr) => {
861 for item in arr {
862 convert_timestamp_strings(item);
863 }
864 }
865 _ => {}
866 }
867}