1use std::sync::{
19 Arc,
20 atomic::{AtomicBool, AtomicU8, Ordering},
21};
22
23use arc_swap::ArcSwap;
24use dashmap::DashMap;
25use nautilus_common::runtime::get_runtime;
26use nautilus_model::{data::BarType, identifiers::InstrumentId, instruments::InstrumentAny};
27use nautilus_network::{
28 mode::ConnectionMode,
29 websocket::{WebSocketClient, WebSocketConfig, channel_message_handler},
30};
31use tokio::sync::RwLock;
32use tokio_util::sync::CancellationToken;
33use ustr::Ustr;
34
35use super::{
36 enums::{KrakenWsChannel, KrakenWsMethod},
37 error::KrakenWsError,
38 handler::{FeedHandler, HandlerCommand},
39 messages::{KrakenWsParams, KrakenWsRequest, NautilusWsMessage},
40};
41use crate::{config::KrakenDataClientConfig, http::client::KrakenHttpClient};
42
43#[derive(Debug)]
44#[cfg_attr(
45 feature = "python",
46 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.adapters")
47)]
48pub struct KrakenWebSocketClient {
49 url: String,
50 config: KrakenDataClientConfig,
51 signal: Arc<AtomicBool>,
52 connection_mode: Arc<ArcSwap<AtomicU8>>,
53 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
54 out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
55 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
56 subscriptions: Arc<DashMap<String, KrakenWsChannel>>,
57 cancellation_token: CancellationToken,
58 req_id_counter: Arc<RwLock<u64>>,
59 auth_token: Arc<RwLock<Option<String>>>,
60}
61
62impl Clone for KrakenWebSocketClient {
63 fn clone(&self) -> Self {
64 Self {
65 url: self.url.clone(),
66 config: self.config.clone(),
67 signal: Arc::clone(&self.signal),
68 connection_mode: Arc::clone(&self.connection_mode),
69 cmd_tx: Arc::clone(&self.cmd_tx),
70 out_rx: self.out_rx.clone(),
71 task_handle: self.task_handle.clone(),
72 subscriptions: self.subscriptions.clone(),
73 cancellation_token: self.cancellation_token.clone(),
74 req_id_counter: self.req_id_counter.clone(),
75 auth_token: self.auth_token.clone(),
76 }
77 }
78}
79
80impl KrakenWebSocketClient {
81 pub fn new(config: KrakenDataClientConfig, cancellation_token: CancellationToken) -> Self {
82 let url = config.ws_public_url();
83 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
84 let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
85 let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
86
87 Self {
88 url,
89 config,
90 signal: Arc::new(AtomicBool::new(false)),
91 connection_mode,
92 cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
93 out_rx: None,
94 task_handle: None,
95 subscriptions: Arc::new(DashMap::new()),
96 cancellation_token,
97 req_id_counter: Arc::new(RwLock::new(0)),
98 auth_token: Arc::new(RwLock::new(None)),
99 }
100 }
101
102 async fn get_next_req_id(&self) -> u64 {
103 let mut counter = self.req_id_counter.write().await;
104 *counter += 1;
105 *counter
106 }
107
108 pub async fn connect(&mut self) -> Result<(), KrakenWsError> {
109 tracing::debug!("Connecting to {}", self.url);
110
111 self.signal.store(false, Ordering::Relaxed);
112
113 let (raw_handler, raw_rx) = channel_message_handler();
114
115 let ws_config = WebSocketConfig {
116 url: self.url.clone(),
117 headers: vec![],
118 message_handler: Some(raw_handler),
119 ping_handler: None,
120 heartbeat: self.config.heartbeat_interval_secs,
121 heartbeat_msg: Some("ping".to_string()),
122 reconnect_timeout_ms: None,
123 reconnect_delay_initial_ms: None,
124 reconnect_delay_max_ms: None,
125 reconnect_backoff_factor: None,
126 reconnect_jitter_ms: None,
127 reconnect_max_attempts: None,
128 };
129
130 let ws_client = WebSocketClient::connect(
131 ws_config,
132 None, vec![], None, )
136 .await
137 .map_err(|e| KrakenWsError::ConnectionError(e.to_string()))?;
138
139 self.connection_mode
141 .store(ws_client.connection_mode_atomic());
142
143 let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
144 self.out_rx = Some(Arc::new(out_rx));
145
146 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
147 *self.cmd_tx.write().await = cmd_tx.clone();
148
149 if let Err(e) = cmd_tx.send(HandlerCommand::SetClient(ws_client)) {
150 return Err(KrakenWsError::ConnectionError(format!(
151 "Failed to send WebSocketClient to handler: {e}"
152 )));
153 }
154
155 let signal = self.signal.clone();
156
157 let stream_handle = get_runtime().spawn(async move {
158 let mut handler = FeedHandler::new(signal.clone(), cmd_rx, raw_rx);
159
160 loop {
161 match handler.next().await {
162 Some(msg) => {
163 if out_tx.send(msg).is_err() {
164 tracing::error!("Failed to send message (receiver dropped)");
165 break;
166 }
167 }
168 None => {
169 if handler.is_stopped() {
170 tracing::debug!("Stop signal received, ending message processing");
171 break;
172 }
173 tracing::warn!("WebSocket stream ended unexpectedly");
174 break;
175 }
176 }
177 }
178
179 tracing::debug!("Handler task exiting");
180 });
181
182 self.task_handle = Some(Arc::new(stream_handle));
183
184 tracing::debug!("WebSocket connected successfully");
185 Ok(())
186 }
187
188 pub async fn disconnect(&mut self) -> Result<(), KrakenWsError> {
189 tracing::debug!("Disconnecting WebSocket");
190
191 self.signal.store(true, Ordering::Relaxed);
192
193 if let Err(e) = self.cmd_tx.read().await.send(HandlerCommand::Disconnect) {
194 tracing::debug!(
195 "Failed to send disconnect command (handler may already be shut down): {e}"
196 );
197 }
198
199 if let Some(task_handle) = self.task_handle.take() {
200 match Arc::try_unwrap(task_handle) {
201 Ok(handle) => {
202 tracing::debug!("Waiting for task handle to complete");
203 match tokio::time::timeout(tokio::time::Duration::from_secs(2), handle).await {
204 Ok(Ok(())) => tracing::debug!("Task handle completed successfully"),
205 Ok(Err(e)) => tracing::error!("Task handle encountered an error: {e:?}"),
206 Err(_) => {
207 tracing::warn!(
208 "Timeout waiting for task handle, task may still be running"
209 );
210 }
211 }
212 }
213 Err(arc_handle) => {
214 tracing::debug!(
215 "Cannot take ownership of task handle - other references exist, aborting task"
216 );
217 arc_handle.abort();
218 }
219 }
220 } else {
221 tracing::debug!("No task handle to await");
222 }
223
224 self.subscriptions.clear();
225
226 Ok(())
227 }
228
229 pub async fn close(&mut self) -> Result<(), KrakenWsError> {
230 self.disconnect().await
231 }
232
233 pub async fn wait_until_active(&self, timeout_secs: f64) -> Result<(), KrakenWsError> {
234 let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
235
236 tokio::time::timeout(timeout, async {
237 while !self.is_active() {
238 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
239 }
240 })
241 .await
242 .map_err(|_| {
243 KrakenWsError::ConnectionError(format!(
244 "WebSocket connection timeout after {timeout_secs} seconds"
245 ))
246 })?;
247
248 Ok(())
249 }
250
251 pub async fn authenticate(&self) -> Result<(), KrakenWsError> {
252 if !self.config.has_api_credentials() {
253 return Err(KrakenWsError::AuthenticationError(
254 "API credentials required for authentication".to_string(),
255 ));
256 }
257
258 let api_key = self
259 .config
260 .api_key
261 .clone()
262 .ok_or_else(|| KrakenWsError::AuthenticationError("Missing API key".to_string()))?;
263 let api_secret =
264 self.config.api_secret.clone().ok_or_else(|| {
265 KrakenWsError::AuthenticationError("Missing API secret".to_string())
266 })?;
267
268 let http_client = KrakenHttpClient::with_credentials(
269 api_key,
270 api_secret,
271 Some(self.config.http_base_url()),
272 self.config.timeout_secs,
273 None,
274 None,
275 None,
276 self.config.http_proxy.clone(),
277 )
278 .map_err(|e| {
279 KrakenWsError::AuthenticationError(format!("Failed to create HTTP client: {e}"))
280 })?;
281
282 let ws_token = http_client.get_websockets_token().await.map_err(|e| {
283 KrakenWsError::AuthenticationError(format!("Failed to get WebSocket token: {e}"))
284 })?;
285
286 tracing::debug!(
287 token_length = ws_token.token.len(),
288 expires = ws_token.expires,
289 "WebSocket authentication token received"
290 );
291
292 let mut auth_token = self.auth_token.write().await;
293 *auth_token = Some(ws_token.token);
294
295 Ok(())
296 }
297
298 pub fn cache_instruments(&self, instruments: Vec<InstrumentAny>) {
299 if let Ok(cmd_tx) = self.cmd_tx.try_read()
301 && let Err(e) = cmd_tx.send(HandlerCommand::InitializeInstruments(instruments))
302 {
303 tracing::debug!("Failed to send instruments to handler: {e}");
304 }
305 }
306
307 pub fn cache_instrument(&self, instrument: InstrumentAny) {
308 if let Ok(cmd_tx) = self.cmd_tx.try_read()
310 && let Err(e) = cmd_tx.send(HandlerCommand::UpdateInstrument(instrument))
311 {
312 tracing::debug!("Failed to send instrument update to handler: {e}");
313 }
314 }
315
316 pub fn cancel_all_requests(&self) {
317 self.cancellation_token.cancel();
318 }
319
320 pub fn cancellation_token(&self) -> &CancellationToken {
321 &self.cancellation_token
322 }
323
324 pub async fn subscribe(
325 &self,
326 channel: KrakenWsChannel,
327 symbols: Vec<Ustr>,
328 depth: Option<u32>,
329 ) -> Result<(), KrakenWsError> {
330 let req_id = self.get_next_req_id().await;
331
332 let is_private = matches!(
334 channel,
335 KrakenWsChannel::Executions | KrakenWsChannel::Balances
336 );
337 let token = if is_private {
338 Some(self.auth_token.read().await.clone().ok_or_else(|| {
339 KrakenWsError::AuthenticationError(
340 "Authentication token required for private channels. Call authenticate() first"
341 .to_string(),
342 )
343 })?)
344 } else {
345 None
346 };
347
348 let request = KrakenWsRequest {
349 method: KrakenWsMethod::Subscribe,
350 params: Some(KrakenWsParams {
351 channel,
352 symbol: Some(symbols.clone()),
353 snapshot: None,
354 depth,
355 token,
356 }),
357 req_id: Some(req_id),
358 };
359
360 self.send_request(&request).await?;
361
362 for symbol in symbols {
363 let key = format!("{:?}:{}", channel, symbol);
364 self.subscriptions.insert(key, channel);
365 }
366
367 Ok(())
368 }
369
370 pub async fn unsubscribe(
371 &self,
372 channel: KrakenWsChannel,
373 symbols: Vec<Ustr>,
374 ) -> Result<(), KrakenWsError> {
375 let req_id = self.get_next_req_id().await;
376
377 let is_private = matches!(
379 channel,
380 KrakenWsChannel::Executions | KrakenWsChannel::Balances
381 );
382 let token = if is_private {
383 Some(self.auth_token.read().await.clone().ok_or_else(|| {
384 KrakenWsError::AuthenticationError(
385 "Authentication token required for private channels. Call authenticate() first"
386 .to_string(),
387 )
388 })?)
389 } else {
390 None
391 };
392
393 let request = KrakenWsRequest {
394 method: KrakenWsMethod::Unsubscribe,
395 params: Some(KrakenWsParams {
396 channel,
397 symbol: Some(symbols.clone()),
398 snapshot: None,
399 depth: None,
400 token,
401 }),
402 req_id: Some(req_id),
403 };
404
405 self.send_request(&request).await?;
406
407 for symbol in symbols {
408 let key = format!("{:?}:{}", channel, symbol);
409 self.subscriptions.remove(&key);
410 }
411
412 Ok(())
413 }
414
415 pub async fn send_ping(&self) -> Result<(), KrakenWsError> {
416 let req_id = self.get_next_req_id().await;
417
418 let request = KrakenWsRequest {
419 method: KrakenWsMethod::Ping,
420 params: None,
421 req_id: Some(req_id),
422 };
423
424 self.send_request(&request).await
425 }
426
427 async fn send_request(&self, request: &KrakenWsRequest) -> Result<(), KrakenWsError> {
428 let payload =
429 serde_json::to_string(request).map_err(|e| KrakenWsError::JsonError(e.to_string()))?;
430
431 tracing::trace!("Sending message: {payload}");
432
433 self.cmd_tx
434 .read()
435 .await
436 .send(HandlerCommand::SendText { payload })
437 .map_err(|e| KrakenWsError::ConnectionError(format!("Failed to send request: {e}")))?;
438
439 Ok(())
440 }
441
442 pub fn is_connected(&self) -> bool {
443 let connection_mode_arc = self.connection_mode.load();
444 !ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
445 }
446
447 pub fn is_active(&self) -> bool {
448 let connection_mode_arc = self.connection_mode.load();
449 ConnectionMode::from_atomic(&connection_mode_arc).is_active()
450 && !self.signal.load(Ordering::Relaxed)
451 }
452
453 pub fn is_closed(&self) -> bool {
454 let connection_mode_arc = self.connection_mode.load();
455 ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
456 || self.signal.load(Ordering::Relaxed)
457 }
458
459 pub fn url(&self) -> &str {
460 &self.url
461 }
462
463 pub fn get_subscriptions(&self) -> Vec<String> {
464 self.subscriptions
465 .iter()
466 .map(|entry| entry.key().clone())
467 .collect()
468 }
469
470 pub fn stream(&mut self) -> impl futures_util::Stream<Item = NautilusWsMessage> + use<> {
471 let rx = self
472 .out_rx
473 .take()
474 .expect("Stream receiver already taken or client not connected");
475 let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
476 async_stream::stream! {
477 while let Some(msg) = rx.recv().await {
478 yield msg;
479 }
480 }
481 }
482
483 pub async fn subscribe_book(
484 &self,
485 instrument_id: InstrumentId,
486 depth: Option<u32>,
487 ) -> Result<(), KrakenWsError> {
488 let symbol = Ustr::from(instrument_id.symbol.as_str());
489 self.subscribe(KrakenWsChannel::Book, vec![symbol], depth)
490 .await
491 }
492
493 pub async fn subscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), KrakenWsError> {
494 let symbol = Ustr::from(instrument_id.symbol.as_str());
495 self.subscribe(KrakenWsChannel::Ticker, vec![symbol], None)
496 .await
497 }
498
499 pub async fn subscribe_trades(&self, instrument_id: InstrumentId) -> Result<(), KrakenWsError> {
500 let symbol = Ustr::from(instrument_id.symbol.as_str());
501 self.subscribe(KrakenWsChannel::Trade, vec![symbol], None)
502 .await
503 }
504
505 pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), KrakenWsError> {
506 let symbol = Ustr::from(bar_type.instrument_id().symbol.as_str());
507 self.subscribe(KrakenWsChannel::Ohlc, vec![symbol], None)
508 .await
509 }
510
511 pub async fn unsubscribe_book(&self, instrument_id: InstrumentId) -> Result<(), KrakenWsError> {
512 let symbol = Ustr::from(instrument_id.symbol.as_str());
513 self.unsubscribe(KrakenWsChannel::Book, vec![symbol]).await
514 }
515
516 pub async fn unsubscribe_quotes(
517 &self,
518 instrument_id: InstrumentId,
519 ) -> Result<(), KrakenWsError> {
520 let symbol = Ustr::from(instrument_id.symbol.as_str());
521 self.unsubscribe(KrakenWsChannel::Ticker, vec![symbol])
522 .await
523 }
524
525 pub async fn unsubscribe_trades(
526 &self,
527 instrument_id: InstrumentId,
528 ) -> Result<(), KrakenWsError> {
529 let symbol = Ustr::from(instrument_id.symbol.as_str());
530 self.unsubscribe(KrakenWsChannel::Trade, vec![symbol]).await
531 }
532
533 pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), KrakenWsError> {
534 let symbol = Ustr::from(bar_type.instrument_id().symbol.as_str());
535 self.unsubscribe(KrakenWsChannel::Ohlc, vec![symbol]).await
536 }
537}