nautilus_tardis/machine/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    collections::HashMap,
18    env,
19    sync::{
20        Arc,
21        atomic::{AtomicBool, Ordering},
22    },
23};
24
25use futures_util::{Stream, StreamExt, pin_mut};
26use nautilus_model::data::Data;
27use ustr::Ustr;
28
29use super::{
30    Error,
31    message::WsMessage,
32    replay_normalized, stream_normalized,
33    types::{
34        InstrumentMiniInfo, ReplayNormalizedRequestOptions, StreamNormalizedRequestOptions,
35        TardisInstrumentKey,
36    },
37};
38use crate::machine::parse::parse_tardis_ws_message;
39
40/// Provides a client for connecting to a [Tardis Machine Server](https://docs.tardis.dev/api/tardis-machine).
41#[cfg_attr(
42    feature = "python",
43    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.adapters")
44)]
45#[derive(Debug, Clone)]
46pub struct TardisMachineClient {
47    pub base_url: String,
48    pub replay_signal: Arc<AtomicBool>,
49    pub stream_signal: Arc<AtomicBool>,
50    pub instruments: HashMap<TardisInstrumentKey, Arc<InstrumentMiniInfo>>,
51    pub normalize_symbols: bool,
52}
53
54impl TardisMachineClient {
55    /// Creates a new [`TardisMachineClient`] instance.
56    pub fn new(base_url: Option<&str>, normalize_symbols: bool) -> anyhow::Result<Self> {
57        let base_url = base_url
58            .map(ToString::to_string)
59            .or_else(|| env::var("TARDIS_MACHINE_WS_URL").ok())
60            .ok_or_else(|| {
61                anyhow::anyhow!(
62                    "Tardis Machine `base_url` must be provided or set in the 'TARDIS_MACHINE_WS_URL' environment variable"
63                )
64            })?;
65
66        Ok(Self {
67            base_url,
68            replay_signal: Arc::new(AtomicBool::new(false)),
69            stream_signal: Arc::new(AtomicBool::new(false)),
70            instruments: HashMap::new(),
71            normalize_symbols,
72        })
73    }
74
75    pub fn add_instrument_info(&mut self, info: InstrumentMiniInfo) {
76        let key = info.as_tardis_instrument_key();
77        self.instruments.insert(key, Arc::new(info));
78    }
79
80    #[must_use]
81    pub fn is_closed(&self) -> bool {
82        self.replay_signal.load(Ordering::Relaxed) && self.stream_signal.load(Ordering::Relaxed)
83    }
84
85    pub fn close(&mut self) {
86        tracing::debug!("Closing");
87
88        self.replay_signal.store(true, Ordering::Relaxed);
89        self.stream_signal.store(true, Ordering::Relaxed);
90
91        tracing::debug!("Closed");
92    }
93
94    pub async fn replay(
95        &self,
96        options: Vec<ReplayNormalizedRequestOptions>,
97    ) -> impl Stream<Item = Data> {
98        let stream = replay_normalized(&self.base_url, options, self.replay_signal.clone())
99            .await
100            .expect("Failed to connect to WebSocket");
101
102        // We use Box::pin to heap-allocate the stream and ensure it implements
103        // Unpin for safe async handling across lifetimes.
104        handle_ws_stream(Box::pin(stream), None, Some(self.instruments.clone()))
105    }
106
107    pub async fn stream(
108        &self,
109        instrument: InstrumentMiniInfo,
110        options: Vec<StreamNormalizedRequestOptions>,
111    ) -> impl Stream<Item = Data> {
112        let stream = stream_normalized(&self.base_url, options, self.stream_signal.clone())
113            .await
114            .expect("Failed to connect to WebSocket");
115
116        // We use Box::pin to heap-allocate the stream and ensure it implements
117        // Unpin for safe async handling across lifetimes.
118        handle_ws_stream(Box::pin(stream), Some(Arc::new(instrument)), None)
119    }
120}
121
122fn handle_ws_stream<S>(
123    stream: S,
124    instrument: Option<Arc<InstrumentMiniInfo>>,
125    instrument_map: Option<HashMap<TardisInstrumentKey, Arc<InstrumentMiniInfo>>>,
126) -> impl Stream<Item = Data>
127where
128    S: Stream<Item = Result<WsMessage, Error>> + Unpin,
129{
130    assert!(
131        instrument.is_some() || instrument_map.is_some(),
132        "Either `instrument` or `instrument_map` must be provided"
133    );
134
135    async_stream::stream! {
136        pin_mut!(stream);
137        while let Some(result) = stream.next().await {
138            match result {
139                Ok(msg) => {
140                    let info = instrument.clone().or_else(|| {
141                        instrument_map
142                            .as_ref()
143                            .and_then(|map| determine_instrument_info(&msg, map))
144                    });
145
146                    if let Some(info) = info {
147                        if let Some(data) = parse_tardis_ws_message(msg, info) {
148                            yield data;
149                        }
150                    }
151                }
152                Err(e) => {
153                    tracing::error!("Error in WebSocket stream: {e:?}");
154                    break;
155                }
156            }
157        }
158    }
159}
160
161pub fn determine_instrument_info(
162    msg: &WsMessage,
163    instrument_map: &HashMap<TardisInstrumentKey, Arc<InstrumentMiniInfo>>,
164) -> Option<Arc<InstrumentMiniInfo>> {
165    let key = match msg {
166        WsMessage::BookChange(msg) => {
167            TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange.clone())
168        }
169        WsMessage::BookSnapshot(msg) => {
170            TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange.clone())
171        }
172        WsMessage::Trade(msg) => {
173            TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange.clone())
174        }
175        WsMessage::TradeBar(msg) => {
176            TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange.clone())
177        }
178        WsMessage::DerivativeTicker(_) => return None,
179        WsMessage::Disconnect(_) => return None,
180    };
181    if let Some(inst) = instrument_map.get(&key) {
182        Some(inst.clone())
183    } else {
184        tracing::error!("Instrument definition info not available for {key:?}");
185        None
186    }
187}