nautilus_binance/spot/websocket/streams/
client.rs1use std::{
27 fmt::Debug,
28 sync::{
29 Arc,
30 atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
31 },
32};
33
34use arc_swap::ArcSwap;
35use dashmap::DashMap;
36use futures_util::Stream;
37use nautilus_common::live::get_runtime;
38use nautilus_model::instruments::{Instrument, InstrumentAny};
39use nautilus_network::{
40 mode::ConnectionMode,
41 websocket::{
42 PingHandler, SubscriptionState, WebSocketClient, WebSocketConfig, channel_message_handler,
43 },
44};
45use tokio_util::sync::CancellationToken;
46use ustr::Ustr;
47
48use super::{
49 super::error::{BinanceWsError, BinanceWsResult},
50 handler::BinanceSpotWsFeedHandler,
51 messages::{BinanceSpotWsMessage, HandlerCommand},
52 subscription::MAX_STREAMS_PER_CONNECTION,
53};
54use crate::common::{
55 consts::{
56 BINANCE_RATE_LIMIT_KEY_SUBSCRIPTION, BINANCE_SPOT_SBE_WS_URL, BINANCE_WS_CONNECTION_QUOTA,
57 BINANCE_WS_SUBSCRIPTION_QUOTA,
58 },
59 credential::Ed25519Credential,
60};
61
62#[derive(Clone)]
64#[cfg_attr(
65 feature = "python",
66 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.binance")
67)]
68pub struct BinanceSpotWebSocketClient {
69 url: String,
70 credential: Option<Arc<Ed25519Credential>>,
71 heartbeat: Option<u64>,
72 signal: Arc<AtomicBool>,
73 connection_mode: Arc<ArcSwap<AtomicU8>>,
74 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
75 out_rx:
76 Arc<std::sync::Mutex<Option<tokio::sync::mpsc::UnboundedReceiver<BinanceSpotWsMessage>>>>,
77 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
78 subscriptions_state: SubscriptionState,
79 request_id_counter: Arc<AtomicU64>,
80 instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
81 cancellation_token: CancellationToken,
82}
83
84impl Debug for BinanceSpotWebSocketClient {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 f.debug_struct(stringify!(BinanceSpotWebSocketClient))
87 .field("url", &self.url)
88 .field(
89 "credential",
90 &self.credential.as_ref().map(|_| "<redacted>"),
91 )
92 .field("heartbeat", &self.heartbeat)
93 .finish_non_exhaustive()
94 }
95}
96
97impl Default for BinanceSpotWebSocketClient {
98 fn default() -> Self {
99 Self::new(None, None, None, None).unwrap()
100 }
101}
102
103impl BinanceSpotWebSocketClient {
104 pub fn new(
110 url: Option<String>,
111 api_key: Option<String>,
112 api_secret: Option<String>,
113 heartbeat: Option<u64>,
114 ) -> anyhow::Result<Self> {
115 let url = url.unwrap_or(BINANCE_SPOT_SBE_WS_URL.to_string());
116
117 let credential = match (api_key, api_secret) {
118 (Some(key), Some(secret)) => Some(Arc::new(Ed25519Credential::new(key, &secret)?)),
119 _ => None,
120 };
121
122 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel();
123
124 Ok(Self {
125 url,
126 credential,
127 heartbeat,
128 signal: Arc::new(AtomicBool::new(false)),
129 connection_mode: Arc::new(ArcSwap::new(Arc::new(AtomicU8::new(
130 ConnectionMode::Closed as u8,
131 )))),
132 cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
133 out_rx: Arc::new(std::sync::Mutex::new(None)),
134 task_handle: None,
135 subscriptions_state: SubscriptionState::new('@'),
136 request_id_counter: Arc::new(AtomicU64::new(1)),
137 instruments_cache: Arc::new(DashMap::new()),
138 cancellation_token: CancellationToken::new(),
139 })
140 }
141
142 #[must_use]
144 pub fn is_active(&self) -> bool {
145 let mode_u8 = self.connection_mode.load().load(Ordering::Relaxed);
146 mode_u8 == ConnectionMode::Active as u8
147 }
148
149 #[must_use]
151 pub fn is_closed(&self) -> bool {
152 let mode_u8 = self.connection_mode.load().load(Ordering::Relaxed);
153 mode_u8 == ConnectionMode::Closed as u8
154 }
155
156 #[must_use]
158 pub fn subscription_count(&self) -> usize {
159 self.subscriptions_state.len()
160 }
161
162 pub async fn connect(&mut self) -> BinanceWsResult<()> {
172 self.signal.store(false, Ordering::Relaxed);
173 self.cancellation_token = CancellationToken::new();
174
175 let (raw_handler, raw_rx) = channel_message_handler();
176 let ping_handler: PingHandler = Arc::new(move |_| {});
177
178 let headers = if let Some(ref cred) = self.credential {
180 vec![("X-MBX-APIKEY".to_string(), cred.api_key().to_string())]
181 } else {
182 vec![]
183 };
184
185 log::info!(
186 "Connecting to Binance SBE WebSocket: url={}, auth={}",
187 self.url,
188 self.credential.is_some()
189 );
190
191 let config = WebSocketConfig {
192 url: self.url.clone(),
193 headers,
194 heartbeat: self.heartbeat,
195 heartbeat_msg: None,
196 reconnect_timeout_ms: Some(5_000),
197 reconnect_delay_initial_ms: Some(500),
198 reconnect_delay_max_ms: Some(5_000),
199 reconnect_backoff_factor: Some(2.0),
200 reconnect_jitter_ms: Some(250),
201 reconnect_max_attempts: None,
202 };
203
204 let keyed_quotas = vec![(
206 BINANCE_RATE_LIMIT_KEY_SUBSCRIPTION[0].as_str().to_string(),
207 *BINANCE_WS_SUBSCRIPTION_QUOTA,
208 )];
209
210 let client = WebSocketClient::connect(
211 config,
212 Some(raw_handler),
213 Some(ping_handler),
214 None,
215 keyed_quotas,
216 Some(*BINANCE_WS_CONNECTION_QUOTA),
217 )
218 .await
219 .map_err(|e| {
220 log::error!("WebSocket connection failed: {e}");
221 BinanceWsError::NetworkError(e.to_string())
222 })?;
223
224 self.connection_mode.store(client.connection_mode_atomic());
225
226 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel();
227 let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel();
228 *self.cmd_tx.write().await = cmd_tx;
229 *self.out_rx.lock().expect("out_rx lock poisoned") = Some(out_rx);
230
231 let mut handler = BinanceSpotWsFeedHandler::new(
232 self.signal.clone(),
233 cmd_rx,
234 raw_rx,
235 out_tx.clone(),
236 self.subscriptions_state.clone(),
237 self.request_id_counter.clone(),
238 );
239
240 self.cmd_tx
241 .read()
242 .await
243 .send(HandlerCommand::SetClient(client))
244 .map_err(|e| BinanceWsError::ClientError(format!("Failed to set client: {e}")))?;
245
246 let instruments: Vec<InstrumentAny> = self
247 .instruments_cache
248 .iter()
249 .map(|entry| entry.value().clone())
250 .collect();
251
252 if !instruments.is_empty() {
253 self.cmd_tx
254 .read()
255 .await
256 .send(HandlerCommand::InitializeInstruments(instruments))
257 .map_err(|e| {
258 BinanceWsError::ClientError(format!("Failed to initialize instruments: {e}"))
259 })?;
260 }
261
262 let signal = self.signal.clone();
263 let cancellation_token = self.cancellation_token.clone();
264 let subscriptions_state = self.subscriptions_state.clone();
265 let cmd_tx = self.cmd_tx.clone();
266
267 let task_handle = get_runtime().spawn(async move {
268 loop {
269 tokio::select! {
270 () = cancellation_token.cancelled() => {
271 log::debug!("Handler task cancelled");
272 break;
273 }
274 result = handler.next() => {
275 match result {
276 Some(BinanceSpotWsMessage::Reconnected) => {
277 log::info!("WebSocket reconnected, restoring subscriptions");
278 let all_topics = subscriptions_state.all_topics();
280 for topic in &all_topics {
281 subscriptions_state.mark_failure(topic);
282 }
283
284 let streams = subscriptions_state.all_topics();
286 if !streams.is_empty()
287 && let Err(e) = cmd_tx.read().await.send(HandlerCommand::Subscribe { streams }) {
288 log::error!("Failed to resubscribe after reconnect: {e}");
289 }
290
291 if out_tx.send(BinanceSpotWsMessage::Reconnected).is_err() {
292 log::debug!("Output channel closed");
293 break;
294 }
295 }
296 Some(msg) => {
297 if out_tx.send(msg).is_err() {
298 log::debug!("Output channel closed");
299 break;
300 }
301 }
302 None => {
303 if signal.load(Ordering::Relaxed) {
304 log::debug!("Handler received shutdown signal");
305 } else {
306 log::warn!("Handler loop ended unexpectedly");
307 }
308 break;
309 }
310 }
311 }
312 }
313 }
314 });
315
316 self.task_handle = Some(Arc::new(task_handle));
317
318 log::info!("Connected to Binance Spot SBE stream: url={}", self.url);
319 Ok(())
320 }
321
322 pub async fn close(&mut self) -> BinanceWsResult<()> {
332 self.signal.store(true, Ordering::Relaxed);
333 self.cancellation_token.cancel();
334
335 let _ = self.cmd_tx.read().await.send(HandlerCommand::Disconnect);
336
337 if let Some(handle) = self.task_handle.take()
338 && let Ok(handle) = Arc::try_unwrap(handle)
339 {
340 let _ = handle.await;
341 }
342
343 *self.out_rx.lock().expect("out_rx lock poisoned") = None;
344
345 log::info!("Disconnected from Binance Spot SBE stream");
346 Ok(())
347 }
348
349 pub async fn subscribe(&self, streams: Vec<String>) -> BinanceWsResult<()> {
355 let current_count = self.subscriptions_state.len();
356 if current_count + streams.len() > MAX_STREAMS_PER_CONNECTION {
357 return Err(BinanceWsError::ClientError(format!(
358 "Would exceed max streams: {} + {} > {}",
359 current_count,
360 streams.len(),
361 MAX_STREAMS_PER_CONNECTION
362 )));
363 }
364
365 self.cmd_tx
366 .read()
367 .await
368 .send(HandlerCommand::Subscribe { streams })
369 .map_err(|e| BinanceWsError::ClientError(format!("Handler not available: {e}")))?;
370
371 Ok(())
372 }
373
374 pub async fn unsubscribe(&self, streams: Vec<String>) -> BinanceWsResult<()> {
380 self.cmd_tx
381 .read()
382 .await
383 .send(HandlerCommand::Unsubscribe { streams })
384 .map_err(|e| BinanceWsError::ClientError(format!("Handler not available: {e}")))?;
385
386 Ok(())
387 }
388
389 pub fn stream(&self) -> impl Stream<Item = BinanceSpotWsMessage> + 'static {
399 let out_rx = self.out_rx.lock().expect("out_rx lock poisoned").take();
400 async_stream::stream! {
401 if let Some(mut rx) = out_rx {
402 while let Some(msg) = rx.recv().await {
403 yield msg;
404 }
405 }
406 }
407 }
408
409 pub fn cache_instruments(&self, instruments: Vec<InstrumentAny>) {
411 for inst in &instruments {
412 self.instruments_cache
413 .insert(inst.symbol().inner(), inst.clone());
414 }
415
416 if self.is_active() {
417 let cmd_tx = self.cmd_tx.clone();
418 let instruments_clone = instruments;
419 get_runtime().spawn(async move {
420 let _ = cmd_tx
421 .read()
422 .await
423 .send(HandlerCommand::InitializeInstruments(instruments_clone));
424 });
425 }
426 }
427
428 pub fn cache_instrument(&self, instrument: InstrumentAny) {
430 self.instruments_cache
431 .insert(instrument.symbol().inner(), instrument.clone());
432
433 if self.is_active() {
434 let cmd_tx = self.cmd_tx.clone();
435 get_runtime().spawn(async move {
436 let _ = cmd_tx
437 .read()
438 .await
439 .send(HandlerCommand::UpdateInstrument(instrument));
440 });
441 }
442 }
443
444 #[must_use]
446 pub fn get_instrument(&self, symbol: &str) -> Option<InstrumentAny> {
447 self.instruments_cache
448 .get(&Ustr::from(symbol))
449 .map(|entry| entry.value().clone())
450 }
451}