1use std::{
23 fmt::Debug,
24 num::NonZeroU32,
25 sync::{
26 Arc, LazyLock,
27 atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
28 },
29 time::Duration,
30};
31
32use arc_swap::ArcSwap;
33use dashmap::DashMap;
34use futures_util::Stream;
35use nautilus_common::live::get_runtime;
36use nautilus_core::{
37 consts::NAUTILUS_USER_AGENT, env::get_or_env_var_opt, time::get_atomic_clock_realtime,
38};
39use nautilus_model::{
40 identifiers::InstrumentId,
41 instruments::{Instrument, InstrumentAny},
42};
43use nautilus_network::{
44 http::USER_AGENT,
45 mode::ConnectionMode,
46 ratelimiter::quota::Quota,
47 websocket::{
48 AuthTracker, PingHandler, SubscriptionState, WebSocketClient, WebSocketConfig,
49 channel_message_handler,
50 },
51};
52use tokio_util::sync::CancellationToken;
53use ustr::Ustr;
54
55use super::{
56 auth::{AuthState, DEFAULT_SESSION_NAME, send_auth_request, spawn_token_refresh_task},
57 enums::{DeribitUpdateInterval, DeribitWsChannel},
58 error::{DeribitWsError, DeribitWsResult},
59 handler::{DeribitWsFeedHandler, HandlerCommand},
60 messages::NautilusWsMessage,
61};
62use crate::common::{
63 consts::{DERIBIT_TESTNET_WS_URL, DERIBIT_WS_URL},
64 credential::Credential,
65};
66
67pub static DERIBIT_WS_SUBSCRIPTION_QUOTA: LazyLock<Quota> =
69 LazyLock::new(|| Quota::per_second(NonZeroU32::new(20).unwrap()));
70
71const AUTHENTICATION_TIMEOUT_SECS: u64 = 30;
73
74#[derive(Clone)]
76#[cfg_attr(
77 feature = "python",
78 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.deribit")
79)]
80pub struct DeribitWebSocketClient {
81 url: String,
82 is_testnet: bool,
83 heartbeat_interval: Option<u64>,
84 credential: Option<Credential>,
85 is_authenticated: Arc<AtomicBool>,
86 auth_state: Arc<tokio::sync::RwLock<Option<AuthState>>>,
87 signal: Arc<AtomicBool>,
88 connection_mode: Arc<ArcSwap<AtomicU8>>,
89 auth_tracker: AuthTracker,
90 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
91 out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
92 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
93 subscriptions_state: SubscriptionState,
94 subscribed_channels: Arc<DashMap<String, ()>>,
95 instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
96 request_id_counter: Arc<AtomicU64>,
97 cancellation_token: CancellationToken,
98}
99
100impl Debug for DeribitWebSocketClient {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 f.debug_struct("DeribitWebSocketClient")
103 .field("url", &self.url)
104 .field("is_testnet", &self.is_testnet)
105 .field("has_credentials", &self.credential.is_some())
106 .field(
107 "is_authenticated",
108 &self.is_authenticated.load(Ordering::Relaxed),
109 )
110 .field(
111 "has_auth_state",
112 &self
113 .auth_state
114 .try_read()
115 .map(|s| s.is_some())
116 .unwrap_or(false),
117 )
118 .field("heartbeat_interval", &self.heartbeat_interval)
119 .finish_non_exhaustive()
120 }
121}
122
123impl DeribitWebSocketClient {
124 pub fn new(
130 url: Option<String>,
131 api_key: Option<String>,
132 api_secret: Option<String>,
133 heartbeat_interval: Option<u64>,
134 is_testnet: bool,
135 ) -> anyhow::Result<Self> {
136 let url = url.unwrap_or_else(|| {
137 if is_testnet {
138 DERIBIT_TESTNET_WS_URL.to_string()
139 } else {
140 DERIBIT_WS_URL.to_string()
141 }
142 });
143
144 let credential = match (api_key, api_secret) {
146 (Some(key), Some(secret)) => Some(Credential::new(key, secret)),
147 (None, None) => None,
148 _ => anyhow::bail!("Both api_key and api_secret must be provided together, or neither"),
149 };
150
151 let signal = Arc::new(AtomicBool::new(false));
152 let subscriptions_state = SubscriptionState::new('.');
153
154 Ok(Self {
155 url,
156 is_testnet,
157 heartbeat_interval,
158 credential,
159 is_authenticated: Arc::new(AtomicBool::new(false)),
160 auth_state: Arc::new(tokio::sync::RwLock::new(None)),
161 signal,
162 connection_mode: Arc::new(ArcSwap::from_pointee(AtomicU8::new(
163 ConnectionMode::Closed.as_u8(),
164 ))),
165 auth_tracker: AuthTracker::new(),
166 cmd_tx: {
167 let (tx, _) = tokio::sync::mpsc::unbounded_channel();
168 Arc::new(tokio::sync::RwLock::new(tx))
169 },
170 out_rx: None,
171 task_handle: None,
172 subscriptions_state,
173 subscribed_channels: Arc::new(DashMap::new()),
174 instruments_cache: Arc::new(DashMap::new()),
175 request_id_counter: Arc::new(AtomicU64::new(1)),
176 cancellation_token: CancellationToken::new(),
177 })
178 }
179
180 pub fn new_public(is_testnet: bool) -> anyhow::Result<Self> {
186 let heartbeat_interval = 10;
187 Self::new(None, None, None, Some(heartbeat_interval), is_testnet)
188 }
189
190 pub fn with_credentials(is_testnet: bool) -> anyhow::Result<Self> {
200 let (key_env, secret_env) = if is_testnet {
201 ("DERIBIT_TESTNET_API_KEY", "DERIBIT_TESTNET_API_SECRET")
202 } else {
203 ("DERIBIT_API_KEY", "DERIBIT_API_SECRET")
204 };
205
206 let api_key = get_or_env_var_opt(None, key_env)
207 .ok_or_else(|| anyhow::anyhow!("Missing environment variable: {key_env}"))?;
208 let api_secret = get_or_env_var_opt(None, secret_env)
209 .ok_or_else(|| anyhow::anyhow!("Missing environment variable: {secret_env}"))?;
210
211 let heartbeat_interval = 10;
212 Self::new(
213 None,
214 Some(api_key),
215 Some(api_secret),
216 Some(heartbeat_interval),
217 is_testnet,
218 )
219 }
220
221 fn connection_mode(&self) -> ConnectionMode {
223 let mode_u8 = self.connection_mode.load().load(Ordering::Relaxed);
224 ConnectionMode::from_u8(mode_u8)
225 }
226
227 #[must_use]
229 pub fn is_active(&self) -> bool {
230 self.connection_mode() == ConnectionMode::Active
231 }
232
233 #[must_use]
235 pub fn url(&self) -> &str {
236 &self.url
237 }
238
239 #[must_use]
241 pub fn is_closed(&self) -> bool {
242 self.connection_mode() == ConnectionMode::Disconnect
243 }
244
245 pub fn cancel_all_requests(&self) {
247 self.cancellation_token.cancel();
248 }
249
250 #[must_use]
252 pub fn cancellation_token(&self) -> &CancellationToken {
253 &self.cancellation_token
254 }
255
256 pub async fn wait_until_active(&self, timeout_secs: f64) -> DeribitWsResult<()> {
262 let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
263
264 tokio::time::timeout(timeout, async {
265 while !self.is_active() {
266 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
267 }
268 })
269 .await
270 .map_err(|_| {
271 DeribitWsError::Timeout(format!(
272 "WebSocket connection timeout after {timeout_secs} seconds"
273 ))
274 })?;
275
276 Ok(())
277 }
278
279 pub fn cache_instruments(&self, instruments: Vec<InstrumentAny>) {
281 self.instruments_cache.clear();
282 for inst in instruments {
283 self.instruments_cache
284 .insert(inst.raw_symbol().inner(), inst);
285 }
286 tracing::debug!("Cached {} instruments", self.instruments_cache.len());
287 }
288
289 pub fn cache_instrument(&self, instrument: InstrumentAny) {
291 let symbol = instrument.raw_symbol().inner();
292 self.instruments_cache.insert(symbol, instrument);
293
294 if self.is_active() {
296 let tx = self.cmd_tx.clone();
297 let inst = self.instruments_cache.get(&symbol).map(|r| r.clone());
298 if let Some(inst) = inst {
299 get_runtime().spawn(async move {
300 let _ = tx
301 .read()
302 .await
303 .send(HandlerCommand::UpdateInstrument(Box::new(inst)));
304 });
305 }
306 }
307 }
308
309 pub async fn connect(&mut self) -> anyhow::Result<()> {
315 tracing::info!("Connecting to Deribit WebSocket: {}", self.url);
316
317 self.signal.store(false, Ordering::Relaxed);
319
320 let (message_handler, raw_rx) = channel_message_handler();
322
323 let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {
325 });
327
328 let config = WebSocketConfig {
330 url: self.url.clone(),
331 headers: vec![(USER_AGENT.to_string(), NAUTILUS_USER_AGENT.to_string())],
332 heartbeat: self.heartbeat_interval,
333 heartbeat_msg: None, reconnect_timeout_ms: Some(5_000),
335 reconnect_delay_initial_ms: None,
336 reconnect_delay_max_ms: None,
337 reconnect_backoff_factor: None,
338 reconnect_jitter_ms: None,
339 reconnect_max_attempts: None,
340 };
341
342 let keyed_quotas = vec![("subscription".to_string(), *DERIBIT_WS_SUBSCRIPTION_QUOTA)];
344
345 let ws_client = WebSocketClient::connect(
347 config,
348 Some(message_handler),
349 Some(ping_handler),
350 None, keyed_quotas,
352 Some(*DERIBIT_WS_SUBSCRIPTION_QUOTA), )
354 .await?;
355
356 self.connection_mode
358 .store(ws_client.connection_mode_atomic());
359
360 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel();
362 let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel();
363
364 *self.cmd_tx.write().await = cmd_tx.clone();
366 self.out_rx = Some(Arc::new(out_rx));
367
368 let mut handler = DeribitWsFeedHandler::new(
370 self.signal.clone(),
371 cmd_rx,
372 raw_rx,
373 out_tx,
374 self.auth_tracker.clone(),
375 self.subscriptions_state.clone(),
376 );
377
378 let _ = cmd_tx.send(HandlerCommand::SetClient(ws_client));
380
381 let instruments: Vec<InstrumentAny> =
383 self.instruments_cache.iter().map(|r| r.clone()).collect();
384 if !instruments.is_empty() {
385 let _ = cmd_tx.send(HandlerCommand::InitializeInstruments(instruments));
386 }
387
388 if let Some(interval) = self.heartbeat_interval {
390 let _ = cmd_tx.send(HandlerCommand::SetHeartbeat { interval });
391 }
392
393 let subscriptions_state = self.subscriptions_state.clone();
395 let subscribed_channels = self.subscribed_channels.clone();
396 let credential = self.credential.clone();
397 let is_authenticated = self.is_authenticated.clone();
398 let auth_state = self.auth_state.clone();
399 let request_id_counter = self.request_id_counter.clone();
400
401 let task_handle = get_runtime().spawn(async move {
402 let mut pending_reauth = false;
404
405 loop {
406 match handler.next().await {
407 Some(msg) => match msg {
408 NautilusWsMessage::Reconnected => {
409 tracing::info!("Reconnected to Deribit WebSocket");
410
411 let channels: Vec<String> = subscribed_channels
413 .iter()
414 .map(|r| r.key().clone())
415 .collect();
416
417 for channel in &channels {
419 subscriptions_state.mark_failure(channel);
420 }
421
422 if let Some(cred) = &credential {
424 tracing::info!("Re-authenticating after reconnection...");
425
426 is_authenticated.store(false, Ordering::Release);
428 pending_reauth = true;
429
430 let previous_scope = auth_state
432 .read()
433 .await
434 .as_ref()
435 .map(|s| s.scope.clone());
436
437 send_auth_request(cred, previous_scope, &cmd_tx, &request_id_counter);
439 } else {
440 if !channels.is_empty() {
442 let _ = cmd_tx.send(HandlerCommand::Subscribe { channels });
443 }
444 }
445 }
446 NautilusWsMessage::Authenticated(result) => {
447 let timestamp = get_atomic_clock_realtime().get_time_ms();
448 let new_auth_state = AuthState::from_auth_result(&result, timestamp);
449 *auth_state.write().await = Some(new_auth_state);
450
451 spawn_token_refresh_task(
453 result.expires_in,
454 result.refresh_token.clone(),
455 cmd_tx.clone(),
456 request_id_counter.clone(),
457 );
458
459 if pending_reauth {
460 pending_reauth = false;
461 is_authenticated.store(true, Ordering::Release);
462 tracing::info!(
463 "Re-authentication successful (scope: {}), resubscribing to channels",
464 result.scope
465 );
466
467 let channels: Vec<String> = subscribed_channels
469 .iter()
470 .map(|r| r.key().clone())
471 .collect();
472
473 if !channels.is_empty() {
474 let _ = cmd_tx.send(HandlerCommand::Subscribe { channels });
475 }
476 } else {
477 is_authenticated.store(true, Ordering::Release);
479 tracing::debug!(
480 "Auth state stored: scope={}, expires_in={}s",
481 result.scope,
482 result.expires_in
483 );
484 }
485 }
486 _ => {}
487 },
488 None => {
489 tracing::debug!("Handler returned None, stopping task");
490 break;
491 }
492 }
493 }
494 });
495
496 self.task_handle = Some(Arc::new(task_handle));
497 tracing::info!("Connected to Deribit WebSocket");
498
499 Ok(())
500 }
501
502 pub async fn close(&self) -> DeribitWsResult<()> {
508 tracing::info!("Closing Deribit WebSocket connection");
509 self.signal.store(true, Ordering::Relaxed);
510
511 let _ = self.cmd_tx.read().await.send(HandlerCommand::Disconnect);
512
513 if let Some(_handle) = &self.task_handle {
515 let _ = tokio::time::timeout(Duration::from_secs(5), async {
516 tokio::time::sleep(Duration::from_millis(100)).await;
518 })
519 .await;
520 }
521
522 Ok(())
523 }
524
525 pub fn stream(&mut self) -> impl Stream<Item = NautilusWsMessage> + 'static {
531 let rx = self
532 .out_rx
533 .take()
534 .expect("Data stream receiver already taken or not connected");
535 let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
536
537 async_stream::stream! {
538 while let Some(msg) = rx.recv().await {
539 yield msg;
540 }
541 }
542 }
543
544 #[must_use]
546 pub fn has_credentials(&self) -> bool {
547 self.credential.is_some()
548 }
549
550 #[must_use]
552 pub fn is_authenticated(&self) -> bool {
553 self.is_authenticated.load(Ordering::Acquire)
554 }
555
556 pub async fn authenticate(&self, session_name: Option<&str>) -> DeribitWsResult<()> {
575 let credential = self.credential.as_ref().ok_or_else(|| {
576 DeribitWsError::Authentication("API credentials not configured".to_string())
577 })?;
578
579 let scope = session_name.map(|name| format!("session:{name}"));
581
582 tracing::info!(
583 "Authenticating WebSocket with API key: {}, scope: {}",
584 credential.api_key_masked(),
585 scope.as_deref().unwrap_or("connection (default)")
586 );
587
588 let rx = self.auth_tracker.begin();
589
590 let cmd_tx = self.cmd_tx.read().await;
592 send_auth_request(credential, scope, &cmd_tx, &self.request_id_counter);
593 drop(cmd_tx);
594
595 match self
597 .auth_tracker
598 .wait_for_result::<DeribitWsError>(Duration::from_secs(AUTHENTICATION_TIMEOUT_SECS), rx)
599 .await
600 {
601 Ok(()) => {
602 self.is_authenticated.store(true, Ordering::Release);
603 tracing::info!("WebSocket authenticated successfully");
604 Ok(())
605 }
606 Err(e) => {
607 tracing::error!(error = %e, "WebSocket authentication failed");
608 Err(e)
609 }
610 }
611 }
612
613 pub async fn authenticate_session(&self) -> DeribitWsResult<()> {
623 self.authenticate(Some(DEFAULT_SESSION_NAME)).await
624 }
625
626 pub async fn auth_state(&self) -> Option<AuthState> {
630 self.auth_state.read().await.clone()
631 }
632
633 pub async fn access_token(&self) -> Option<String> {
635 self.auth_state
636 .read()
637 .await
638 .as_ref()
639 .map(|s| s.access_token.clone())
640 }
641
642 async fn send_subscribe(&self, channels: Vec<String>) -> DeribitWsResult<()> {
647 for channel in &channels {
649 self.subscribed_channels.insert(channel.clone(), ());
650 }
651
652 self.cmd_tx
653 .read()
654 .await
655 .send(HandlerCommand::Subscribe {
656 channels: channels.clone(),
657 })
658 .map_err(|e| DeribitWsError::Send(e.to_string()))?;
659
660 tracing::debug!("Sent subscribe for {} channels", channels.len());
661 Ok(())
662 }
663
664 async fn send_unsubscribe(&self, channels: Vec<String>) -> DeribitWsResult<()> {
665 for channel in &channels {
667 self.subscribed_channels.remove(channel);
668 }
669
670 self.cmd_tx
671 .read()
672 .await
673 .send(HandlerCommand::Unsubscribe {
674 channels: channels.clone(),
675 })
676 .map_err(|e| DeribitWsError::Send(e.to_string()))?;
677
678 tracing::debug!("Sent unsubscribe for {} channels", channels.len());
679 Ok(())
680 }
681
682 pub async fn subscribe_trades(
693 &self,
694 instrument_id: InstrumentId,
695 interval: Option<DeribitUpdateInterval>,
696 ) -> DeribitWsResult<()> {
697 let interval = interval.unwrap_or_default();
698 self.check_auth_requirement(interval)?;
699 let channel =
700 DeribitWsChannel::Trades.format_channel(instrument_id.symbol.as_str(), Some(interval));
701 self.send_subscribe(vec![channel]).await
702 }
703
704 pub async fn subscribe_trades_raw(&self, instrument_id: InstrumentId) -> DeribitWsResult<()> {
712 self.subscribe_trades(instrument_id, Some(DeribitUpdateInterval::Raw))
713 .await
714 }
715
716 pub async fn unsubscribe_trades(
722 &self,
723 instrument_id: InstrumentId,
724 interval: Option<DeribitUpdateInterval>,
725 ) -> DeribitWsResult<()> {
726 let interval = interval.unwrap_or_default();
727 let channel =
728 DeribitWsChannel::Trades.format_channel(instrument_id.symbol.as_str(), Some(interval));
729 self.send_unsubscribe(vec![channel]).await
730 }
731
732 pub async fn subscribe_book(
743 &self,
744 instrument_id: InstrumentId,
745 interval: Option<DeribitUpdateInterval>,
746 ) -> DeribitWsResult<()> {
747 let interval = interval.unwrap_or_default();
748 self.check_auth_requirement(interval)?;
749 let channel =
750 DeribitWsChannel::Book.format_channel(instrument_id.symbol.as_str(), Some(interval));
751 self.send_subscribe(vec![channel]).await
752 }
753
754 pub async fn subscribe_book_raw(&self, instrument_id: InstrumentId) -> DeribitWsResult<()> {
762 self.subscribe_book(instrument_id, Some(DeribitUpdateInterval::Raw))
763 .await
764 }
765
766 pub async fn unsubscribe_book(
772 &self,
773 instrument_id: InstrumentId,
774 interval: Option<DeribitUpdateInterval>,
775 ) -> DeribitWsResult<()> {
776 let interval = interval.unwrap_or_default();
777 let channel =
778 DeribitWsChannel::Book.format_channel(instrument_id.symbol.as_str(), Some(interval));
779 self.send_unsubscribe(vec![channel]).await
780 }
781
782 pub async fn subscribe_ticker(
793 &self,
794 instrument_id: InstrumentId,
795 interval: Option<DeribitUpdateInterval>,
796 ) -> DeribitWsResult<()> {
797 let interval = interval.unwrap_or_default();
798 self.check_auth_requirement(interval)?;
799 let channel =
800 DeribitWsChannel::Ticker.format_channel(instrument_id.symbol.as_str(), Some(interval));
801 self.send_subscribe(vec![channel]).await
802 }
803
804 pub async fn subscribe_ticker_raw(&self, instrument_id: InstrumentId) -> DeribitWsResult<()> {
812 self.subscribe_ticker(instrument_id, Some(DeribitUpdateInterval::Raw))
813 .await
814 }
815
816 pub async fn unsubscribe_ticker(
822 &self,
823 instrument_id: InstrumentId,
824 interval: Option<DeribitUpdateInterval>,
825 ) -> DeribitWsResult<()> {
826 let interval = interval.unwrap_or_default();
827 let channel =
828 DeribitWsChannel::Ticker.format_channel(instrument_id.symbol.as_str(), Some(interval));
829 self.send_unsubscribe(vec![channel]).await
830 }
831
832 pub async fn subscribe_quotes(&self, instrument_id: InstrumentId) -> DeribitWsResult<()> {
840 let channel = DeribitWsChannel::Quote.format_channel(instrument_id.symbol.as_str(), None);
841 self.send_subscribe(vec![channel]).await
842 }
843
844 pub async fn unsubscribe_quotes(&self, instrument_id: InstrumentId) -> DeribitWsResult<()> {
850 let channel = DeribitWsChannel::Quote.format_channel(instrument_id.symbol.as_str(), None);
851 self.send_unsubscribe(vec![channel]).await
852 }
853
854 fn check_auth_requirement(&self, interval: DeribitUpdateInterval) -> DeribitWsResult<()> {
860 if interval.requires_auth() && !self.is_authenticated() {
861 return Err(DeribitWsError::Authentication(
862 "Raw streams require authentication. Call authenticate() first.".to_string(),
863 ));
864 }
865 Ok(())
866 }
867
868 pub async fn subscribe(&self, channels: Vec<String>) -> DeribitWsResult<()> {
874 self.send_subscribe(channels).await
875 }
876
877 pub async fn unsubscribe(&self, channels: Vec<String>) -> DeribitWsResult<()> {
883 self.send_unsubscribe(channels).await
884 }
885}