1use std::{
29 collections::VecDeque,
30 sync::{
31 Arc,
32 atomic::{AtomicBool, AtomicU64, Ordering},
33 },
34};
35
36use ahash::AHashMap;
37use nautilus_core::nanos::UnixNanos;
38use nautilus_model::{
39 data::{BookOrder, Data, OrderBookDelta, OrderBookDeltas, QuoteTick, TradeTick},
40 enums::{AggressorSide, BookAction, OrderSide, RecordFlag},
41 identifiers::TradeId,
42 instruments::{Instrument, InstrumentAny},
43 types::{Price, Quantity},
44};
45use nautilus_network::{
46 RECONNECTED,
47 websocket::{SubscriptionState, WebSocketClient},
48};
49use tokio_tungstenite::tungstenite::Message;
50use ustr::Ustr;
51
52use super::messages::{
53 BinanceWsErrorMsg, BinanceWsErrorResponse, BinanceWsResponse, BinanceWsSubscription,
54 HandlerCommand, NautilusWsMessage,
55};
56use crate::common::sbe::stream::{
57 BestBidAskStreamEvent, DepthDiffStreamEvent, DepthSnapshotStreamEvent, MessageHeader,
58 StreamDecodeError, TradesStreamEvent, mantissa_to_f64, template_id,
59};
60
61#[derive(Debug)]
63pub enum MarketDataMessage {
64 Trades(TradesStreamEvent),
66 BestBidAsk(BestBidAskStreamEvent),
68 DepthSnapshot(DepthSnapshotStreamEvent),
70 DepthDiff(DepthDiffStreamEvent),
72}
73
74pub fn decode_market_data(buf: &[u8]) -> Result<MarketDataMessage, StreamDecodeError> {
79 let header = MessageHeader::decode(buf)?;
80 header.validate_schema()?;
81
82 match header.template_id {
83 template_id::TRADES_STREAM_EVENT => {
84 Ok(MarketDataMessage::Trades(TradesStreamEvent::decode(buf)?))
85 }
86 template_id::BEST_BID_ASK_STREAM_EVENT => Ok(MarketDataMessage::BestBidAsk(
87 BestBidAskStreamEvent::decode(buf)?,
88 )),
89 template_id::DEPTH_SNAPSHOT_STREAM_EVENT => Ok(MarketDataMessage::DepthSnapshot(
90 DepthSnapshotStreamEvent::decode(buf)?,
91 )),
92 template_id::DEPTH_DIFF_STREAM_EVENT => Ok(MarketDataMessage::DepthDiff(
93 DepthDiffStreamEvent::decode(buf)?,
94 )),
95 _ => Err(StreamDecodeError::UnknownTemplateId(header.template_id)),
96 }
97}
98
99pub(super) struct BinanceSpotWsFeedHandler {
104 #[allow(dead_code)] signal: Arc<AtomicBool>,
106 inner: Option<WebSocketClient>,
107 cmd_rx: tokio::sync::mpsc::UnboundedReceiver<HandlerCommand>,
108 raw_rx: tokio::sync::mpsc::UnboundedReceiver<Message>,
109 #[allow(dead_code)] out_tx: tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>,
111 subscriptions: SubscriptionState,
112 instruments_cache: AHashMap<Ustr, InstrumentAny>,
113 request_id_counter: Arc<AtomicU64>,
114 pending_messages: VecDeque<NautilusWsMessage>,
115 pending_requests: AHashMap<u64, Vec<String>>,
116}
117
118impl BinanceSpotWsFeedHandler {
119 pub(super) fn new(
121 signal: Arc<AtomicBool>,
122 cmd_rx: tokio::sync::mpsc::UnboundedReceiver<HandlerCommand>,
123 raw_rx: tokio::sync::mpsc::UnboundedReceiver<Message>,
124 out_tx: tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>,
125 subscriptions: SubscriptionState,
126 request_id_counter: Arc<AtomicU64>,
127 ) -> Self {
128 Self {
129 signal,
130 inner: None,
131 cmd_rx,
132 raw_rx,
133 out_tx,
134 subscriptions,
135 instruments_cache: AHashMap::new(),
136 request_id_counter,
137 pending_messages: VecDeque::new(),
138 pending_requests: AHashMap::new(),
139 }
140 }
141
142 pub(super) async fn next(&mut self) -> Option<NautilusWsMessage> {
146 if let Some(message) = self.pending_messages.pop_front() {
148 return Some(message);
149 }
150
151 loop {
152 tokio::select! {
153 Some(cmd) = self.cmd_rx.recv() => {
154 match cmd {
155 HandlerCommand::SetClient(client) => {
156 tracing::debug!("Handler received WebSocket client");
157 self.inner = Some(client);
158 }
159 HandlerCommand::Disconnect => {
160 tracing::debug!("Handler disconnecting WebSocket client");
161 self.inner = None;
162 return None;
163 }
164 HandlerCommand::InitializeInstruments(instruments) => {
165 for inst in instruments {
166 self.instruments_cache.insert(inst.symbol().inner(), inst);
167 }
168 }
169 HandlerCommand::UpdateInstrument(inst) => {
170 self.instruments_cache.insert(inst.symbol().inner(), inst);
171 }
172 HandlerCommand::Subscribe { streams } => {
173 if let Err(e) = self.handle_subscribe(streams).await {
174 tracing::error!(error = %e, "Failed to handle subscribe command");
175 }
176 }
177 HandlerCommand::Unsubscribe { streams } => {
178 if let Err(e) = self.handle_unsubscribe(streams).await {
179 tracing::error!(error = %e, "Failed to handle unsubscribe command");
180 }
181 }
182 }
183 }
184 Some(msg) = self.raw_rx.recv() => {
185 if let Message::Text(ref text) = msg
186 && text.as_str() == RECONNECTED
187 {
188 tracing::info!("Handler received reconnection signal");
189 return Some(NautilusWsMessage::Reconnected);
190 }
191
192 let messages = self.handle_message(msg);
193 if !messages.is_empty() {
194 let mut iter = messages.into_iter();
195 let first = iter.next();
196 self.pending_messages.extend(iter);
197 if let Some(msg) = first {
198 return Some(msg);
199 }
200 }
201 }
202 else => {
203 return None;
204 }
205 }
206 }
207 }
208
209 fn handle_message(&mut self, msg: Message) -> Vec<NautilusWsMessage> {
211 match msg {
212 Message::Binary(data) => self.handle_binary_frame(&data),
213 Message::Text(text) => self.handle_text_frame(&text),
214 Message::Close(_) => {
215 tracing::debug!("Received close frame");
216 vec![]
217 }
218 Message::Ping(_) | Message::Pong(_) | Message::Frame(_) => vec![],
219 }
220 }
221
222 fn handle_binary_frame(&mut self, data: &[u8]) -> Vec<NautilusWsMessage> {
224 match decode_market_data(data) {
225 Ok(MarketDataMessage::Trades(event)) => self.parse_trades_event(event),
226 Ok(MarketDataMessage::BestBidAsk(event)) => self.parse_bbo_event(event),
227 Ok(MarketDataMessage::DepthSnapshot(event)) => self.parse_depth_snapshot(event),
228 Ok(MarketDataMessage::DepthDiff(event)) => self.parse_depth_diff(event),
229 Err(e) => {
230 tracing::error!(error = %e, "SBE decode error");
231 vec![NautilusWsMessage::RawBinary(data.to_vec())]
232 }
233 }
234 }
235
236 fn handle_text_frame(&mut self, text: &str) -> Vec<NautilusWsMessage> {
238 if let Ok(response) = serde_json::from_str::<BinanceWsResponse>(text) {
239 self.handle_subscription_response(response);
240 return vec![];
241 }
242
243 if let Ok(error) = serde_json::from_str::<BinanceWsErrorResponse>(text) {
245 if let Some(id) = error.id
246 && let Some(streams) = self.pending_requests.remove(&id)
247 {
248 for stream in &streams {
249 self.subscriptions.mark_failure(stream);
250 }
251 tracing::warn!(
252 id,
253 streams = ?streams,
254 code = error.code,
255 msg = %error.msg,
256 "Subscription request failed"
257 );
258 }
259 return vec![NautilusWsMessage::Error(BinanceWsErrorMsg {
260 code: error.code,
261 msg: error.msg,
262 })];
263 }
264
265 if let Ok(value) = serde_json::from_str(text) {
266 vec![NautilusWsMessage::RawJson(value)]
267 } else {
268 tracing::warn!("Failed to parse JSON message: {text}");
269 vec![]
270 }
271 }
272
273 fn handle_subscription_response(&mut self, response: BinanceWsResponse) {
275 if let Some(streams) = self.pending_requests.remove(&response.id) {
276 if response.result.is_none() {
277 for stream in &streams {
279 self.subscriptions.confirm_subscribe(stream);
280 }
281 tracing::debug!(streams = ?streams, "Subscription confirmed");
282 } else {
283 for stream in &streams {
285 self.subscriptions.mark_failure(stream);
286 }
287 tracing::warn!(
288 streams = ?streams,
289 result = ?response.result,
290 "Subscription failed"
291 );
292 }
293 } else {
294 tracing::debug!(id = response.id, "Received response for unknown request");
295 }
296 }
297
298 fn parse_trades_event(&self, event: TradesStreamEvent) -> Vec<NautilusWsMessage> {
300 let symbol = Ustr::from(&event.symbol);
301
302 let Some(instrument) = self.instruments_cache.get(&symbol) else {
303 tracing::warn!(symbol = %event.symbol, "No instrument in cache for trades");
304 return vec![];
305 };
306
307 let instrument_id = instrument.id();
308 let price_precision = instrument.price_precision();
309 let size_precision = instrument.size_precision();
310
311 let trades: Vec<Data> = event
312 .trades
313 .iter()
314 .map(|t| {
315 let price_f64 = mantissa_to_f64(t.price_mantissa, event.price_exponent);
316 let qty_f64 = mantissa_to_f64(t.qty_mantissa, event.qty_exponent);
317 let ts_event = UnixNanos::from(event.transact_time_us as u64 * 1000); let trade = TradeTick::new(
320 instrument_id,
321 Price::new(price_f64, price_precision),
322 Quantity::new(qty_f64, size_precision),
323 if t.is_buyer_maker {
324 AggressorSide::Seller
325 } else {
326 AggressorSide::Buyer
327 },
328 TradeId::new(t.id.to_string()),
329 ts_event,
330 ts_event, );
332 Data::from(trade)
333 })
334 .collect();
335
336 if trades.is_empty() {
337 vec![]
338 } else {
339 vec![NautilusWsMessage::Data(trades)]
340 }
341 }
342
343 fn parse_bbo_event(&self, event: BestBidAskStreamEvent) -> Vec<NautilusWsMessage> {
345 let symbol = Ustr::from(&event.symbol);
346
347 let Some(instrument) = self.instruments_cache.get(&symbol) else {
348 tracing::warn!(symbol = %event.symbol, "No instrument in cache for BBO");
349 return vec![];
350 };
351
352 let instrument_id = instrument.id();
353 let price_precision = instrument.price_precision();
354 let size_precision = instrument.size_precision();
355
356 let bid_price = mantissa_to_f64(event.bid_price_mantissa, event.price_exponent);
357 let bid_size = mantissa_to_f64(event.bid_qty_mantissa, event.qty_exponent);
358 let ask_price = mantissa_to_f64(event.ask_price_mantissa, event.price_exponent);
359 let ask_size = mantissa_to_f64(event.ask_qty_mantissa, event.qty_exponent);
360 let ts_event = UnixNanos::from(event.event_time_us as u64 * 1000); let quote = QuoteTick::new(
363 instrument_id,
364 Price::new(bid_price, price_precision),
365 Price::new(ask_price, price_precision),
366 Quantity::new(bid_size, size_precision),
367 Quantity::new(ask_size, size_precision),
368 ts_event,
369 ts_event,
370 );
371
372 vec![NautilusWsMessage::Data(vec![Data::from(quote)])]
373 }
374
375 fn parse_depth_snapshot(&self, event: DepthSnapshotStreamEvent) -> Vec<NautilusWsMessage> {
377 let symbol = Ustr::from(&event.symbol);
378
379 let Some(instrument) = self.instruments_cache.get(&symbol) else {
380 tracing::warn!(symbol = %event.symbol, "No instrument in cache for depth snapshot");
381 return vec![];
382 };
383
384 let instrument_id = instrument.id();
385 let price_precision = instrument.price_precision();
386 let size_precision = instrument.size_precision();
387 let ts_event = UnixNanos::from(event.event_time_us as u64 * 1000);
388
389 let mut deltas = Vec::with_capacity(event.bids.len() + event.asks.len() + 1);
390
391 deltas.push(OrderBookDelta::clear(instrument_id, 0, ts_event, ts_event));
393
394 for (i, level) in event.bids.iter().enumerate() {
396 let price = mantissa_to_f64(level.price_mantissa, event.price_exponent);
397 let size = mantissa_to_f64(level.qty_mantissa, event.qty_exponent);
398 let flags = if i == event.bids.len() - 1 && event.asks.is_empty() {
399 RecordFlag::F_LAST as u8
400 } else {
401 0
402 };
403
404 let order = BookOrder::new(
405 OrderSide::Buy,
406 Price::new(price, price_precision),
407 Quantity::new(size, size_precision),
408 0, );
410
411 deltas.push(OrderBookDelta::new(
412 instrument_id,
413 BookAction::Add,
414 order,
415 flags,
416 0, ts_event,
418 ts_event,
419 ));
420 }
421
422 for (i, level) in event.asks.iter().enumerate() {
424 let price = mantissa_to_f64(level.price_mantissa, event.price_exponent);
425 let size = mantissa_to_f64(level.qty_mantissa, event.qty_exponent);
426 let flags = if i == event.asks.len() - 1 {
427 RecordFlag::F_LAST as u8
428 } else {
429 0
430 };
431
432 let order = BookOrder::new(
433 OrderSide::Sell,
434 Price::new(price, price_precision),
435 Quantity::new(size, size_precision),
436 0, );
438
439 deltas.push(OrderBookDelta::new(
440 instrument_id,
441 BookAction::Add,
442 order,
443 flags,
444 0, ts_event,
446 ts_event,
447 ));
448 }
449
450 if deltas.len() <= 1 {
451 return vec![];
452 }
453
454 vec![NautilusWsMessage::Deltas(OrderBookDeltas::new(
455 instrument_id,
456 deltas,
457 ))]
458 }
459
460 fn parse_depth_diff(&self, event: DepthDiffStreamEvent) -> Vec<NautilusWsMessage> {
462 let symbol = Ustr::from(&event.symbol);
463
464 let Some(instrument) = self.instruments_cache.get(&symbol) else {
465 tracing::warn!(symbol = %event.symbol, "No instrument in cache for depth diff");
466 return vec![];
467 };
468
469 let instrument_id = instrument.id();
470 let price_precision = instrument.price_precision();
471 let size_precision = instrument.size_precision();
472 let ts_event = UnixNanos::from(event.event_time_us as u64 * 1000);
473
474 let mut deltas = Vec::with_capacity(event.bids.len() + event.asks.len());
475
476 for (i, level) in event.bids.iter().enumerate() {
478 let price = mantissa_to_f64(level.price_mantissa, event.price_exponent);
479 let size = mantissa_to_f64(level.qty_mantissa, event.qty_exponent);
480
481 let action = if size == 0.0 {
483 BookAction::Delete
484 } else {
485 BookAction::Update
486 };
487
488 let flags = if i == event.bids.len() - 1 && event.asks.is_empty() {
489 RecordFlag::F_LAST as u8
490 } else {
491 0
492 };
493
494 let order = BookOrder::new(
495 OrderSide::Buy,
496 Price::new(price, price_precision),
497 Quantity::new(size, size_precision),
498 0, );
500
501 deltas.push(OrderBookDelta::new(
502 instrument_id,
503 action,
504 order,
505 flags,
506 0, ts_event,
508 ts_event,
509 ));
510 }
511
512 for (i, level) in event.asks.iter().enumerate() {
514 let price = mantissa_to_f64(level.price_mantissa, event.price_exponent);
515 let size = mantissa_to_f64(level.qty_mantissa, event.qty_exponent);
516
517 let action = if size == 0.0 {
518 BookAction::Delete
519 } else {
520 BookAction::Update
521 };
522
523 let flags = if i == event.asks.len() - 1 {
524 RecordFlag::F_LAST as u8
525 } else {
526 0
527 };
528
529 let order = BookOrder::new(
530 OrderSide::Sell,
531 Price::new(price, price_precision),
532 Quantity::new(size, size_precision),
533 0, );
535
536 deltas.push(OrderBookDelta::new(
537 instrument_id,
538 action,
539 order,
540 flags,
541 0, ts_event,
543 ts_event,
544 ));
545 }
546
547 if deltas.is_empty() {
548 return vec![];
549 }
550
551 vec![NautilusWsMessage::Deltas(OrderBookDeltas::new(
552 instrument_id,
553 deltas,
554 ))]
555 }
556
557 async fn handle_subscribe(&mut self, streams: Vec<String>) -> anyhow::Result<()> {
559 let request_id = self.request_id_counter.fetch_add(1, Ordering::SeqCst);
560 let request = BinanceWsSubscription::subscribe(streams.clone(), request_id);
561 let payload = serde_json::to_string(&request)?;
562
563 self.pending_requests.insert(request_id, streams.clone());
565
566 for stream in &streams {
568 self.subscriptions.mark_subscribe(stream);
569 }
570
571 self.send_text(payload).await?;
572 Ok(())
573 }
574
575 async fn handle_unsubscribe(&mut self, streams: Vec<String>) -> anyhow::Result<()> {
577 let request_id = self.request_id_counter.fetch_add(1, Ordering::SeqCst);
578 let request = BinanceWsSubscription::unsubscribe(streams.clone(), request_id);
579 let payload = serde_json::to_string(&request)?;
580
581 self.send_text(payload).await?;
582
583 for stream in &streams {
586 self.subscriptions.mark_unsubscribe(stream);
587 self.subscriptions.confirm_unsubscribe(stream);
588 }
589
590 Ok(())
591 }
592
593 async fn send_text(&self, payload: String) -> anyhow::Result<()> {
595 let Some(client) = &self.inner else {
596 anyhow::bail!("No active WebSocket client");
597 };
598 client
599 .send_text(payload, None)
600 .await
601 .map_err(|e| anyhow::anyhow!("Failed to send message: {e}"))?;
602 Ok(())
603 }
604}
605
606#[cfg(test)]
607mod tests {
608 use rstest::rstest;
609
610 use super::*;
611 use crate::common::sbe::stream::STREAM_SCHEMA_ID;
612
613 #[rstest]
614 fn test_decode_empty_buffer() {
615 let err = decode_market_data(&[]).unwrap_err();
616 assert!(matches!(err, StreamDecodeError::BufferTooShort { .. }));
617 }
618
619 #[rstest]
620 fn test_decode_short_buffer() {
621 let buf = [0u8; 5];
622 let err = decode_market_data(&buf).unwrap_err();
623 assert!(matches!(err, StreamDecodeError::BufferTooShort { .. }));
624 }
625
626 #[rstest]
627 fn test_decode_wrong_schema() {
628 let mut buf = [0u8; 100];
629 buf[0..2].copy_from_slice(&50u16.to_le_bytes()); buf[2..4].copy_from_slice(&template_id::BEST_BID_ASK_STREAM_EVENT.to_le_bytes());
631 buf[4..6].copy_from_slice(&99u16.to_le_bytes()); buf[6..8].copy_from_slice(&0u16.to_le_bytes()); let err = decode_market_data(&buf).unwrap_err();
635 assert!(matches!(err, StreamDecodeError::SchemaMismatch { .. }));
636 }
637
638 #[rstest]
639 fn test_decode_unknown_template() {
640 let mut buf = [0u8; 100];
641 buf[0..2].copy_from_slice(&50u16.to_le_bytes()); buf[2..4].copy_from_slice(&9999u16.to_le_bytes()); buf[4..6].copy_from_slice(&STREAM_SCHEMA_ID.to_le_bytes());
644 buf[6..8].copy_from_slice(&0u16.to_le_bytes()); let err = decode_market_data(&buf).unwrap_err();
647 assert!(matches!(err, StreamDecodeError::UnknownTemplateId(9999)));
648 }
649}