nautilus_kraken/websocket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! WebSocket client for the Kraken v2 streaming API.
17
18use std::sync::{
19    Arc,
20    atomic::{AtomicBool, AtomicU8, Ordering},
21};
22
23use arc_swap::ArcSwap;
24use dashmap::DashMap;
25use nautilus_common::runtime::get_runtime;
26use nautilus_model::{data::BarType, identifiers::InstrumentId, instruments::InstrumentAny};
27use nautilus_network::{
28    mode::ConnectionMode,
29    websocket::{WebSocketClient, WebSocketConfig, channel_message_handler},
30};
31use tokio::sync::RwLock;
32use tokio_util::sync::CancellationToken;
33use ustr::Ustr;
34
35use super::{
36    enums::{KrakenWsChannel, KrakenWsMethod},
37    error::KrakenWsError,
38    handler::{FeedHandler, HandlerCommand},
39    messages::{KrakenWsParams, KrakenWsRequest, NautilusWsMessage},
40};
41use crate::{config::KrakenDataClientConfig, http::client::KrakenHttpClient};
42
43#[derive(Debug)]
44#[cfg_attr(
45    feature = "python",
46    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.adapters")
47)]
48pub struct KrakenWebSocketClient {
49    url: String,
50    config: KrakenDataClientConfig,
51    signal: Arc<AtomicBool>,
52    connection_mode: Arc<ArcSwap<AtomicU8>>,
53    cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
54    out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
55    task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
56    subscriptions: Arc<DashMap<String, KrakenWsChannel>>,
57    cancellation_token: CancellationToken,
58    req_id_counter: Arc<RwLock<u64>>,
59    auth_token: Arc<RwLock<Option<String>>>,
60}
61
62impl Clone for KrakenWebSocketClient {
63    fn clone(&self) -> Self {
64        Self {
65            url: self.url.clone(),
66            config: self.config.clone(),
67            signal: Arc::clone(&self.signal),
68            connection_mode: Arc::clone(&self.connection_mode),
69            cmd_tx: Arc::clone(&self.cmd_tx),
70            out_rx: self.out_rx.clone(),
71            task_handle: self.task_handle.clone(),
72            subscriptions: self.subscriptions.clone(),
73            cancellation_token: self.cancellation_token.clone(),
74            req_id_counter: self.req_id_counter.clone(),
75            auth_token: self.auth_token.clone(),
76        }
77    }
78}
79
80impl KrakenWebSocketClient {
81    pub fn new(config: KrakenDataClientConfig, cancellation_token: CancellationToken) -> Self {
82        let url = config.ws_public_url();
83        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
84        let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
85        let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
86
87        Self {
88            url,
89            config,
90            signal: Arc::new(AtomicBool::new(false)),
91            connection_mode,
92            cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
93            out_rx: None,
94            task_handle: None,
95            subscriptions: Arc::new(DashMap::new()),
96            cancellation_token,
97            req_id_counter: Arc::new(RwLock::new(0)),
98            auth_token: Arc::new(RwLock::new(None)),
99        }
100    }
101
102    async fn get_next_req_id(&self) -> u64 {
103        let mut counter = self.req_id_counter.write().await;
104        *counter += 1;
105        *counter
106    }
107
108    pub async fn connect(&mut self) -> Result<(), KrakenWsError> {
109        tracing::debug!("Connecting to {}", self.url);
110
111        self.signal.store(false, Ordering::Relaxed);
112
113        let (raw_handler, raw_rx) = channel_message_handler();
114
115        let ws_config = WebSocketConfig {
116            url: self.url.clone(),
117            headers: vec![],
118            message_handler: Some(raw_handler),
119            ping_handler: None,
120            heartbeat: self.config.heartbeat_interval_secs,
121            heartbeat_msg: Some("ping".to_string()),
122            reconnect_timeout_ms: None,
123            reconnect_delay_initial_ms: None,
124            reconnect_delay_max_ms: None,
125            reconnect_backoff_factor: None,
126            reconnect_jitter_ms: None,
127            reconnect_max_attempts: None,
128        };
129
130        let ws_client = WebSocketClient::connect(
131            ws_config,
132            None,   // post_reconnection
133            vec![], // keyed_quotas
134            None,   // default_quota
135        )
136        .await
137        .map_err(|e| KrakenWsError::ConnectionError(e.to_string()))?;
138
139        // Share connection state across clones via ArcSwap
140        self.connection_mode
141            .store(ws_client.connection_mode_atomic());
142
143        let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
144        self.out_rx = Some(Arc::new(out_rx));
145
146        let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
147        *self.cmd_tx.write().await = cmd_tx.clone();
148
149        if let Err(e) = cmd_tx.send(HandlerCommand::SetClient(ws_client)) {
150            return Err(KrakenWsError::ConnectionError(format!(
151                "Failed to send WebSocketClient to handler: {e}"
152            )));
153        }
154
155        let signal = self.signal.clone();
156
157        let stream_handle = get_runtime().spawn(async move {
158            let mut handler = FeedHandler::new(signal.clone(), cmd_rx, raw_rx);
159
160            loop {
161                match handler.next().await {
162                    Some(msg) => {
163                        if out_tx.send(msg).is_err() {
164                            tracing::error!("Failed to send message (receiver dropped)");
165                            break;
166                        }
167                    }
168                    None => {
169                        if handler.is_stopped() {
170                            tracing::debug!("Stop signal received, ending message processing");
171                            break;
172                        }
173                        tracing::warn!("WebSocket stream ended unexpectedly");
174                        break;
175                    }
176                }
177            }
178
179            tracing::debug!("Handler task exiting");
180        });
181
182        self.task_handle = Some(Arc::new(stream_handle));
183
184        tracing::debug!("WebSocket connected successfully");
185        Ok(())
186    }
187
188    pub async fn disconnect(&mut self) -> Result<(), KrakenWsError> {
189        tracing::debug!("Disconnecting WebSocket");
190
191        self.signal.store(true, Ordering::Relaxed);
192
193        if let Err(e) = self.cmd_tx.read().await.send(HandlerCommand::Disconnect) {
194            tracing::debug!(
195                "Failed to send disconnect command (handler may already be shut down): {e}"
196            );
197        }
198
199        if let Some(task_handle) = self.task_handle.take() {
200            match Arc::try_unwrap(task_handle) {
201                Ok(handle) => {
202                    tracing::debug!("Waiting for task handle to complete");
203                    match tokio::time::timeout(tokio::time::Duration::from_secs(2), handle).await {
204                        Ok(Ok(())) => tracing::debug!("Task handle completed successfully"),
205                        Ok(Err(e)) => tracing::error!("Task handle encountered an error: {e:?}"),
206                        Err(_) => {
207                            tracing::warn!(
208                                "Timeout waiting for task handle, task may still be running"
209                            );
210                        }
211                    }
212                }
213                Err(arc_handle) => {
214                    tracing::debug!(
215                        "Cannot take ownership of task handle - other references exist, aborting task"
216                    );
217                    arc_handle.abort();
218                }
219            }
220        } else {
221            tracing::debug!("No task handle to await");
222        }
223
224        self.subscriptions.clear();
225
226        Ok(())
227    }
228
229    pub async fn close(&mut self) -> Result<(), KrakenWsError> {
230        self.disconnect().await
231    }
232
233    pub async fn wait_until_active(&self, timeout_secs: f64) -> Result<(), KrakenWsError> {
234        let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
235
236        tokio::time::timeout(timeout, async {
237            while !self.is_active() {
238                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
239            }
240        })
241        .await
242        .map_err(|_| {
243            KrakenWsError::ConnectionError(format!(
244                "WebSocket connection timeout after {timeout_secs} seconds"
245            ))
246        })?;
247
248        Ok(())
249    }
250
251    pub async fn authenticate(&self) -> Result<(), KrakenWsError> {
252        if !self.config.has_api_credentials() {
253            return Err(KrakenWsError::AuthenticationError(
254                "API credentials required for authentication".to_string(),
255            ));
256        }
257
258        let api_key = self
259            .config
260            .api_key
261            .clone()
262            .ok_or_else(|| KrakenWsError::AuthenticationError("Missing API key".to_string()))?;
263        let api_secret =
264            self.config.api_secret.clone().ok_or_else(|| {
265                KrakenWsError::AuthenticationError("Missing API secret".to_string())
266            })?;
267
268        let http_client = KrakenHttpClient::with_credentials(
269            api_key,
270            api_secret,
271            Some(self.config.http_base_url()),
272            self.config.timeout_secs,
273            None,
274            None,
275            None,
276            self.config.http_proxy.clone(),
277        )
278        .map_err(|e| {
279            KrakenWsError::AuthenticationError(format!("Failed to create HTTP client: {e}"))
280        })?;
281
282        let ws_token = http_client.get_websockets_token().await.map_err(|e| {
283            KrakenWsError::AuthenticationError(format!("Failed to get WebSocket token: {e}"))
284        })?;
285
286        tracing::debug!(
287            token_length = ws_token.token.len(),
288            expires = ws_token.expires,
289            "WebSocket authentication token received"
290        );
291
292        let mut auth_token = self.auth_token.write().await;
293        *auth_token = Some(ws_token.token);
294
295        Ok(())
296    }
297
298    pub fn cache_instruments(&self, instruments: Vec<InstrumentAny>) {
299        // Before connect() the handler isn't running; this send will fail and that's expected
300        if let Ok(cmd_tx) = self.cmd_tx.try_read()
301            && let Err(e) = cmd_tx.send(HandlerCommand::InitializeInstruments(instruments))
302        {
303            tracing::debug!("Failed to send instruments to handler: {e}");
304        }
305    }
306
307    pub fn cache_instrument(&self, instrument: InstrumentAny) {
308        // Before connect() the handler isn't running; this send will fail and that's expected
309        if let Ok(cmd_tx) = self.cmd_tx.try_read()
310            && let Err(e) = cmd_tx.send(HandlerCommand::UpdateInstrument(instrument))
311        {
312            tracing::debug!("Failed to send instrument update to handler: {e}");
313        }
314    }
315
316    pub fn cancel_all_requests(&self) {
317        self.cancellation_token.cancel();
318    }
319
320    pub fn cancellation_token(&self) -> &CancellationToken {
321        &self.cancellation_token
322    }
323
324    pub async fn subscribe(
325        &self,
326        channel: KrakenWsChannel,
327        symbols: Vec<Ustr>,
328        depth: Option<u32>,
329    ) -> Result<(), KrakenWsError> {
330        let req_id = self.get_next_req_id().await;
331
332        // Check if channel requires authentication
333        let is_private = matches!(
334            channel,
335            KrakenWsChannel::Executions | KrakenWsChannel::Balances
336        );
337        let token = if is_private {
338            Some(self.auth_token.read().await.clone().ok_or_else(|| {
339                KrakenWsError::AuthenticationError(
340                    "Authentication token required for private channels. Call authenticate() first"
341                        .to_string(),
342                )
343            })?)
344        } else {
345            None
346        };
347
348        let request = KrakenWsRequest {
349            method: KrakenWsMethod::Subscribe,
350            params: Some(KrakenWsParams {
351                channel,
352                symbol: Some(symbols.clone()),
353                snapshot: None,
354                depth,
355                token,
356            }),
357            req_id: Some(req_id),
358        };
359
360        self.send_request(&request).await?;
361
362        for symbol in symbols {
363            let key = format!("{:?}:{}", channel, symbol);
364            self.subscriptions.insert(key, channel);
365        }
366
367        Ok(())
368    }
369
370    pub async fn unsubscribe(
371        &self,
372        channel: KrakenWsChannel,
373        symbols: Vec<Ustr>,
374    ) -> Result<(), KrakenWsError> {
375        let req_id = self.get_next_req_id().await;
376
377        // Check if channel requires authentication
378        let is_private = matches!(
379            channel,
380            KrakenWsChannel::Executions | KrakenWsChannel::Balances
381        );
382        let token = if is_private {
383            Some(self.auth_token.read().await.clone().ok_or_else(|| {
384                KrakenWsError::AuthenticationError(
385                    "Authentication token required for private channels. Call authenticate() first"
386                        .to_string(),
387                )
388            })?)
389        } else {
390            None
391        };
392
393        let request = KrakenWsRequest {
394            method: KrakenWsMethod::Unsubscribe,
395            params: Some(KrakenWsParams {
396                channel,
397                symbol: Some(symbols.clone()),
398                snapshot: None,
399                depth: None,
400                token,
401            }),
402            req_id: Some(req_id),
403        };
404
405        self.send_request(&request).await?;
406
407        for symbol in symbols {
408            let key = format!("{:?}:{}", channel, symbol);
409            self.subscriptions.remove(&key);
410        }
411
412        Ok(())
413    }
414
415    pub async fn send_ping(&self) -> Result<(), KrakenWsError> {
416        let req_id = self.get_next_req_id().await;
417
418        let request = KrakenWsRequest {
419            method: KrakenWsMethod::Ping,
420            params: None,
421            req_id: Some(req_id),
422        };
423
424        self.send_request(&request).await
425    }
426
427    async fn send_request(&self, request: &KrakenWsRequest) -> Result<(), KrakenWsError> {
428        let payload =
429            serde_json::to_string(request).map_err(|e| KrakenWsError::JsonError(e.to_string()))?;
430
431        tracing::trace!("Sending message: {payload}");
432
433        self.cmd_tx
434            .read()
435            .await
436            .send(HandlerCommand::SendText { payload })
437            .map_err(|e| KrakenWsError::ConnectionError(format!("Failed to send request: {e}")))?;
438
439        Ok(())
440    }
441
442    pub fn is_connected(&self) -> bool {
443        let connection_mode_arc = self.connection_mode.load();
444        !ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
445    }
446
447    pub fn is_active(&self) -> bool {
448        let connection_mode_arc = self.connection_mode.load();
449        ConnectionMode::from_atomic(&connection_mode_arc).is_active()
450            && !self.signal.load(Ordering::Relaxed)
451    }
452
453    pub fn is_closed(&self) -> bool {
454        let connection_mode_arc = self.connection_mode.load();
455        ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
456            || self.signal.load(Ordering::Relaxed)
457    }
458
459    pub fn url(&self) -> &str {
460        &self.url
461    }
462
463    pub fn get_subscriptions(&self) -> Vec<String> {
464        self.subscriptions
465            .iter()
466            .map(|entry| entry.key().clone())
467            .collect()
468    }
469
470    pub fn stream(&mut self) -> impl futures_util::Stream<Item = NautilusWsMessage> + use<> {
471        let rx = self
472            .out_rx
473            .take()
474            .expect("Stream receiver already taken or client not connected");
475        let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
476        async_stream::stream! {
477            while let Some(msg) = rx.recv().await {
478                yield msg;
479            }
480        }
481    }
482
483    pub async fn subscribe_book(
484        &self,
485        instrument_id: InstrumentId,
486        depth: Option<u32>,
487    ) -> Result<(), KrakenWsError> {
488        let symbol = Ustr::from(instrument_id.symbol.as_str());
489        self.subscribe(KrakenWsChannel::Book, vec![symbol], depth)
490            .await
491    }
492
493    pub async fn subscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), KrakenWsError> {
494        let symbol = Ustr::from(instrument_id.symbol.as_str());
495        self.subscribe(KrakenWsChannel::Ticker, vec![symbol], None)
496            .await
497    }
498
499    pub async fn subscribe_trades(&self, instrument_id: InstrumentId) -> Result<(), KrakenWsError> {
500        let symbol = Ustr::from(instrument_id.symbol.as_str());
501        self.subscribe(KrakenWsChannel::Trade, vec![symbol], None)
502            .await
503    }
504
505    pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), KrakenWsError> {
506        let symbol = Ustr::from(bar_type.instrument_id().symbol.as_str());
507        self.subscribe(KrakenWsChannel::Ohlc, vec![symbol], None)
508            .await
509    }
510
511    pub async fn unsubscribe_book(&self, instrument_id: InstrumentId) -> Result<(), KrakenWsError> {
512        let symbol = Ustr::from(instrument_id.symbol.as_str());
513        self.unsubscribe(KrakenWsChannel::Book, vec![symbol]).await
514    }
515
516    pub async fn unsubscribe_quotes(
517        &self,
518        instrument_id: InstrumentId,
519    ) -> Result<(), KrakenWsError> {
520        let symbol = Ustr::from(instrument_id.symbol.as_str());
521        self.unsubscribe(KrakenWsChannel::Ticker, vec![symbol])
522            .await
523    }
524
525    pub async fn unsubscribe_trades(
526        &self,
527        instrument_id: InstrumentId,
528    ) -> Result<(), KrakenWsError> {
529        let symbol = Ustr::from(instrument_id.symbol.as_str());
530        self.unsubscribe(KrakenWsChannel::Trade, vec![symbol]).await
531    }
532
533    pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), KrakenWsError> {
534        let symbol = Ustr::from(bar_type.instrument_id().symbol.as_str());
535        self.unsubscribe(KrakenWsChannel::Ohlc, vec![symbol]).await
536    }
537}