nautilus_binance/futures/websocket/
client.rs1use std::{
27 fmt::Debug,
28 sync::{
29 Arc,
30 atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
31 },
32};
33
34use arc_swap::ArcSwap;
35use dashmap::DashMap;
36use futures_util::Stream;
37use nautilus_common::live::get_runtime;
38use nautilus_core::time::get_atomic_clock_realtime;
39use nautilus_model::instruments::{Instrument, InstrumentAny};
40use nautilus_network::{
41 mode::ConnectionMode,
42 websocket::{
43 PingHandler, SubscriptionState, WebSocketClient, WebSocketConfig, channel_message_handler,
44 },
45};
46use tokio_tungstenite::tungstenite::Message;
47use tokio_util::sync::CancellationToken;
48use ustr::Ustr;
49
50use super::{
51 error::{BinanceWsError, BinanceWsResult},
52 handler_data::BinanceFuturesDataWsFeedHandler,
53 messages::{DataHandlerCommand, NautilusWsMessage},
54};
55use crate::common::{
56 consts::{
57 BINANCE_RATE_LIMIT_KEY_SUBSCRIPTION, BINANCE_WS_CONNECTION_QUOTA,
58 BINANCE_WS_SUBSCRIPTION_QUOTA,
59 },
60 credential::Credential,
61 enums::{BinanceEnvironment, BinanceProductType},
62 urls::get_ws_base_url,
63};
64
65pub const MAX_STREAMS_PER_CONNECTION: usize = 200;
67
68#[derive(Clone)]
70#[cfg_attr(
71 feature = "python",
72 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.binance")
73)]
74pub struct BinanceFuturesWebSocketClient {
75 url: String,
76 product_type: BinanceProductType,
77 credential: Option<Arc<Credential>>,
78 heartbeat: Option<u64>,
79 signal: Arc<AtomicBool>,
80 connection_mode: Arc<ArcSwap<AtomicU8>>,
81 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<DataHandlerCommand>>>,
82 out_rx: Arc<std::sync::Mutex<Option<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>>,
83 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
84 subscriptions_state: SubscriptionState,
85 request_id_counter: Arc<AtomicU64>,
86 instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
87 cancellation_token: CancellationToken,
88}
89
90impl Debug for BinanceFuturesWebSocketClient {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 f.debug_struct(stringify!(BinanceFuturesWebSocketClient))
93 .field("url", &self.url)
94 .field("product_type", &self.product_type)
95 .field(
96 "credential",
97 &self.credential.as_ref().map(|_| "<redacted>"),
98 )
99 .field("heartbeat", &self.heartbeat)
100 .finish_non_exhaustive()
101 }
102}
103
104impl BinanceFuturesWebSocketClient {
105 pub fn new(
113 product_type: BinanceProductType,
114 environment: BinanceEnvironment,
115 api_key: Option<String>,
116 api_secret: Option<String>,
117 url_override: Option<String>,
118 heartbeat: Option<u64>,
119 ) -> anyhow::Result<Self> {
120 match product_type {
121 BinanceProductType::UsdM | BinanceProductType::CoinM => {}
122 _ => {
123 anyhow::bail!(
124 "BinanceFuturesWebSocketClient requires UsdM or CoinM product type, was {product_type:?}"
125 );
126 }
127 }
128
129 let url =
130 url_override.unwrap_or_else(|| get_ws_base_url(product_type, environment).to_string());
131
132 let credential = match (api_key, api_secret) {
133 (Some(key), Some(secret)) => Some(Arc::new(Credential::new(key, secret))),
134 _ => None,
135 };
136
137 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel();
138
139 Ok(Self {
140 url,
141 product_type,
142 credential,
143 heartbeat,
144 signal: Arc::new(AtomicBool::new(false)),
145 connection_mode: Arc::new(ArcSwap::new(Arc::new(AtomicU8::new(
146 ConnectionMode::Closed as u8,
147 )))),
148 cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
149 out_rx: Arc::new(std::sync::Mutex::new(None)),
150 task_handle: None,
151 subscriptions_state: SubscriptionState::new('@'),
152 request_id_counter: Arc::new(AtomicU64::new(1)),
153 instruments_cache: Arc::new(DashMap::new()),
154 cancellation_token: CancellationToken::new(),
155 })
156 }
157
158 #[must_use]
160 pub const fn product_type(&self) -> BinanceProductType {
161 self.product_type
162 }
163
164 #[must_use]
166 pub fn is_active(&self) -> bool {
167 let mode_u8 = self.connection_mode.load().load(Ordering::Relaxed);
168 mode_u8 == ConnectionMode::Active as u8
169 }
170
171 #[must_use]
173 pub fn is_closed(&self) -> bool {
174 let mode_u8 = self.connection_mode.load().load(Ordering::Relaxed);
175 mode_u8 == ConnectionMode::Closed as u8
176 }
177
178 #[must_use]
180 pub fn subscription_count(&self) -> usize {
181 self.subscriptions_state.len()
182 }
183
184 pub async fn connect(&mut self) -> BinanceWsResult<()> {
194 self.signal.store(false, Ordering::Relaxed);
195
196 let (raw_handler, raw_rx) = channel_message_handler();
197 let ping_handler: PingHandler = Arc::new(move |_| {});
198
199 let headers = if let Some(ref cred) = self.credential {
201 vec![("X-MBX-APIKEY".to_string(), cred.api_key().to_string())]
202 } else {
203 vec![]
204 };
205
206 let config = WebSocketConfig {
207 url: self.url.clone(),
208 headers,
209 heartbeat: self.heartbeat,
210 heartbeat_msg: None,
211 reconnect_timeout_ms: Some(5_000),
212 reconnect_delay_initial_ms: Some(500),
213 reconnect_delay_max_ms: Some(5_000),
214 reconnect_backoff_factor: Some(2.0),
215 reconnect_jitter_ms: Some(250),
216 reconnect_max_attempts: None,
217 };
218
219 let keyed_quotas = vec![(
221 BINANCE_RATE_LIMIT_KEY_SUBSCRIPTION[0].as_str().to_string(),
222 *BINANCE_WS_SUBSCRIPTION_QUOTA,
223 )];
224
225 let client = WebSocketClient::connect(
226 config,
227 Some(raw_handler),
228 Some(ping_handler),
229 None,
230 keyed_quotas,
231 Some(*BINANCE_WS_CONNECTION_QUOTA),
232 )
233 .await
234 .map_err(|e| BinanceWsError::NetworkError(e.to_string()))?;
235
236 self.connection_mode.store(client.connection_mode_atomic());
237
238 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel();
239 let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel();
240 *self.cmd_tx.write().await = cmd_tx;
241 *self.out_rx.lock().expect("out_rx lock poisoned") = Some(out_rx);
242
243 let (bytes_tx, bytes_rx) = tokio::sync::mpsc::unbounded_channel::<Vec<u8>>();
244
245 let bytes_task = get_runtime().spawn(async move {
246 let mut raw_rx = raw_rx;
247 while let Some(msg) = raw_rx.recv().await {
248 let data = match msg {
249 Message::Binary(data) => data.to_vec(),
250 Message::Text(text) => text.as_bytes().to_vec(),
251 Message::Close(_) => break,
252 Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => continue,
253 };
254 if bytes_tx.send(data).is_err() {
255 break;
256 }
257 }
258 });
259
260 let mut handler = BinanceFuturesDataWsFeedHandler::new(
261 get_atomic_clock_realtime(),
262 self.signal.clone(),
263 cmd_rx,
264 bytes_rx,
265 out_tx.clone(),
266 self.subscriptions_state.clone(),
267 self.request_id_counter.clone(),
268 );
269
270 self.cmd_tx
271 .read()
272 .await
273 .send(DataHandlerCommand::SetClient(client))
274 .map_err(|e| BinanceWsError::ClientError(format!("Failed to set client: {e}")))?;
275
276 let instruments: Vec<InstrumentAny> = self
277 .instruments_cache
278 .iter()
279 .map(|entry| entry.value().clone())
280 .collect();
281
282 if !instruments.is_empty() {
283 self.cmd_tx
284 .read()
285 .await
286 .send(DataHandlerCommand::InitializeInstruments(instruments))
287 .map_err(|e| {
288 BinanceWsError::ClientError(format!("Failed to initialize instruments: {e}"))
289 })?;
290 }
291
292 let signal = self.signal.clone();
293 let cancellation_token = self.cancellation_token.clone();
294 let subscriptions_state = self.subscriptions_state.clone();
295 let cmd_tx = self.cmd_tx.clone();
296
297 let task_handle = get_runtime().spawn(async move {
298 loop {
299 tokio::select! {
300 () = cancellation_token.cancelled() => {
301 log::debug!("Handler task cancelled");
302 break;
303 }
304 result = handler.next() => {
305 match result {
306 Some(NautilusWsMessage::Reconnected) => {
307 log::info!("WebSocket reconnected, restoring subscriptions");
308 let all_topics = subscriptions_state.all_topics();
310 for topic in &all_topics {
311 subscriptions_state.mark_failure(topic);
312 }
313
314 let streams = subscriptions_state.all_topics();
316 if !streams.is_empty()
317 && let Err(e) = cmd_tx.read().await.send(DataHandlerCommand::Subscribe { streams }) {
318 log::error!("Failed to resubscribe after reconnect: {e}");
319 }
320
321 if out_tx.send(NautilusWsMessage::Reconnected).is_err() {
322 log::debug!("Output channel closed");
323 break;
324 }
325 }
326 Some(msg) => {
327 if out_tx.send(msg).is_err() {
328 log::debug!("Output channel closed");
329 break;
330 }
331 }
332 None => {
333 if signal.load(Ordering::Relaxed) {
334 log::debug!("Handler received shutdown signal");
335 } else {
336 log::warn!("Handler loop ended unexpectedly");
337 }
338 break;
339 }
340 }
341 }
342 }
343 }
344 bytes_task.abort();
345 });
346
347 self.task_handle = Some(Arc::new(task_handle));
348
349 log::info!(
350 "Connected to Binance Futures stream: url={}, product_type={:?}",
351 self.url,
352 self.product_type
353 );
354 Ok(())
355 }
356
357 pub async fn close(&mut self) -> BinanceWsResult<()> {
367 self.signal.store(true, Ordering::Relaxed);
368 self.cancellation_token.cancel();
369
370 let _ = self
371 .cmd_tx
372 .read()
373 .await
374 .send(DataHandlerCommand::Disconnect);
375
376 if let Some(handle) = self.task_handle.take()
377 && let Ok(handle) = Arc::try_unwrap(handle)
378 {
379 let _ = handle.await;
380 }
381
382 *self.out_rx.lock().expect("out_rx lock poisoned") = None;
383
384 log::info!("Disconnected from Binance Futures stream");
385 Ok(())
386 }
387
388 pub async fn subscribe(&self, streams: Vec<String>) -> BinanceWsResult<()> {
394 let current_count = self.subscriptions_state.len();
395 if current_count + streams.len() > MAX_STREAMS_PER_CONNECTION {
396 return Err(BinanceWsError::ClientError(format!(
397 "Would exceed max streams: {} + {} > {}",
398 current_count,
399 streams.len(),
400 MAX_STREAMS_PER_CONNECTION
401 )));
402 }
403
404 self.cmd_tx
405 .read()
406 .await
407 .send(DataHandlerCommand::Subscribe { streams })
408 .map_err(|e| BinanceWsError::ClientError(format!("Handler not available: {e}")))?;
409
410 Ok(())
411 }
412
413 pub async fn unsubscribe(&self, streams: Vec<String>) -> BinanceWsResult<()> {
419 self.cmd_tx
420 .read()
421 .await
422 .send(DataHandlerCommand::Unsubscribe { streams })
423 .map_err(|e| BinanceWsError::ClientError(format!("Handler not available: {e}")))?;
424
425 Ok(())
426 }
427
428 pub fn stream(&self) -> impl Stream<Item = NautilusWsMessage> + 'static {
438 let out_rx = self.out_rx.lock().expect("out_rx lock poisoned").take();
439 async_stream::stream! {
440 if let Some(mut rx) = out_rx {
441 while let Some(msg) = rx.recv().await {
442 yield msg;
443 }
444 }
445 }
446 }
447
448 pub fn cache_instruments(&self, instruments: Vec<InstrumentAny>) {
453 for inst in &instruments {
454 self.instruments_cache
455 .insert(inst.raw_symbol().inner(), inst.clone());
456 }
457
458 if self.is_active() {
459 let cmd_tx = self.cmd_tx.clone();
460 let instruments_clone = instruments;
461 get_runtime().spawn(async move {
462 let _ = cmd_tx
463 .read()
464 .await
465 .send(DataHandlerCommand::InitializeInstruments(instruments_clone));
466 });
467 }
468 }
469
470 pub fn cache_instrument(&self, instrument: InstrumentAny) {
475 self.instruments_cache
476 .insert(instrument.raw_symbol().inner(), instrument.clone());
477
478 if self.is_active() {
479 let cmd_tx = self.cmd_tx.clone();
480 get_runtime().spawn(async move {
481 let _ = cmd_tx
482 .read()
483 .await
484 .send(DataHandlerCommand::UpdateInstrument(instrument));
485 });
486 }
487 }
488
489 #[must_use]
491 pub fn get_instrument(&self, symbol: &str) -> Option<InstrumentAny> {
492 self.instruments_cache
493 .get(&Ustr::from(symbol))
494 .map(|entry| entry.value().clone())
495 }
496}