nautilus_binance/spot/websocket/
client.rs1use std::{
27 fmt::Debug,
28 sync::{
29 Arc,
30 atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
31 },
32};
33
34use arc_swap::ArcSwap;
35use dashmap::DashMap;
36use futures_util::Stream;
37use nautilus_common::live::get_runtime;
38use nautilus_model::instruments::{Instrument, InstrumentAny};
39use nautilus_network::{
40 mode::ConnectionMode,
41 websocket::{
42 PingHandler, SubscriptionState, WebSocketClient, WebSocketConfig, channel_message_handler,
43 },
44};
45use tokio_util::sync::CancellationToken;
46use ustr::Ustr;
47
48use super::{
49 handler::BinanceSpotWsFeedHandler,
50 messages::{HandlerCommand, NautilusWsMessage},
51 streams::MAX_STREAMS_PER_CONNECTION,
52};
53use crate::{
54 common::credential::Ed25519Credential,
55 websocket::error::{BinanceWsError, BinanceWsResult},
56};
57
58pub const SBE_STREAM_URL: &str = "wss://stream-sbe.binance.com/ws";
60
61pub const SBE_STREAM_TESTNET_URL: &str = "wss://testnet.binance.vision/ws-api/v3";
63
64#[derive(Clone)]
66#[cfg_attr(
67 feature = "python",
68 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.binance")
69)]
70pub struct BinanceSpotWebSocketClient {
71 url: String,
72 credential: Option<Arc<Ed25519Credential>>,
73 heartbeat: Option<u64>,
74 signal: Arc<AtomicBool>,
75 connection_mode: Arc<ArcSwap<AtomicU8>>,
76 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
77 out_rx: Arc<std::sync::Mutex<Option<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>>,
78 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
79 subscriptions_state: SubscriptionState,
80 request_id_counter: Arc<AtomicU64>,
81 instruments_cache: Arc<DashMap<Ustr, InstrumentAny>>,
82 cancellation_token: CancellationToken,
83}
84
85impl Debug for BinanceSpotWebSocketClient {
86 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87 f.debug_struct(stringify!(BinanceSpotWebSocketClient))
88 .field("url", &self.url)
89 .field(
90 "credential",
91 &self.credential.as_ref().map(|_| "<redacted>"),
92 )
93 .field("heartbeat", &self.heartbeat)
94 .finish_non_exhaustive()
95 }
96}
97
98impl Default for BinanceSpotWebSocketClient {
99 fn default() -> Self {
100 Self::new(None, None, None, None).unwrap()
101 }
102}
103
104impl BinanceSpotWebSocketClient {
105 pub fn new(
111 url: Option<String>,
112 api_key: Option<String>,
113 api_secret: Option<String>,
114 heartbeat: Option<u64>,
115 ) -> anyhow::Result<Self> {
116 let url = url.unwrap_or(SBE_STREAM_URL.to_string());
117
118 let credential = match (api_key, api_secret) {
119 (Some(key), Some(secret)) => Some(Arc::new(Ed25519Credential::new(key, &secret)?)),
120 _ => None,
121 };
122
123 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel();
124
125 Ok(Self {
126 url,
127 credential,
128 heartbeat,
129 signal: Arc::new(AtomicBool::new(false)),
130 connection_mode: Arc::new(ArcSwap::new(Arc::new(AtomicU8::new(
131 ConnectionMode::Closed as u8,
132 )))),
133 cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
134 out_rx: Arc::new(std::sync::Mutex::new(None)),
135 task_handle: None,
136 subscriptions_state: SubscriptionState::new('@'),
137 request_id_counter: Arc::new(AtomicU64::new(1)),
138 instruments_cache: Arc::new(DashMap::new()),
139 cancellation_token: CancellationToken::new(),
140 })
141 }
142
143 #[must_use]
145 pub fn is_active(&self) -> bool {
146 let mode_u8 = self.connection_mode.load().load(Ordering::Relaxed);
147 mode_u8 == ConnectionMode::Active as u8
148 }
149
150 #[must_use]
152 pub fn is_closed(&self) -> bool {
153 let mode_u8 = self.connection_mode.load().load(Ordering::Relaxed);
154 mode_u8 == ConnectionMode::Closed as u8
155 }
156
157 #[must_use]
159 pub fn subscription_count(&self) -> usize {
160 self.subscriptions_state.len()
161 }
162
163 pub async fn connect(&mut self) -> BinanceWsResult<()> {
173 self.signal.store(false, Ordering::Relaxed);
174
175 let (raw_handler, raw_rx) = channel_message_handler();
176 let ping_handler: PingHandler = Arc::new(move |_| {});
177
178 let headers = if let Some(ref cred) = self.credential {
180 vec![("X-MBX-APIKEY".to_string(), cred.api_key().to_string())]
181 } else {
182 vec![]
183 };
184
185 let config = WebSocketConfig {
186 url: self.url.clone(),
187 headers,
188 heartbeat: self.heartbeat,
189 heartbeat_msg: None,
190 reconnect_timeout_ms: Some(5_000),
191 reconnect_delay_initial_ms: Some(500),
192 reconnect_delay_max_ms: Some(5_000),
193 reconnect_backoff_factor: Some(2.0),
194 reconnect_jitter_ms: Some(250),
195 reconnect_max_attempts: None,
196 };
197
198 let client = WebSocketClient::connect(
199 config,
200 Some(raw_handler),
201 Some(ping_handler),
202 None,
203 vec![],
204 None,
205 )
206 .await
207 .map_err(|e| BinanceWsError::NetworkError(e.to_string()))?;
208
209 self.connection_mode.store(client.connection_mode_atomic());
210
211 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel();
212 let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel();
213 *self.cmd_tx.write().await = cmd_tx;
214 *self.out_rx.lock().expect("out_rx lock poisoned") = Some(out_rx);
215
216 let mut handler = BinanceSpotWsFeedHandler::new(
217 self.signal.clone(),
218 cmd_rx,
219 raw_rx,
220 out_tx.clone(),
221 self.subscriptions_state.clone(),
222 self.request_id_counter.clone(),
223 );
224
225 self.cmd_tx
226 .read()
227 .await
228 .send(HandlerCommand::SetClient(client))
229 .map_err(|e| BinanceWsError::ClientError(format!("Failed to set client: {e}")))?;
230
231 let instruments: Vec<InstrumentAny> = self
232 .instruments_cache
233 .iter()
234 .map(|entry| entry.value().clone())
235 .collect();
236
237 if !instruments.is_empty() {
238 self.cmd_tx
239 .read()
240 .await
241 .send(HandlerCommand::InitializeInstruments(instruments))
242 .map_err(|e| {
243 BinanceWsError::ClientError(format!("Failed to initialize instruments: {e}"))
244 })?;
245 }
246
247 let signal = self.signal.clone();
248 let cancellation_token = self.cancellation_token.clone();
249 let subscriptions_state = self.subscriptions_state.clone();
250 let cmd_tx = self.cmd_tx.clone();
251
252 let task_handle = get_runtime().spawn(async move {
253 loop {
254 tokio::select! {
255 _ = cancellation_token.cancelled() => {
256 tracing::debug!("Handler task cancelled");
257 break;
258 }
259 result = handler.next() => {
260 match result {
261 Some(NautilusWsMessage::Reconnected) => {
262 tracing::info!("WebSocket reconnected, restoring subscriptions");
263 let all_topics = subscriptions_state.all_topics();
265 for topic in &all_topics {
266 subscriptions_state.mark_failure(topic);
267 }
268
269 let streams = subscriptions_state.all_topics();
271 if !streams.is_empty()
272 && let Err(e) = cmd_tx.read().await.send(HandlerCommand::Subscribe { streams }) {
273 tracing::error!(error = %e, "Failed to resubscribe after reconnect");
274 }
275
276 if out_tx.send(NautilusWsMessage::Reconnected).is_err() {
277 tracing::debug!("Output channel closed");
278 break;
279 }
280 }
281 Some(msg) => {
282 if out_tx.send(msg).is_err() {
283 tracing::debug!("Output channel closed");
284 break;
285 }
286 }
287 None => {
288 if signal.load(Ordering::Relaxed) {
289 tracing::debug!("Handler received shutdown signal");
290 } else {
291 tracing::warn!("Handler loop ended unexpectedly");
292 }
293 break;
294 }
295 }
296 }
297 }
298 }
299 });
300
301 self.task_handle = Some(Arc::new(task_handle));
302
303 tracing::info!(url = %self.url, "Connected to Binance Spot SBE stream");
304 Ok(())
305 }
306
307 pub async fn close(&mut self) -> BinanceWsResult<()> {
317 self.signal.store(true, Ordering::Relaxed);
318 self.cancellation_token.cancel();
319
320 let _ = self.cmd_tx.read().await.send(HandlerCommand::Disconnect);
321
322 if let Some(handle) = self.task_handle.take()
323 && let Ok(handle) = Arc::try_unwrap(handle)
324 {
325 let _ = handle.await;
326 }
327
328 *self.out_rx.lock().expect("out_rx lock poisoned") = None;
329
330 tracing::info!("Disconnected from Binance Spot SBE stream");
331 Ok(())
332 }
333
334 pub async fn subscribe(&self, streams: Vec<String>) -> BinanceWsResult<()> {
340 let current_count = self.subscriptions_state.len();
341 if current_count + streams.len() > MAX_STREAMS_PER_CONNECTION {
342 return Err(BinanceWsError::ClientError(format!(
343 "Would exceed max streams: {} + {} > {}",
344 current_count,
345 streams.len(),
346 MAX_STREAMS_PER_CONNECTION
347 )));
348 }
349
350 self.cmd_tx
351 .read()
352 .await
353 .send(HandlerCommand::Subscribe { streams })
354 .map_err(|e| BinanceWsError::ClientError(format!("Handler not available: {e}")))?;
355
356 Ok(())
357 }
358
359 pub async fn unsubscribe(&self, streams: Vec<String>) -> BinanceWsResult<()> {
365 self.cmd_tx
366 .read()
367 .await
368 .send(HandlerCommand::Unsubscribe { streams })
369 .map_err(|e| BinanceWsError::ClientError(format!("Handler not available: {e}")))?;
370
371 Ok(())
372 }
373
374 pub fn stream(&self) -> impl Stream<Item = NautilusWsMessage> + 'static {
384 let out_rx = self.out_rx.lock().expect("out_rx lock poisoned").take();
385 async_stream::stream! {
386 if let Some(mut rx) = out_rx {
387 while let Some(msg) = rx.recv().await {
388 yield msg;
389 }
390 }
391 }
392 }
393
394 pub fn cache_instruments(&self, instruments: Vec<InstrumentAny>) {
396 for inst in &instruments {
397 self.instruments_cache
398 .insert(inst.symbol().inner(), inst.clone());
399 }
400
401 if self.is_active() {
402 let cmd_tx = self.cmd_tx.clone();
403 let instruments_clone = instruments;
404 get_runtime().spawn(async move {
405 let _ = cmd_tx
406 .read()
407 .await
408 .send(HandlerCommand::InitializeInstruments(instruments_clone));
409 });
410 }
411 }
412
413 pub fn cache_instrument(&self, instrument: InstrumentAny) {
415 self.instruments_cache
416 .insert(instrument.symbol().inner(), instrument.clone());
417
418 if self.is_active() {
419 let cmd_tx = self.cmd_tx.clone();
420 get_runtime().spawn(async move {
421 let _ = cmd_tx
422 .read()
423 .await
424 .send(HandlerCommand::UpdateInstrument(instrument));
425 });
426 }
427 }
428
429 #[must_use]
431 pub fn get_instrument(&self, symbol: &str) -> Option<InstrumentAny> {
432 self.instruments_cache
433 .get(&Ustr::from(symbol))
434 .map(|entry| entry.value().clone())
435 }
436}