1use std::{
17 collections::HashMap,
18 env,
19 sync::{
20 Arc,
21 atomic::{AtomicBool, Ordering},
22 },
23};
24
25use futures_util::{Stream, StreamExt, pin_mut};
26use nautilus_model::data::Data;
27use ustr::Ustr;
28
29use super::{
30 Error,
31 message::WsMessage,
32 replay_normalized, stream_normalized,
33 types::{
34 ReplayNormalizedRequestOptions, StreamNormalizedRequestOptions, TardisInstrumentKey,
35 TardisInstrumentMiniInfo,
36 },
37};
38use crate::{
39 common::consts::TARDIS_MACHINE_WS_URL, config::BookSnapshotOutput,
40 machine::parse::parse_tardis_ws_message,
41};
42
43#[cfg_attr(
45 feature = "python",
46 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.tardis")
47)]
48#[derive(Debug, Clone)]
49pub struct TardisMachineClient {
50 pub base_url: String,
51 pub replay_signal: Arc<AtomicBool>,
52 pub stream_signal: Arc<AtomicBool>,
53 pub instruments: HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>>,
54 pub normalize_symbols: bool,
55 pub book_snapshot_output: BookSnapshotOutput,
56}
57
58impl TardisMachineClient {
59 pub fn new(
65 base_url: Option<&str>,
66 normalize_symbols: bool,
67 book_snapshot_output: BookSnapshotOutput,
68 ) -> anyhow::Result<Self> {
69 let base_url = base_url
70 .map(ToString::to_string)
71 .or_else(|| env::var(TARDIS_MACHINE_WS_URL).ok())
72 .ok_or_else(|| {
73 anyhow::anyhow!(
74 "Tardis Machine `base_url` must be provided or \
75 set in the '{TARDIS_MACHINE_WS_URL}' environment variable"
76 )
77 })?;
78
79 Ok(Self {
80 base_url,
81 replay_signal: Arc::new(AtomicBool::new(false)),
82 stream_signal: Arc::new(AtomicBool::new(false)),
83 instruments: HashMap::new(),
84 normalize_symbols,
85 book_snapshot_output,
86 })
87 }
88
89 pub fn add_instrument_info(&mut self, info: TardisInstrumentMiniInfo) {
90 let key = info.as_tardis_instrument_key();
91 self.instruments.insert(key, Arc::new(info));
92 }
93
94 #[must_use]
99 pub fn is_closed(&self) -> bool {
100 self.replay_signal.load(Ordering::Acquire) && self.stream_signal.load(Ordering::Acquire)
102 }
103
104 pub fn close(&mut self) {
105 log::debug!("Closing");
106
107 self.replay_signal.store(true, Ordering::Release);
109 self.stream_signal.store(true, Ordering::Release);
110
111 log::debug!("Closed");
112 }
113
114 pub async fn replay(
120 &self,
121 options: Vec<ReplayNormalizedRequestOptions>,
122 ) -> Result<impl Stream<Item = Result<Data, Error>>, Error> {
123 let stream = replay_normalized(&self.base_url, options, self.replay_signal.clone()).await?;
124
125 Ok(handle_ws_stream(
128 Box::pin(stream),
129 None,
130 Some(self.instruments.clone()),
131 self.book_snapshot_output.clone(),
132 ))
133 }
134
135 pub async fn stream(
141 &self,
142 instrument: TardisInstrumentMiniInfo,
143 options: Vec<StreamNormalizedRequestOptions>,
144 ) -> Result<impl Stream<Item = Result<Data, Error>>, Error> {
145 let stream = stream_normalized(&self.base_url, options, self.stream_signal.clone()).await?;
146
147 Ok(handle_ws_stream(
150 Box::pin(stream),
151 Some(Arc::new(instrument)),
152 None,
153 self.book_snapshot_output.clone(),
154 ))
155 }
156}
157
158fn handle_ws_stream<S>(
159 stream: S,
160 instrument: Option<Arc<TardisInstrumentMiniInfo>>,
161 instrument_map: Option<HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>>>,
162 book_snapshot_output: BookSnapshotOutput,
163) -> impl Stream<Item = Result<Data, Error>>
164where
165 S: Stream<Item = Result<WsMessage, Error>> + Unpin,
166{
167 assert!(
168 instrument.is_some() || instrument_map.is_some(),
169 "Either `instrument` or `instrument_map` must be provided"
170 );
171
172 async_stream::stream! {
173 pin_mut!(stream);
174
175 while let Some(result) = stream.next().await {
176 match result {
177 Ok(msg) => {
178 if matches!(msg, WsMessage::Disconnect(_)) {
179 log::debug!("Received disconnect message: {msg:?}");
180 continue;
181 }
182
183 let info = instrument.clone().or_else(|| {
184 instrument_map
185 .as_ref()
186 .and_then(|map| determine_instrument_info(&msg, map))
187 });
188
189 if let Some(info) = info {
190 if let Some(data) = parse_tardis_ws_message(msg, info, &book_snapshot_output) {
191 yield Ok(data);
192 }
193 } else {
194 log::error!("Missing instrument info for message: {msg:?}");
195 yield Err(Error::ConnectionClosed {
196 reason: "Missing instrument definition info".to_string()
197 });
198 break;
199 }
200 }
201 Err(e) => {
202 log::error!("Error in WebSocket stream: {e:?}");
203 yield Err(e);
204 break;
205 }
206 }
207 }
208 }
209}
210
211pub fn determine_instrument_info(
212 msg: &WsMessage,
213 instrument_map: &HashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>>,
214) -> Option<Arc<TardisInstrumentMiniInfo>> {
215 let key = match msg {
216 WsMessage::BookChange(msg) => {
217 TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange)
218 }
219 WsMessage::BookSnapshot(msg) => {
220 TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange)
221 }
222 WsMessage::Trade(msg) => TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange),
223 WsMessage::TradeBar(msg) => TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange),
224 WsMessage::DerivativeTicker(msg) => {
225 TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange)
226 }
227 WsMessage::Disconnect(_) => return None,
228 };
229 if let Some(inst) = instrument_map.get(&key) {
230 Some(inst.clone())
231 } else {
232 log::error!("Instrument definition info not available for {key:?}");
233 None
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use rstest::rstest;
240
241 use super::*;
242
243 #[rstest]
244 fn test_is_closed_initial_state() {
245 let client = TardisMachineClient::new(
246 Some("ws://localhost:8001"),
247 false,
248 BookSnapshotOutput::Deltas,
249 )
250 .unwrap();
251 assert!(!client.is_closed());
253 }
254
255 #[rstest]
256 fn test_is_closed_after_close() {
257 let mut client = TardisMachineClient::new(
258 Some("ws://localhost:8001"),
259 false,
260 BookSnapshotOutput::Deltas,
261 )
262 .unwrap();
263 client.close();
264 assert!(client.is_closed());
266 }
267
268 #[rstest]
269 fn test_is_closed_partial_signal() {
270 let client = TardisMachineClient::new(
271 Some("ws://localhost:8001"),
272 false,
273 BookSnapshotOutput::Deltas,
274 )
275 .unwrap();
276 client.replay_signal.store(true, Ordering::Release);
279 assert!(!client.is_closed());
280
281 client.stream_signal.store(true, Ordering::Release);
282 assert!(client.is_closed());
284 }
285}