1use std::{
23 fmt::Debug,
24 num::NonZeroU32,
25 sync::{
26 Arc, LazyLock,
27 atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
28 },
29 time::Duration,
30};
31
32use arc_swap::ArcSwap;
33use dashmap::DashMap;
34use futures_util::Stream;
35use nautilus_common::live::runtime::get_runtime;
36use nautilus_core::{
37 consts::NAUTILUS_USER_AGENT, env::get_or_env_var_opt, time::get_atomic_clock_realtime,
38};
39use nautilus_model::{
40 identifiers::InstrumentId,
41 instruments::{Instrument, InstrumentAny},
42};
43use nautilus_network::{
44 http::USER_AGENT,
45 mode::ConnectionMode,
46 ratelimiter::quota::Quota,
47 websocket::{
48 AuthTracker, PingHandler, SubscriptionState, WebSocketClient, WebSocketConfig,
49 channel_message_handler,
50 },
51};
52use tokio_util::sync::CancellationToken;
53use ustr::Ustr;
54
55use super::{
56 auth::{AuthState, DEFAULT_SESSION_NAME, send_auth_request, spawn_token_refresh_task},
57 enums::{DeribitUpdateInterval, DeribitWsChannel},
58 error::{DeribitWsError, DeribitWsResult},
59 handler::{DeribitWsFeedHandler, HandlerCommand},
60 messages::NautilusWsMessage,
61};
62use crate::common::{
63 consts::{DERIBIT_TESTNET_WS_URL, DERIBIT_WS_URL},
64 credential::Credential,
65};
66
67pub static DERIBIT_WS_SUBSCRIPTION_QUOTA: LazyLock<Quota> =
69 LazyLock::new(|| Quota::per_second(NonZeroU32::new(20).unwrap()));
70
71const AUTHENTICATION_TIMEOUT_SECS: u64 = 30;
73
74#[derive(Clone)]
76pub struct DeribitWebSocketClient {
77 url: String,
78 is_testnet: bool,
79 heartbeat_interval: Option<u64>,
80 credential: Option<Credential>,
81 is_authenticated: Arc<AtomicBool>,
82 auth_state: Arc<tokio::sync::RwLock<Option<AuthState>>>,
83 signal: Arc<AtomicBool>,
84 connection_mode: Arc<ArcSwap<AtomicU8>>,
85 auth_tracker: AuthTracker,
86 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
87 out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
88 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
89 subscriptions_state: SubscriptionState,
90 subscribed_channels: Arc<DashMap<String, ()>>,
91 instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
92 request_id_counter: Arc<AtomicU64>,
93 cancellation_token: CancellationToken,
94}
95
96impl Debug for DeribitWebSocketClient {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 f.debug_struct("DeribitWebSocketClient")
99 .field("url", &self.url)
100 .field("is_testnet", &self.is_testnet)
101 .field("has_credentials", &self.credential.is_some())
102 .field(
103 "is_authenticated",
104 &self.is_authenticated.load(Ordering::Relaxed),
105 )
106 .field(
107 "has_auth_state",
108 &self
109 .auth_state
110 .try_read()
111 .map(|s| s.is_some())
112 .unwrap_or(false),
113 )
114 .field("heartbeat_interval", &self.heartbeat_interval)
115 .finish_non_exhaustive()
116 }
117}
118
119impl DeribitWebSocketClient {
120 pub fn new(
126 url: Option<String>,
127 api_key: Option<String>,
128 api_secret: Option<String>,
129 heartbeat_interval: Option<u64>,
130 is_testnet: bool,
131 ) -> anyhow::Result<Self> {
132 let url = url.unwrap_or_else(|| {
133 if is_testnet {
134 DERIBIT_TESTNET_WS_URL.to_string()
135 } else {
136 DERIBIT_WS_URL.to_string()
137 }
138 });
139
140 let credential = match (api_key, api_secret) {
142 (Some(key), Some(secret)) => Some(Credential::new(key, secret)),
143 (None, None) => None,
144 _ => anyhow::bail!("Both api_key and api_secret must be provided together, or neither"),
145 };
146
147 let signal = Arc::new(AtomicBool::new(false));
148 let subscriptions_state = SubscriptionState::new('.');
149
150 Ok(Self {
151 url,
152 is_testnet,
153 heartbeat_interval,
154 credential,
155 is_authenticated: Arc::new(AtomicBool::new(false)),
156 auth_state: Arc::new(tokio::sync::RwLock::new(None)),
157 signal,
158 connection_mode: Arc::new(ArcSwap::from_pointee(AtomicU8::new(
159 ConnectionMode::Closed.as_u8(),
160 ))),
161 auth_tracker: AuthTracker::new(),
162 cmd_tx: {
163 let (tx, _) = tokio::sync::mpsc::unbounded_channel();
164 Arc::new(tokio::sync::RwLock::new(tx))
165 },
166 out_rx: None,
167 task_handle: None,
168 subscriptions_state,
169 subscribed_channels: Arc::new(DashMap::new()),
170 instruments_cache: Arc::new(DashMap::new()),
171 request_id_counter: Arc::new(AtomicU64::new(1)),
172 cancellation_token: CancellationToken::new(),
173 })
174 }
175
176 pub fn new_public(is_testnet: bool) -> anyhow::Result<Self> {
182 let heartbeat_interval = 10;
183 Self::new(None, None, None, Some(heartbeat_interval), is_testnet)
184 }
185
186 pub fn with_credentials(is_testnet: bool) -> anyhow::Result<Self> {
196 let (key_env, secret_env) = if is_testnet {
197 ("DERIBIT_TESTNET_API_KEY", "DERIBIT_TESTNET_API_SECRET")
198 } else {
199 ("DERIBIT_API_KEY", "DERIBIT_API_SECRET")
200 };
201
202 let api_key = get_or_env_var_opt(None, key_env)
203 .ok_or_else(|| anyhow::anyhow!("Missing environment variable: {key_env}"))?;
204 let api_secret = get_or_env_var_opt(None, secret_env)
205 .ok_or_else(|| anyhow::anyhow!("Missing environment variable: {secret_env}"))?;
206
207 let heartbeat_interval = 10;
208 Self::new(
209 None,
210 Some(api_key),
211 Some(api_secret),
212 Some(heartbeat_interval),
213 is_testnet,
214 )
215 }
216
217 fn connection_mode(&self) -> ConnectionMode {
219 let mode_u8 = self.connection_mode.load().load(Ordering::Relaxed);
220 ConnectionMode::from_u8(mode_u8)
221 }
222
223 #[must_use]
225 pub fn is_active(&self) -> bool {
226 self.connection_mode() == ConnectionMode::Active
227 }
228
229 #[must_use]
231 pub fn is_closed(&self) -> bool {
232 self.connection_mode() == ConnectionMode::Disconnect
233 }
234
235 pub fn cancel_all_requests(&self) {
237 self.cancellation_token.cancel();
238 }
239
240 #[must_use]
242 pub fn cancellation_token(&self) -> &CancellationToken {
243 &self.cancellation_token
244 }
245
246 pub async fn wait_until_active(&self, timeout_secs: f64) -> DeribitWsResult<()> {
252 let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
253
254 tokio::time::timeout(timeout, async {
255 while !self.is_active() {
256 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
257 }
258 })
259 .await
260 .map_err(|_| {
261 DeribitWsError::Timeout(format!(
262 "WebSocket connection timeout after {timeout_secs} seconds"
263 ))
264 })?;
265
266 Ok(())
267 }
268
269 pub fn cache_instruments(&self, instruments: Vec<InstrumentAny>) {
271 self.instruments_cache.clear();
272 for inst in instruments {
273 self.instruments_cache
274 .insert(inst.raw_symbol().inner(), inst);
275 }
276 tracing::debug!("Cached {} instruments", self.instruments_cache.len());
277 }
278
279 pub fn cache_instrument(&self, instrument: InstrumentAny) {
281 let symbol = instrument.raw_symbol().inner();
282 self.instruments_cache.insert(symbol, instrument);
283
284 if self.is_active() {
286 let tx = self.cmd_tx.clone();
287 let inst = self.instruments_cache.get(&symbol).map(|r| r.clone());
288 if let Some(inst) = inst {
289 tokio::spawn(async move {
290 let _ = tx
291 .read()
292 .await
293 .send(HandlerCommand::UpdateInstrument(Box::new(inst)));
294 });
295 }
296 }
297 }
298
299 pub async fn connect(&mut self) -> anyhow::Result<()> {
305 tracing::info!("Connecting to Deribit WebSocket: {}", self.url);
306
307 self.signal.store(false, Ordering::Relaxed);
309
310 let (message_handler, raw_rx) = channel_message_handler();
312
313 let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {
315 });
317
318 let config = WebSocketConfig {
320 url: self.url.clone(),
321 headers: vec![(USER_AGENT.to_string(), NAUTILUS_USER_AGENT.to_string())],
322 heartbeat: self.heartbeat_interval,
323 heartbeat_msg: None, message_handler: Some(message_handler),
325 ping_handler: Some(ping_handler),
326 reconnect_timeout_ms: Some(5_000),
327 reconnect_delay_initial_ms: None,
328 reconnect_delay_max_ms: None,
329 reconnect_backoff_factor: None,
330 reconnect_jitter_ms: None,
331 reconnect_max_attempts: None,
332 };
333
334 let keyed_quotas = vec![("subscription".to_string(), *DERIBIT_WS_SUBSCRIPTION_QUOTA)];
336
337 let ws_client = WebSocketClient::connect(
339 config,
340 None, keyed_quotas,
342 Some(*DERIBIT_WS_SUBSCRIPTION_QUOTA), )
344 .await?;
345
346 self.connection_mode
348 .store(ws_client.connection_mode_atomic());
349
350 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel();
352 let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel();
353
354 *self.cmd_tx.write().await = cmd_tx.clone();
356 self.out_rx = Some(Arc::new(out_rx));
357
358 let mut handler = DeribitWsFeedHandler::new(
360 self.signal.clone(),
361 cmd_rx,
362 raw_rx,
363 out_tx,
364 self.auth_tracker.clone(),
365 self.subscriptions_state.clone(),
366 );
367
368 let _ = cmd_tx.send(HandlerCommand::SetClient(ws_client));
370
371 let instruments: Vec<InstrumentAny> =
373 self.instruments_cache.iter().map(|r| r.clone()).collect();
374 if !instruments.is_empty() {
375 let _ = cmd_tx.send(HandlerCommand::InitializeInstruments(instruments));
376 }
377
378 if let Some(interval) = self.heartbeat_interval {
380 let _ = cmd_tx.send(HandlerCommand::SetHeartbeat { interval });
381 }
382
383 let subscriptions_state = self.subscriptions_state.clone();
385 let subscribed_channels = self.subscribed_channels.clone();
386 let credential = self.credential.clone();
387 let is_authenticated = self.is_authenticated.clone();
388 let auth_state = self.auth_state.clone();
389 let request_id_counter = self.request_id_counter.clone();
390
391 let task_handle = get_runtime().spawn(async move {
392 let mut pending_reauth = false;
394
395 loop {
396 match handler.next().await {
397 Some(msg) => match msg {
398 NautilusWsMessage::Reconnected => {
399 tracing::info!("Reconnected to Deribit WebSocket");
400
401 let channels: Vec<String> = subscribed_channels
403 .iter()
404 .map(|r| r.key().clone())
405 .collect();
406
407 for channel in &channels {
409 subscriptions_state.mark_failure(channel);
410 }
411
412 if let Some(cred) = &credential {
414 tracing::info!("Re-authenticating after reconnection...");
415
416 is_authenticated.store(false, Ordering::Release);
418 pending_reauth = true;
419
420 let previous_scope = auth_state
422 .read()
423 .await
424 .as_ref()
425 .map(|s| s.scope.clone());
426
427 send_auth_request(cred, previous_scope, &cmd_tx, &request_id_counter);
429 } else {
430 if !channels.is_empty() {
432 let _ = cmd_tx.send(HandlerCommand::Subscribe { channels });
433 }
434 }
435 }
436 NautilusWsMessage::Authenticated(result) => {
437 let timestamp = get_atomic_clock_realtime().get_time_ms();
438 let new_auth_state = AuthState::from_auth_result(&result, timestamp);
439 *auth_state.write().await = Some(new_auth_state);
440
441 spawn_token_refresh_task(
443 result.expires_in,
444 result.refresh_token.clone(),
445 cmd_tx.clone(),
446 request_id_counter.clone(),
447 );
448
449 if pending_reauth {
450 pending_reauth = false;
451 is_authenticated.store(true, Ordering::Release);
452 tracing::info!(
453 "Re-authentication successful (scope: {}), resubscribing to channels",
454 result.scope
455 );
456
457 let channels: Vec<String> = subscribed_channels
459 .iter()
460 .map(|r| r.key().clone())
461 .collect();
462
463 if !channels.is_empty() {
464 let _ = cmd_tx.send(HandlerCommand::Subscribe { channels });
465 }
466 } else {
467 is_authenticated.store(true, Ordering::Release);
469 tracing::debug!(
470 "Auth state stored: scope={}, expires_in={}s",
471 result.scope,
472 result.expires_in
473 );
474 }
475 }
476 _ => {}
477 },
478 None => {
479 tracing::debug!("Handler returned None, stopping task");
480 break;
481 }
482 }
483 }
484 });
485
486 self.task_handle = Some(Arc::new(task_handle));
487 tracing::info!("Connected to Deribit WebSocket");
488
489 Ok(())
490 }
491
492 pub async fn close(&self) -> DeribitWsResult<()> {
498 tracing::info!("Closing Deribit WebSocket connection");
499 self.signal.store(true, Ordering::Relaxed);
500
501 let _ = self.cmd_tx.read().await.send(HandlerCommand::Disconnect);
502
503 if let Some(_handle) = &self.task_handle {
505 let _ = tokio::time::timeout(Duration::from_secs(5), async {
506 tokio::time::sleep(Duration::from_millis(100)).await;
508 })
509 .await;
510 }
511
512 Ok(())
513 }
514
515 pub fn stream(&mut self) -> impl Stream<Item = NautilusWsMessage> + 'static {
521 let rx = self
522 .out_rx
523 .take()
524 .expect("Data stream receiver already taken or not connected");
525 let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
526
527 async_stream::stream! {
528 while let Some(msg) = rx.recv().await {
529 yield msg;
530 }
531 }
532 }
533
534 #[must_use]
536 pub fn has_credentials(&self) -> bool {
537 self.credential.is_some()
538 }
539
540 #[must_use]
542 pub fn is_authenticated(&self) -> bool {
543 self.is_authenticated.load(Ordering::Acquire)
544 }
545
546 pub async fn authenticate(&self, session_name: Option<&str>) -> DeribitWsResult<()> {
565 let credential = self.credential.as_ref().ok_or_else(|| {
566 DeribitWsError::Authentication("API credentials not configured".to_string())
567 })?;
568
569 let scope = session_name.map(|name| format!("session:{name}"));
571
572 tracing::info!(
573 "Authenticating WebSocket with API key: {}, scope: {}",
574 credential.api_key_masked(),
575 scope.as_deref().unwrap_or("connection (default)")
576 );
577
578 let rx = self.auth_tracker.begin();
579
580 let cmd_tx = self.cmd_tx.read().await;
582 send_auth_request(credential, scope, &cmd_tx, &self.request_id_counter);
583 drop(cmd_tx);
584
585 match self
587 .auth_tracker
588 .wait_for_result::<DeribitWsError>(Duration::from_secs(AUTHENTICATION_TIMEOUT_SECS), rx)
589 .await
590 {
591 Ok(()) => {
592 self.is_authenticated.store(true, Ordering::Release);
593 tracing::info!("WebSocket authenticated successfully");
594 Ok(())
595 }
596 Err(e) => {
597 tracing::error!(error = %e, "WebSocket authentication failed");
598 Err(e)
599 }
600 }
601 }
602
603 pub async fn authenticate_session(&self) -> DeribitWsResult<()> {
613 self.authenticate(Some(DEFAULT_SESSION_NAME)).await
614 }
615
616 pub async fn auth_state(&self) -> Option<AuthState> {
620 self.auth_state.read().await.clone()
621 }
622
623 pub async fn access_token(&self) -> Option<String> {
625 self.auth_state
626 .read()
627 .await
628 .as_ref()
629 .map(|s| s.access_token.clone())
630 }
631
632 async fn send_subscribe(&self, channels: Vec<String>) -> DeribitWsResult<()> {
637 for channel in &channels {
639 self.subscribed_channels.insert(channel.clone(), ());
640 }
641
642 self.cmd_tx
643 .read()
644 .await
645 .send(HandlerCommand::Subscribe {
646 channels: channels.clone(),
647 })
648 .map_err(|e| DeribitWsError::Send(e.to_string()))?;
649
650 tracing::debug!("Sent subscribe for {} channels", channels.len());
651 Ok(())
652 }
653
654 async fn send_unsubscribe(&self, channels: Vec<String>) -> DeribitWsResult<()> {
655 for channel in &channels {
657 self.subscribed_channels.remove(channel);
658 }
659
660 self.cmd_tx
661 .read()
662 .await
663 .send(HandlerCommand::Unsubscribe {
664 channels: channels.clone(),
665 })
666 .map_err(|e| DeribitWsError::Send(e.to_string()))?;
667
668 tracing::debug!("Sent unsubscribe for {} channels", channels.len());
669 Ok(())
670 }
671
672 pub async fn subscribe_trades(
683 &self,
684 instrument_id: InstrumentId,
685 interval: Option<DeribitUpdateInterval>,
686 ) -> DeribitWsResult<()> {
687 let interval = interval.unwrap_or_default();
688 self.check_auth_requirement(interval)?;
689 let channel =
690 DeribitWsChannel::Trades.format_channel(instrument_id.symbol.as_str(), Some(interval));
691 self.send_subscribe(vec![channel]).await
692 }
693
694 pub async fn subscribe_trades_raw(&self, instrument_id: InstrumentId) -> DeribitWsResult<()> {
702 self.subscribe_trades(instrument_id, Some(DeribitUpdateInterval::Raw))
703 .await
704 }
705
706 pub async fn unsubscribe_trades(
712 &self,
713 instrument_id: InstrumentId,
714 interval: Option<DeribitUpdateInterval>,
715 ) -> DeribitWsResult<()> {
716 let interval = interval.unwrap_or_default();
717 let channel =
718 DeribitWsChannel::Trades.format_channel(instrument_id.symbol.as_str(), Some(interval));
719 self.send_unsubscribe(vec![channel]).await
720 }
721
722 pub async fn subscribe_book(
733 &self,
734 instrument_id: InstrumentId,
735 interval: Option<DeribitUpdateInterval>,
736 ) -> DeribitWsResult<()> {
737 let interval = interval.unwrap_or_default();
738 self.check_auth_requirement(interval)?;
739 let channel =
740 DeribitWsChannel::Book.format_channel(instrument_id.symbol.as_str(), Some(interval));
741 self.send_subscribe(vec![channel]).await
742 }
743
744 pub async fn subscribe_book_raw(&self, instrument_id: InstrumentId) -> DeribitWsResult<()> {
752 self.subscribe_book(instrument_id, Some(DeribitUpdateInterval::Raw))
753 .await
754 }
755
756 pub async fn unsubscribe_book(
762 &self,
763 instrument_id: InstrumentId,
764 interval: Option<DeribitUpdateInterval>,
765 ) -> DeribitWsResult<()> {
766 let interval = interval.unwrap_or_default();
767 let channel =
768 DeribitWsChannel::Book.format_channel(instrument_id.symbol.as_str(), Some(interval));
769 self.send_unsubscribe(vec![channel]).await
770 }
771
772 pub async fn subscribe_ticker(
783 &self,
784 instrument_id: InstrumentId,
785 interval: Option<DeribitUpdateInterval>,
786 ) -> DeribitWsResult<()> {
787 let interval = interval.unwrap_or_default();
788 self.check_auth_requirement(interval)?;
789 let channel =
790 DeribitWsChannel::Ticker.format_channel(instrument_id.symbol.as_str(), Some(interval));
791 self.send_subscribe(vec![channel]).await
792 }
793
794 pub async fn subscribe_ticker_raw(&self, instrument_id: InstrumentId) -> DeribitWsResult<()> {
802 self.subscribe_ticker(instrument_id, Some(DeribitUpdateInterval::Raw))
803 .await
804 }
805
806 pub async fn unsubscribe_ticker(
812 &self,
813 instrument_id: InstrumentId,
814 interval: Option<DeribitUpdateInterval>,
815 ) -> DeribitWsResult<()> {
816 let interval = interval.unwrap_or_default();
817 let channel =
818 DeribitWsChannel::Ticker.format_channel(instrument_id.symbol.as_str(), Some(interval));
819 self.send_unsubscribe(vec![channel]).await
820 }
821
822 pub async fn subscribe_quotes(&self, instrument_id: InstrumentId) -> DeribitWsResult<()> {
830 let channel = DeribitWsChannel::Quote.format_channel(instrument_id.symbol.as_str(), None);
831 self.send_subscribe(vec![channel]).await
832 }
833
834 pub async fn unsubscribe_quotes(&self, instrument_id: InstrumentId) -> DeribitWsResult<()> {
840 let channel = DeribitWsChannel::Quote.format_channel(instrument_id.symbol.as_str(), None);
841 self.send_unsubscribe(vec![channel]).await
842 }
843
844 fn check_auth_requirement(&self, interval: DeribitUpdateInterval) -> DeribitWsResult<()> {
850 if interval.requires_auth() && !self.is_authenticated() {
851 return Err(DeribitWsError::Authentication(
852 "Raw streams require authentication. Call authenticate() first.".to_string(),
853 ));
854 }
855 Ok(())
856 }
857
858 pub async fn subscribe(&self, channels: Vec<String>) -> DeribitWsResult<()> {
864 self.send_subscribe(channels).await
865 }
866
867 pub async fn unsubscribe(&self, channels: Vec<String>) -> DeribitWsResult<()> {
873 self.send_unsubscribe(channels).await
874 }
875}