1use std::{
17 collections::HashMap,
18 env,
19 sync::{
20 Arc,
21 atomic::{AtomicBool, Ordering},
22 },
23};
24
25use futures_util::{Stream, StreamExt, pin_mut};
26use nautilus_model::data::Data;
27use ustr::Ustr;
28
29use super::{
30 Error,
31 message::WsMessage,
32 replay_normalized, stream_normalized,
33 types::{
34 ReplayNormalizedRequestOptions, StreamNormalizedRequestOptions, TardisInstrumentKey,
35 TardisInstrumentMiniInfo,
36 },
37};
38use crate::{config::BookSnapshotOutput, machine::parse::parse_tardis_ws_message};
39
40#[cfg_attr(
42 feature = "python",
43 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.adapters")
44)]
45#[derive(Debug, Clone)]
46pub struct TardisMachineClient {
47 pub base_url: String,
48 pub replay_signal: Arc<AtomicBool>,
49 pub stream_signal: Arc<AtomicBool>,
50 pub instruments: HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>>,
51 pub normalize_symbols: bool,
52 pub book_snapshot_output: BookSnapshotOutput,
53}
54
55impl TardisMachineClient {
56 pub fn new(
62 base_url: Option<&str>,
63 normalize_symbols: bool,
64 book_snapshot_output: BookSnapshotOutput,
65 ) -> anyhow::Result<Self> {
66 let base_url = base_url
67 .map(ToString::to_string)
68 .or_else(|| env::var("TARDIS_MACHINE_WS_URL").ok())
69 .ok_or_else(|| {
70 anyhow::anyhow!(
71 "Tardis Machine `base_url` must be provided or set in the 'TARDIS_MACHINE_WS_URL' environment variable"
72 )
73 })?;
74
75 Ok(Self {
76 base_url,
77 replay_signal: Arc::new(AtomicBool::new(false)),
78 stream_signal: Arc::new(AtomicBool::new(false)),
79 instruments: HashMap::new(),
80 normalize_symbols,
81 book_snapshot_output,
82 })
83 }
84
85 pub fn add_instrument_info(&mut self, info: TardisInstrumentMiniInfo) {
86 let key = info.as_tardis_instrument_key();
87 self.instruments.insert(key, Arc::new(info));
88 }
89
90 #[must_use]
95 pub fn is_closed(&self) -> bool {
96 self.replay_signal.load(Ordering::Acquire) && self.stream_signal.load(Ordering::Acquire)
98 }
99
100 pub fn close(&mut self) {
101 tracing::debug!("Closing");
102
103 self.replay_signal.store(true, Ordering::Release);
105 self.stream_signal.store(true, Ordering::Release);
106
107 tracing::debug!("Closed");
108 }
109
110 pub async fn replay(
116 &self,
117 options: Vec<ReplayNormalizedRequestOptions>,
118 ) -> Result<impl Stream<Item = Result<Data, Error>>, Error> {
119 let stream = replay_normalized(&self.base_url, options, self.replay_signal.clone()).await?;
120
121 Ok(handle_ws_stream(
124 Box::pin(stream),
125 None,
126 Some(self.instruments.clone()),
127 self.book_snapshot_output.clone(),
128 ))
129 }
130
131 pub async fn stream(
137 &self,
138 instrument: TardisInstrumentMiniInfo,
139 options: Vec<StreamNormalizedRequestOptions>,
140 ) -> Result<impl Stream<Item = Result<Data, Error>>, Error> {
141 let stream = stream_normalized(&self.base_url, options, self.stream_signal.clone()).await?;
142
143 Ok(handle_ws_stream(
146 Box::pin(stream),
147 Some(Arc::new(instrument)),
148 None,
149 self.book_snapshot_output.clone(),
150 ))
151 }
152}
153
154fn handle_ws_stream<S>(
155 stream: S,
156 instrument: Option<Arc<TardisInstrumentMiniInfo>>,
157 instrument_map: Option<HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>>>,
158 book_snapshot_output: BookSnapshotOutput,
159) -> impl Stream<Item = Result<Data, Error>>
160where
161 S: Stream<Item = Result<WsMessage, Error>> + Unpin,
162{
163 assert!(
164 instrument.is_some() || instrument_map.is_some(),
165 "Either `instrument` or `instrument_map` must be provided"
166 );
167
168 async_stream::stream! {
169 pin_mut!(stream);
170
171 while let Some(result) = stream.next().await {
172 match result {
173 Ok(msg) => {
174 if matches!(msg, WsMessage::Disconnect(_)) {
175 tracing::debug!("Received disconnect message: {msg:?}");
176 continue;
177 }
178
179 let info = instrument.clone().or_else(|| {
180 instrument_map
181 .as_ref()
182 .and_then(|map| determine_instrument_info(&msg, map))
183 });
184
185 if let Some(info) = info {
186 if let Some(data) = parse_tardis_ws_message(msg, info, &book_snapshot_output) {
187 yield Ok(data);
188 }
189 } else {
190 tracing::error!("Missing instrument info for message: {msg:?}");
191 yield Err(Error::ConnectionClosed {
192 reason: "Missing instrument definition info".to_string()
193 });
194 break;
195 }
196 }
197 Err(e) => {
198 tracing::error!("Error in WebSocket stream: {e:?}");
199 yield Err(e);
200 break;
201 }
202 }
203 }
204 }
205}
206
207pub fn determine_instrument_info(
208 msg: &WsMessage,
209 instrument_map: &HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>>,
210) -> Option<Arc<TardisInstrumentMiniInfo>> {
211 let key = match msg {
212 WsMessage::BookChange(msg) => {
213 TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange)
214 }
215 WsMessage::BookSnapshot(msg) => {
216 TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange)
217 }
218 WsMessage::Trade(msg) => TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange),
219 WsMessage::TradeBar(msg) => TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange),
220 WsMessage::DerivativeTicker(msg) => {
221 TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange)
222 }
223 WsMessage::Disconnect(_) => return None,
224 };
225 if let Some(inst) = instrument_map.get(&key) {
226 Some(inst.clone())
227 } else {
228 tracing::error!("Instrument definition info not available for {key:?}");
229 None
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use rstest::rstest;
236
237 use super::*;
238
239 #[rstest]
240 fn test_is_closed_initial_state() {
241 let client = TardisMachineClient::new(
242 Some("ws://localhost:8001"),
243 false,
244 BookSnapshotOutput::Deltas,
245 )
246 .unwrap();
247 assert!(!client.is_closed());
249 }
250
251 #[rstest]
252 fn test_is_closed_after_close() {
253 let mut client = TardisMachineClient::new(
254 Some("ws://localhost:8001"),
255 false,
256 BookSnapshotOutput::Deltas,
257 )
258 .unwrap();
259 client.close();
260 assert!(client.is_closed());
262 }
263
264 #[rstest]
265 fn test_is_closed_partial_signal() {
266 let client = TardisMachineClient::new(
267 Some("ws://localhost:8001"),
268 false,
269 BookSnapshotOutput::Deltas,
270 )
271 .unwrap();
272 client.replay_signal.store(true, Ordering::Release);
275 assert!(!client.is_closed());
276
277 client.stream_signal.store(true, Ordering::Release);
278 assert!(client.is_closed());
280 }
281}