nautilus_binance/spot/websocket/streams/
client.rs1use std::{
27 fmt::Debug,
28 sync::{
29 Arc,
30 atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
31 },
32};
33
34use arc_swap::ArcSwap;
35use dashmap::DashMap;
36use futures_util::Stream;
37use nautilus_common::live::get_runtime;
38use nautilus_model::instruments::{Instrument, InstrumentAny};
39use nautilus_network::{
40 mode::ConnectionMode,
41 websocket::{
42 PingHandler, SubscriptionState, WebSocketClient, WebSocketConfig, channel_message_handler,
43 },
44};
45use tokio_util::sync::CancellationToken;
46use ustr::Ustr;
47
48use super::{
49 handler::BinanceSpotWsFeedHandler,
50 messages::{HandlerCommand, NautilusWsMessage},
51 subscription::MAX_STREAMS_PER_CONNECTION,
52};
53use crate::{
54 common::{consts::BINANCE_SPOT_SBE_WS_URL, credential::Ed25519Credential},
55 websocket::error::{BinanceWsError, BinanceWsResult},
56};
57
58#[derive(Clone)]
60#[cfg_attr(
61 feature = "python",
62 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.binance")
63)]
64pub struct BinanceSpotWebSocketClient {
65 url: String,
66 credential: Option<Arc<Ed25519Credential>>,
67 heartbeat: Option<u64>,
68 signal: Arc<AtomicBool>,
69 connection_mode: Arc<ArcSwap<AtomicU8>>,
70 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
71 out_rx: Arc<std::sync::Mutex<Option<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>>,
72 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
73 subscriptions_state: SubscriptionState,
74 request_id_counter: Arc<AtomicU64>,
75 instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
76 cancellation_token: CancellationToken,
77}
78
79impl Debug for BinanceSpotWebSocketClient {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 f.debug_struct(stringify!(BinanceSpotWebSocketClient))
82 .field("url", &self.url)
83 .field(
84 "credential",
85 &self.credential.as_ref().map(|_| "<redacted>"),
86 )
87 .field("heartbeat", &self.heartbeat)
88 .finish_non_exhaustive()
89 }
90}
91
92impl Default for BinanceSpotWebSocketClient {
93 fn default() -> Self {
94 Self::new(None, None, None, None).unwrap()
95 }
96}
97
98impl BinanceSpotWebSocketClient {
99 pub fn new(
105 url: Option<String>,
106 api_key: Option<String>,
107 api_secret: Option<String>,
108 heartbeat: Option<u64>,
109 ) -> anyhow::Result<Self> {
110 let url = url.unwrap_or(BINANCE_SPOT_SBE_WS_URL.to_string());
111
112 let credential = match (api_key, api_secret) {
113 (Some(key), Some(secret)) => Some(Arc::new(Ed25519Credential::new(key, &secret)?)),
114 _ => None,
115 };
116
117 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel();
118
119 Ok(Self {
120 url,
121 credential,
122 heartbeat,
123 signal: Arc::new(AtomicBool::new(false)),
124 connection_mode: Arc::new(ArcSwap::new(Arc::new(AtomicU8::new(
125 ConnectionMode::Closed as u8,
126 )))),
127 cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
128 out_rx: Arc::new(std::sync::Mutex::new(None)),
129 task_handle: None,
130 subscriptions_state: SubscriptionState::new('@'),
131 request_id_counter: Arc::new(AtomicU64::new(1)),
132 instruments_cache: Arc::new(DashMap::new()),
133 cancellation_token: CancellationToken::new(),
134 })
135 }
136
137 #[must_use]
139 pub fn is_active(&self) -> bool {
140 let mode_u8 = self.connection_mode.load().load(Ordering::Relaxed);
141 mode_u8 == ConnectionMode::Active as u8
142 }
143
144 #[must_use]
146 pub fn is_closed(&self) -> bool {
147 let mode_u8 = self.connection_mode.load().load(Ordering::Relaxed);
148 mode_u8 == ConnectionMode::Closed as u8
149 }
150
151 #[must_use]
153 pub fn subscription_count(&self) -> usize {
154 self.subscriptions_state.len()
155 }
156
157 pub async fn connect(&mut self) -> BinanceWsResult<()> {
167 self.signal.store(false, Ordering::Relaxed);
168 self.cancellation_token = CancellationToken::new();
169
170 let (raw_handler, raw_rx) = channel_message_handler();
171 let ping_handler: PingHandler = Arc::new(move |_| {});
172
173 let headers = if let Some(ref cred) = self.credential {
175 vec![("X-MBX-APIKEY".to_string(), cred.api_key().to_string())]
176 } else {
177 vec![]
178 };
179
180 log::info!(
181 "Connecting to Binance SBE WebSocket: url={}, auth={}",
182 self.url,
183 self.credential.is_some()
184 );
185
186 let config = WebSocketConfig {
187 url: self.url.clone(),
188 headers,
189 heartbeat: self.heartbeat,
190 heartbeat_msg: None,
191 reconnect_timeout_ms: Some(5_000),
192 reconnect_delay_initial_ms: Some(500),
193 reconnect_delay_max_ms: Some(5_000),
194 reconnect_backoff_factor: Some(2.0),
195 reconnect_jitter_ms: Some(250),
196 reconnect_max_attempts: None,
197 };
198
199 let client = WebSocketClient::connect(
200 config,
201 Some(raw_handler),
202 Some(ping_handler),
203 None,
204 vec![],
205 None,
206 )
207 .await
208 .map_err(|e| {
209 log::error!("WebSocket connection failed: {e}");
210 BinanceWsError::NetworkError(e.to_string())
211 })?;
212
213 self.connection_mode.store(client.connection_mode_atomic());
214
215 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel();
216 let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel();
217 *self.cmd_tx.write().await = cmd_tx;
218 *self.out_rx.lock().expect("out_rx lock poisoned") = Some(out_rx);
219
220 let mut handler = BinanceSpotWsFeedHandler::new(
221 self.signal.clone(),
222 cmd_rx,
223 raw_rx,
224 out_tx.clone(),
225 self.subscriptions_state.clone(),
226 self.request_id_counter.clone(),
227 );
228
229 self.cmd_tx
230 .read()
231 .await
232 .send(HandlerCommand::SetClient(client))
233 .map_err(|e| BinanceWsError::ClientError(format!("Failed to set client: {e}")))?;
234
235 let instruments: Vec<InstrumentAny> = self
236 .instruments_cache
237 .iter()
238 .map(|entry| entry.value().clone())
239 .collect();
240
241 if !instruments.is_empty() {
242 self.cmd_tx
243 .read()
244 .await
245 .send(HandlerCommand::InitializeInstruments(instruments))
246 .map_err(|e| {
247 BinanceWsError::ClientError(format!("Failed to initialize instruments: {e}"))
248 })?;
249 }
250
251 let signal = self.signal.clone();
252 let cancellation_token = self.cancellation_token.clone();
253 let subscriptions_state = self.subscriptions_state.clone();
254 let cmd_tx = self.cmd_tx.clone();
255
256 let task_handle = get_runtime().spawn(async move {
257 loop {
258 tokio::select! {
259 () = cancellation_token.cancelled() => {
260 log::debug!("Handler task cancelled");
261 break;
262 }
263 result = handler.next() => {
264 match result {
265 Some(NautilusWsMessage::Reconnected) => {
266 log::info!("WebSocket reconnected, restoring subscriptions");
267 let all_topics = subscriptions_state.all_topics();
269 for topic in &all_topics {
270 subscriptions_state.mark_failure(topic);
271 }
272
273 let streams = subscriptions_state.all_topics();
275 if !streams.is_empty()
276 && let Err(e) = cmd_tx.read().await.send(HandlerCommand::Subscribe { streams }) {
277 log::error!("Failed to resubscribe after reconnect: {e}");
278 }
279
280 if out_tx.send(NautilusWsMessage::Reconnected).is_err() {
281 log::debug!("Output channel closed");
282 break;
283 }
284 }
285 Some(msg) => {
286 if out_tx.send(msg).is_err() {
287 log::debug!("Output channel closed");
288 break;
289 }
290 }
291 None => {
292 if signal.load(Ordering::Relaxed) {
293 log::debug!("Handler received shutdown signal");
294 } else {
295 log::warn!("Handler loop ended unexpectedly");
296 }
297 break;
298 }
299 }
300 }
301 }
302 }
303 });
304
305 self.task_handle = Some(Arc::new(task_handle));
306
307 log::info!("Connected to Binance Spot SBE stream: url={}", self.url);
308 Ok(())
309 }
310
311 pub async fn close(&mut self) -> BinanceWsResult<()> {
321 self.signal.store(true, Ordering::Relaxed);
322 self.cancellation_token.cancel();
323
324 let _ = self.cmd_tx.read().await.send(HandlerCommand::Disconnect);
325
326 if let Some(handle) = self.task_handle.take()
327 && let Ok(handle) = Arc::try_unwrap(handle)
328 {
329 let _ = handle.await;
330 }
331
332 *self.out_rx.lock().expect("out_rx lock poisoned") = None;
333
334 log::info!("Disconnected from Binance Spot SBE stream");
335 Ok(())
336 }
337
338 pub async fn subscribe(&self, streams: Vec<String>) -> BinanceWsResult<()> {
344 let current_count = self.subscriptions_state.len();
345 if current_count + streams.len() > MAX_STREAMS_PER_CONNECTION {
346 return Err(BinanceWsError::ClientError(format!(
347 "Would exceed max streams: {} + {} > {}",
348 current_count,
349 streams.len(),
350 MAX_STREAMS_PER_CONNECTION
351 )));
352 }
353
354 self.cmd_tx
355 .read()
356 .await
357 .send(HandlerCommand::Subscribe { streams })
358 .map_err(|e| BinanceWsError::ClientError(format!("Handler not available: {e}")))?;
359
360 Ok(())
361 }
362
363 pub async fn unsubscribe(&self, streams: Vec<String>) -> BinanceWsResult<()> {
369 self.cmd_tx
370 .read()
371 .await
372 .send(HandlerCommand::Unsubscribe { streams })
373 .map_err(|e| BinanceWsError::ClientError(format!("Handler not available: {e}")))?;
374
375 Ok(())
376 }
377
378 pub fn stream(&self) -> impl Stream<Item = NautilusWsMessage> + 'static {
388 let out_rx = self.out_rx.lock().expect("out_rx lock poisoned").take();
389 async_stream::stream! {
390 if let Some(mut rx) = out_rx {
391 while let Some(msg) = rx.recv().await {
392 yield msg;
393 }
394 }
395 }
396 }
397
398 pub fn cache_instruments(&self, instruments: Vec<InstrumentAny>) {
400 for inst in &instruments {
401 self.instruments_cache
402 .insert(inst.symbol().inner(), inst.clone());
403 }
404
405 if self.is_active() {
406 let cmd_tx = self.cmd_tx.clone();
407 let instruments_clone = instruments;
408 get_runtime().spawn(async move {
409 let _ = cmd_tx
410 .read()
411 .await
412 .send(HandlerCommand::InitializeInstruments(instruments_clone));
413 });
414 }
415 }
416
417 pub fn cache_instrument(&self, instrument: InstrumentAny) {
419 self.instruments_cache
420 .insert(instrument.symbol().inner(), instrument.clone());
421
422 if self.is_active() {
423 let cmd_tx = self.cmd_tx.clone();
424 get_runtime().spawn(async move {
425 let _ = cmd_tx
426 .read()
427 .await
428 .send(HandlerCommand::UpdateInstrument(instrument));
429 });
430 }
431 }
432
433 #[must_use]
435 pub fn get_instrument(&self, symbol: &str) -> Option<InstrumentAny> {
436 self.instruments_cache
437 .get(&Ustr::from(symbol))
438 .map(|entry| entry.value().clone())
439 }
440}