nautilus_coinbase_intx/websocket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    sync::{
18        Arc,
19        atomic::{AtomicBool, Ordering},
20    },
21    time::{Duration, SystemTime},
22};
23
24use ahash::{AHashMap, AHashSet};
25use chrono::Utc;
26use dashmap::DashMap;
27use futures_util::{Stream, StreamExt};
28use nautilus_common::{logging::log_task_stopped, runtime::get_runtime};
29use nautilus_core::{
30    consts::NAUTILUS_USER_AGENT, env::get_or_env_var, time::get_atomic_clock_realtime,
31};
32use nautilus_model::{
33    data::{BarType, Data, OrderBookDeltas_API},
34    identifiers::InstrumentId,
35    instruments::{Instrument, InstrumentAny},
36};
37use nautilus_network::websocket::{MessageReader, WebSocketClient, WebSocketConfig};
38use reqwest::header::USER_AGENT;
39use tokio_tungstenite::tungstenite::{Error, Message};
40use ustr::Ustr;
41
42use super::{
43    enums::{CoinbaseIntxWsChannel, WsOperation},
44    error::CoinbaseIntxWsError,
45    messages::{CoinbaseIntxSubscription, CoinbaseIntxWsMessage, NautilusWsMessage},
46    parse::{
47        parse_candle_msg, parse_index_price_msg, parse_mark_price_msg,
48        parse_orderbook_snapshot_msg, parse_orderbook_update_msg, parse_quote_msg,
49    },
50};
51use crate::{
52    common::{
53        consts::COINBASE_INTX_WS_URL, credential::Credential, parse::bar_spec_as_coinbase_channel,
54    },
55    websocket::parse::{parse_instrument_any, parse_trade_msg},
56};
57
58/// Provides a WebSocket client for connecting to [Coinbase International](https://www.coinbase.com/en/international-exchange).
59#[derive(Debug, Clone)]
60#[cfg_attr(
61    feature = "python",
62    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.adapters")
63)]
64pub struct CoinbaseIntxWebSocketClient {
65    url: String,
66    credential: Credential,
67    heartbeat: Option<u64>,
68    inner: Arc<tokio::sync::RwLock<Option<WebSocketClient>>>,
69    rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
70    signal: Arc<AtomicBool>,
71    task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
72    subscriptions: Arc<DashMap<CoinbaseIntxWsChannel, AHashSet<Ustr>>>,
73    instruments_cache: Arc<AHashMap<Ustr, InstrumentAny>>,
74}
75
76impl Default for CoinbaseIntxWebSocketClient {
77    fn default() -> Self {
78        Self::new(None, None, None, None, Some(10)).expect("Failed to create client")
79    }
80}
81
82impl CoinbaseIntxWebSocketClient {
83    /// Creates a new [`CoinbaseIntxWebSocketClient`] instance.
84    ///
85    /// # Errors
86    ///
87    /// Returns an error if required environment variables are missing or invalid.
88    pub fn new(
89        url: Option<String>,
90        api_key: Option<String>,
91        api_secret: Option<String>,
92        api_passphrase: Option<String>,
93        heartbeat: Option<u64>,
94    ) -> anyhow::Result<Self> {
95        let url = url.unwrap_or(COINBASE_INTX_WS_URL.to_string());
96        let api_key = get_or_env_var(api_key, "COINBASE_INTX_API_KEY")?;
97        let api_secret = get_or_env_var(api_secret, "COINBASE_INTX_API_SECRET")?;
98        let api_passphrase = get_or_env_var(api_passphrase, "COINBASE_INTX_API_PASSPHRASE")?;
99
100        let credential = Credential::new(api_key, api_secret, api_passphrase);
101        let signal = Arc::new(AtomicBool::new(false));
102        let subscriptions = Arc::new(DashMap::new());
103        let instruments_cache = Arc::new(AHashMap::new());
104
105        Ok(Self {
106            url,
107            credential,
108            heartbeat,
109            inner: Arc::new(tokio::sync::RwLock::new(None)),
110            rx: None,
111            signal,
112            task_handle: None,
113            subscriptions,
114            instruments_cache,
115        })
116    }
117
118    /// Creates a new authenticated [`CoinbaseIntxWebSocketClient`] using environment variables and
119    /// the default Coinbase International production websocket url.
120    ///
121    /// # Errors
122    ///
123    /// Returns an error if required environment variables are missing or invalid.
124    pub fn from_env() -> anyhow::Result<Self> {
125        Self::new(None, None, None, None, None)
126    }
127
128    /// Returns the websocket url being used by the client.
129    #[must_use]
130    pub const fn url(&self) -> &str {
131        self.url.as_str()
132    }
133
134    /// Returns the public API key being used by the client.
135    #[must_use]
136    pub fn api_key(&self) -> &str {
137        self.credential.api_key.as_str()
138    }
139
140    /// Returns a masked version of the API key for logging purposes.
141    #[must_use]
142    pub fn api_key_masked(&self) -> String {
143        self.credential.api_key_masked()
144    }
145
146    /// Returns a value indicating whether the client is active.
147    #[must_use]
148    pub fn is_active(&self) -> bool {
149        self.inner
150            .try_read()
151            .ok()
152            .and_then(|guard| {
153                guard
154                    .as_ref()
155                    .map(nautilus_network::websocket::WebSocketClient::is_active)
156            })
157            .unwrap_or(false)
158    }
159
160    /// Returns a value indicating whether the client is closed.
161    #[must_use]
162    pub fn is_closed(&self) -> bool {
163        self.inner
164            .try_read()
165            .ok()
166            .and_then(|guard| {
167                guard
168                    .as_ref()
169                    .map(nautilus_network::websocket::WebSocketClient::is_closed)
170            })
171            .unwrap_or(true)
172    }
173
174    /// Initialize the instruments cache with the given `instruments`.
175    pub fn cache_instruments(&mut self, instruments: Vec<InstrumentAny>) {
176        let mut instruments_cache: AHashMap<Ustr, InstrumentAny> = AHashMap::new();
177
178        for inst in instruments {
179            instruments_cache.insert(inst.symbol().inner(), inst.clone());
180        }
181
182        self.instruments_cache = Arc::new(instruments_cache);
183    }
184
185    /// Get active subscriptions for a specific instrument.
186    #[must_use]
187    pub fn get_subscriptions(&self, instrument_id: InstrumentId) -> Vec<CoinbaseIntxWsChannel> {
188        let product_id = instrument_id.symbol.inner();
189        let mut channels = Vec::new();
190
191        for entry in self.subscriptions.iter() {
192            let (channel, instruments) = entry.pair();
193            if instruments.contains(&product_id) {
194                channels.push(*channel);
195            }
196        }
197
198        channels
199    }
200
201    /// Connects the client to the server and caches the given instruments.
202    ///
203    /// # Errors
204    ///
205    /// Returns an error if the WebSocket connection or initial subscription fails.
206    pub async fn connect(&mut self) -> anyhow::Result<()> {
207        let client = self.clone();
208        let post_reconnect = Arc::new(move || {
209            let client = client.clone();
210
211            tokio::spawn(async move { client.resubscribe_all().await });
212        });
213
214        let config = WebSocketConfig {
215            url: self.url.clone(),
216            headers: vec![(USER_AGENT.to_string(), NAUTILUS_USER_AGENT.to_string())],
217            message_handler: None, // Will be handled by the returned reader
218            heartbeat: self.heartbeat,
219            heartbeat_msg: None,
220            ping_handler: None,
221            reconnect_timeout_ms: Some(5_000),
222            reconnect_delay_initial_ms: None, // Use default
223            reconnect_delay_max_ms: None,     // Use default
224            reconnect_backoff_factor: None,   // Use default
225            reconnect_jitter_ms: None,        // Use default
226            reconnect_max_attempts: None,
227        };
228        let (reader, client) =
229            WebSocketClient::connect_stream(config, vec![], None, Some(post_reconnect)).await?;
230
231        *self.inner.write().await = Some(client);
232
233        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
234        self.rx = Some(Arc::new(rx));
235        let signal = self.signal.clone();
236
237        // TODO: For now just clone the entire cache out of the arc on connect
238        let instruments_cache = (*self.instruments_cache).clone();
239
240        let stream_handle = get_runtime().spawn(async move {
241            CoinbaseIntxWsMessageHandler::new(reader, signal, tx, instruments_cache)
242                .run()
243                .await;
244        });
245
246        self.task_handle = Some(Arc::new(stream_handle));
247
248        Ok(())
249    }
250
251    /// Wait until the WebSocket connection is active.
252    ///
253    /// # Errors
254    ///
255    /// Returns an error if the connection times out.
256    pub async fn wait_until_active(&self, timeout_secs: f64) -> Result<(), CoinbaseIntxWsError> {
257        let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
258
259        tokio::time::timeout(timeout, async {
260            while !self.is_active() {
261                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
262            }
263        })
264        .await
265        .map_err(|_| {
266            CoinbaseIntxWsError::ClientError(format!(
267                "WebSocket connection timeout after {timeout_secs} seconds"
268            ))
269        })?;
270
271        Ok(())
272    }
273
274    /// Provides the internal data stream as a channel-based stream.
275    ///
276    /// # Panics
277    ///
278    /// This function panics if:
279    /// - The websocket is not connected.
280    /// - If `stream_data` has already been called somewhere else (stream receiver is then taken).
281    pub fn stream(&mut self) -> impl Stream<Item = NautilusWsMessage> + 'static {
282        let rx = self
283            .rx
284            .take()
285            .expect("Data stream receiver already taken or not connected"); // Design-time error
286        let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
287        async_stream::stream! {
288            while let Some(data) = rx.recv().await {
289                yield data;
290            }
291        }
292    }
293
294    /// Closes the client.
295    ///
296    /// # Errors
297    ///
298    /// Returns an error if the WebSocket fails to close properly.
299    pub async fn close(&mut self) -> Result<(), Error> {
300        tracing::debug!("Closing");
301        self.signal.store(true, Ordering::Relaxed);
302
303        match tokio::time::timeout(Duration::from_secs(5), async {
304            if let Some(inner) = self.inner.read().await.as_ref() {
305                inner.disconnect().await;
306            } else {
307                log::error!("Error on close: not connected");
308            }
309        })
310        .await
311        {
312            Ok(()) => {
313                tracing::debug!("Inner disconnected");
314            }
315            Err(_) => {
316                tracing::error!("Timeout waiting for inner client to disconnect");
317            }
318        }
319
320        log::debug!("Closed");
321
322        Ok(())
323    }
324
325    /// Subscribes to the given channels and product IDs.
326    ///
327    /// # Errors
328    ///
329    /// Returns an error if the subscription message cannot be sent.
330    async fn subscribe(
331        &self,
332        channels: Vec<CoinbaseIntxWsChannel>,
333        product_ids: Vec<Ustr>,
334    ) -> Result<(), CoinbaseIntxWsError> {
335        // Update active subscriptions
336        for channel in &channels {
337            self.subscriptions
338                .entry(*channel)
339                .or_default()
340                .extend(product_ids.clone());
341        }
342        tracing::debug!(
343            "Added active subscription(s): channels={channels:?}, product_ids={product_ids:?}"
344        );
345
346        let time = chrono::DateTime::<Utc>::from(SystemTime::now())
347            .timestamp()
348            .to_string();
349        let signature = self.credential.sign_ws(&time);
350        let message = CoinbaseIntxSubscription {
351            op: WsOperation::Subscribe,
352            product_ids: Some(product_ids),
353            channels,
354            time,
355            key: self.credential.api_key,
356            passphrase: self.credential.api_passphrase,
357            signature,
358        };
359
360        let json_txt = serde_json::to_string(&message)
361            .map_err(|e| CoinbaseIntxWsError::JsonError(e.to_string()))?;
362
363        if let Some(inner) = self.inner.read().await.as_ref() {
364            if let Err(e) = inner.send_text(json_txt, None).await {
365                tracing::error!("Error sending message: {e:?}");
366            }
367        } else {
368            return Err(CoinbaseIntxWsError::ClientError(
369                "Cannot send message: not connected".to_string(),
370            ));
371        }
372
373        Ok(())
374    }
375
376    /// Unsubscribes from the given channels and product IDs.
377    async fn unsubscribe(
378        &self,
379        channels: Vec<CoinbaseIntxWsChannel>,
380        product_ids: Vec<Ustr>,
381    ) -> Result<(), CoinbaseIntxWsError> {
382        // Update active subscriptions
383        for channel in &channels {
384            if let Some(mut entry) = self.subscriptions.get_mut(channel) {
385                for product_id in &product_ids {
386                    entry.remove(product_id);
387                }
388                if entry.is_empty() {
389                    drop(entry);
390                    self.subscriptions.remove(channel);
391                }
392            }
393        }
394        tracing::debug!(
395            "Removed active subscription(s): channels={channels:?}, product_ids={product_ids:?}"
396        );
397
398        let time = chrono::DateTime::<Utc>::from(SystemTime::now())
399            .timestamp()
400            .to_string();
401        let signature = self.credential.sign_ws(&time);
402        let message = CoinbaseIntxSubscription {
403            op: WsOperation::Unsubscribe,
404            product_ids: Some(product_ids),
405            channels,
406            time,
407            key: self.credential.api_key,
408            passphrase: self.credential.api_passphrase,
409            signature,
410        };
411
412        let json_txt = serde_json::to_string(&message)
413            .map_err(|e| CoinbaseIntxWsError::JsonError(e.to_string()))?;
414
415        if let Some(inner) = self.inner.read().await.as_ref() {
416            if let Err(e) = inner.send_text(json_txt, None).await {
417                tracing::error!("Error sending message: {e:?}");
418            }
419        } else {
420            return Err(CoinbaseIntxWsError::ClientError(
421                "Cannot send message: not connected".to_string(),
422            ));
423        }
424
425        Ok(())
426    }
427
428    /// Resubscribes for all active subscriptions.
429    async fn resubscribe_all(&self) {
430        let mut subs = Vec::new();
431        for entry in self.subscriptions.iter() {
432            let (channel, product_ids) = entry.pair();
433            if !product_ids.is_empty() {
434                subs.push((*channel, product_ids.clone()));
435            }
436        }
437
438        for (channel, product_ids) in subs {
439            tracing::debug!("Resubscribing: channel={channel}, product_ids={product_ids:?}");
440
441            if let Err(e) = self
442                .subscribe(vec![channel], product_ids.into_iter().collect())
443                .await
444            {
445                tracing::error!("Failed to resubscribe to channel {channel}: {e}");
446            }
447        }
448    }
449
450    /// Subscribes to instrument definition updates for the given instrument IDs.
451    /// Subscribes to instrument updates for the specified instruments.
452    ///
453    /// # Errors
454    ///
455    /// Returns an error if the subscription fails.
456    pub async fn subscribe_instruments(
457        &self,
458        instrument_ids: Vec<InstrumentId>,
459    ) -> Result<(), CoinbaseIntxWsError> {
460        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
461        self.subscribe(vec![CoinbaseIntxWsChannel::Instruments], product_ids)
462            .await
463    }
464
465    /// Subscribes to funding message streams for the given instrument IDs.
466    /// Subscribes to funding rate updates for the specified instruments.
467    ///
468    /// # Errors
469    ///
470    /// Returns an error if the subscription fails.
471    pub async fn subscribe_funding_rates(
472        &self,
473        instrument_ids: Vec<InstrumentId>,
474    ) -> Result<(), CoinbaseIntxWsError> {
475        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
476        self.subscribe(vec![CoinbaseIntxWsChannel::Funding], product_ids)
477            .await
478    }
479
480    /// Subscribes to risk message streams for the given instrument IDs.
481    /// Subscribes to risk updates for the specified instruments.
482    ///
483    /// # Errors
484    ///
485    /// Returns an error if the subscription fails.
486    pub async fn subscribe_risk(
487        &self,
488        instrument_ids: Vec<InstrumentId>,
489    ) -> Result<(), CoinbaseIntxWsError> {
490        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
491        self.subscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
492            .await
493    }
494
495    /// Subscribes to order book (level 2) streams for the given instrument IDs.
496    /// Subscribes to order book snapshots and updates for the specified instruments.
497    ///
498    /// # Errors
499    ///
500    /// Returns an error if the subscription fails.
501    pub async fn subscribe_book(
502        &self,
503        instrument_ids: Vec<InstrumentId>,
504    ) -> Result<(), CoinbaseIntxWsError> {
505        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
506        self.subscribe(vec![CoinbaseIntxWsChannel::Level2], product_ids)
507            .await
508    }
509
510    /// Subscribes to quote (level 1) streams for the given instrument IDs.
511    /// Subscribes to top-of-book quote updates for the specified instruments.
512    ///
513    /// # Errors
514    ///
515    /// Returns an error if the subscription fails.
516    pub async fn subscribe_quotes(
517        &self,
518        instrument_ids: Vec<InstrumentId>,
519    ) -> Result<(), CoinbaseIntxWsError> {
520        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
521        self.subscribe(vec![CoinbaseIntxWsChannel::Level1], product_ids)
522            .await
523    }
524
525    /// Subscribes to trade (match) streams for the given instrument IDs.
526    /// Subscribes to trade updates for the specified instruments.
527    ///
528    /// # Errors
529    ///
530    /// Returns an error if the subscription fails.
531    pub async fn subscribe_trades(
532        &self,
533        instrument_ids: Vec<InstrumentId>,
534    ) -> Result<(), CoinbaseIntxWsError> {
535        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
536        self.subscribe(vec![CoinbaseIntxWsChannel::Match], product_ids)
537            .await
538    }
539
540    /// Subscribes to risk streams (for mark prices) for the given instrument IDs.
541    /// Subscribes to mark price updates for the specified instruments.
542    ///
543    /// # Errors
544    ///
545    /// Returns an error if the subscription fails.
546    pub async fn subscribe_mark_prices(
547        &self,
548        instrument_ids: Vec<InstrumentId>,
549    ) -> Result<(), CoinbaseIntxWsError> {
550        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
551        self.subscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
552            .await
553    }
554
555    /// Subscribes to risk streams (for index prices) for the given instrument IDs.
556    /// Subscribes to index price updates for the specified instruments.
557    ///
558    /// # Errors
559    ///
560    /// Returns an error if the subscription fails.
561    pub async fn subscribe_index_prices(
562        &self,
563        instrument_ids: Vec<InstrumentId>,
564    ) -> Result<(), CoinbaseIntxWsError> {
565        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
566        self.subscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
567            .await
568    }
569
570    /// Subscribes to bar (candle) streams for the given instrument IDs.
571    /// Subscribes to candlestick bar updates for the specified bar type.
572    ///
573    /// # Errors
574    ///
575    /// Returns an error if the subscription fails.
576    pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), CoinbaseIntxWsError> {
577        let channel = bar_spec_as_coinbase_channel(bar_type.spec())
578            .map_err(|e| CoinbaseIntxWsError::ClientError(e.to_string()))?;
579        let product_ids = vec![bar_type.standard().instrument_id().symbol.inner()];
580        self.subscribe(vec![channel], product_ids).await
581    }
582
583    /// Unsubscribes from instrument definition streams for the given instrument IDs.
584    /// Unsubscribes from instrument updates for the specified instruments.
585    ///
586    /// # Errors
587    ///
588    /// Returns an error if the unsubscription fails.
589    pub async fn unsubscribe_instruments(
590        &self,
591        instrument_ids: Vec<InstrumentId>,
592    ) -> Result<(), CoinbaseIntxWsError> {
593        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
594        self.unsubscribe(vec![CoinbaseIntxWsChannel::Instruments], product_ids)
595            .await
596    }
597
598    /// Unsubscribes from risk message streams for the given instrument IDs.
599    /// Unsubscribes from risk updates for the specified instruments.
600    ///
601    /// # Errors
602    ///
603    /// Returns an error if the unsubscription fails.
604    pub async fn unsubscribe_risk(
605        &self,
606        instrument_ids: Vec<InstrumentId>,
607    ) -> Result<(), CoinbaseIntxWsError> {
608        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
609        self.unsubscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
610            .await
611    }
612
613    /// Unsubscribes from funding message streams for the given instrument IDs.
614    /// Unsubscribes from funding updates for the specified instruments.
615    ///
616    /// # Errors
617    ///
618    /// Returns an error if the unsubscription fails.
619    pub async fn unsubscribe_funding(
620        &self,
621        instrument_ids: Vec<InstrumentId>,
622    ) -> Result<(), CoinbaseIntxWsError> {
623        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
624        self.unsubscribe(vec![CoinbaseIntxWsChannel::Funding], product_ids)
625            .await
626    }
627
628    /// Unsubscribes from order book (level 2) streams for the given instrument IDs.
629    /// Unsubscribes from order book updates for the specified instruments.
630    ///
631    /// # Errors
632    ///
633    /// Returns an error if the unsubscription fails.
634    pub async fn unsubscribe_book(
635        &self,
636        instrument_ids: Vec<InstrumentId>,
637    ) -> Result<(), CoinbaseIntxWsError> {
638        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
639        self.unsubscribe(vec![CoinbaseIntxWsChannel::Level2], product_ids)
640            .await
641    }
642
643    /// Unsubscribes from quote (level 1) streams for the given instrument IDs.
644    /// Unsubscribes from quote updates for the specified instruments.
645    ///
646    /// # Errors
647    ///
648    /// Returns an error if the unsubscription fails.
649    pub async fn unsubscribe_quotes(
650        &self,
651        instrument_ids: Vec<InstrumentId>,
652    ) -> Result<(), CoinbaseIntxWsError> {
653        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
654        self.unsubscribe(vec![CoinbaseIntxWsChannel::Level1], product_ids)
655            .await
656    }
657
658    /// Unsubscribes from trade (match) streams for the given instrument IDs.
659    /// Unsubscribes from trade updates for the specified instruments.
660    ///
661    /// # Errors
662    ///
663    /// Returns an error if the unsubscription fails.
664    pub async fn unsubscribe_trades(
665        &self,
666        instrument_ids: Vec<InstrumentId>,
667    ) -> Result<(), CoinbaseIntxWsError> {
668        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
669        self.unsubscribe(vec![CoinbaseIntxWsChannel::Match], product_ids)
670            .await
671    }
672
673    /// Unsubscribes from risk streams (for mark prices) for the given instrument IDs.
674    /// Unsubscribes from mark price updates for the specified instruments.
675    ///
676    /// # Errors
677    ///
678    /// Returns an error if the unsubscription fails.
679    pub async fn unsubscribe_mark_prices(
680        &self,
681        instrument_ids: Vec<InstrumentId>,
682    ) -> Result<(), CoinbaseIntxWsError> {
683        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
684        self.unsubscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
685            .await
686    }
687
688    /// Unsubscribes from risk streams (for index prices) for the given instrument IDs.
689    /// Unsubscribes from index price updates for the specified instruments.
690    ///
691    /// # Errors
692    ///
693    /// Returns an error if the unsubscription fails.
694    pub async fn unsubscribe_index_prices(
695        &self,
696        instrument_ids: Vec<InstrumentId>,
697    ) -> Result<(), CoinbaseIntxWsError> {
698        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
699        self.unsubscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
700            .await
701    }
702
703    /// Unsubscribes from bar (candle) streams for the given instrument IDs.
704    /// Unsubscribes from bar updates for the specified bar type.
705    ///
706    /// # Errors
707    ///
708    /// Returns an error if the unsubscription fails.
709    pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), CoinbaseIntxWsError> {
710        let channel = bar_spec_as_coinbase_channel(bar_type.spec())
711            .map_err(|e| CoinbaseIntxWsError::ClientError(e.to_string()))?;
712        let product_id = bar_type.standard().instrument_id().symbol.inner();
713        self.unsubscribe(vec![channel], vec![product_id]).await
714    }
715}
716
717fn instrument_ids_to_product_ids(instrument_ids: &[InstrumentId]) -> Vec<Ustr> {
718    instrument_ids.iter().map(|x| x.symbol.inner()).collect()
719}
720
721/// Provides a raw message handler for Coinbase International WebSocket feed.
722struct CoinbaseIntxFeedHandler {
723    reader: MessageReader,
724    signal: Arc<AtomicBool>,
725}
726
727impl CoinbaseIntxFeedHandler {
728    /// Creates a new [`CoinbaseIntxFeedHandler`] instance.
729    pub const fn new(reader: MessageReader, signal: Arc<AtomicBool>) -> Self {
730        Self { reader, signal }
731    }
732
733    /// Gets the next message from the WebSocket message stream.
734    async fn next(&mut self) -> Option<CoinbaseIntxWsMessage> {
735        // Timeout awaiting the next message before checking signal
736        let timeout = Duration::from_millis(10);
737
738        loop {
739            if self.signal.load(Ordering::Relaxed) {
740                tracing::debug!("Stop signal received");
741                break;
742            }
743
744            match tokio::time::timeout(timeout, self.reader.next()).await {
745                Ok(Some(msg)) => match msg {
746                    Ok(Message::Pong(_)) => {
747                        tracing::trace!("Received pong");
748                    }
749                    Ok(Message::Ping(_)) => {
750                        tracing::trace!("Received pong"); // Coinbase send ping frames as pongs
751                    }
752                    Ok(Message::Text(text)) => {
753                        match serde_json::from_str(&text) {
754                            Ok(event) => match &event {
755                                CoinbaseIntxWsMessage::Reject(msg) => {
756                                    tracing::error!("{msg:?}");
757                                }
758                                CoinbaseIntxWsMessage::Confirmation(msg) => {
759                                    tracing::debug!("{msg:?}");
760                                    continue;
761                                }
762                                CoinbaseIntxWsMessage::Instrument(_) => return Some(event),
763                                CoinbaseIntxWsMessage::Funding(_) => return Some(event),
764                                CoinbaseIntxWsMessage::Risk(_) => return Some(event),
765                                CoinbaseIntxWsMessage::BookSnapshot(_) => return Some(event),
766                                CoinbaseIntxWsMessage::BookUpdate(_) => return Some(event),
767                                CoinbaseIntxWsMessage::Quote(_) => return Some(event),
768                                CoinbaseIntxWsMessage::Trade(_) => return Some(event),
769                                CoinbaseIntxWsMessage::CandleSnapshot(_) => return Some(event),
770                                CoinbaseIntxWsMessage::CandleUpdate(_) => continue, // Ignore
771                            },
772                            Err(e) => {
773                                tracing::error!("Failed to parse message: {e}: {text}");
774                                break;
775                            }
776                        }
777                    }
778                    Ok(Message::Binary(msg)) => {
779                        tracing::debug!("Raw binary: {msg:?}");
780                    }
781                    Ok(Message::Close(_)) => {
782                        tracing::debug!("Received close message");
783                        return None;
784                    }
785                    Ok(msg) => {
786                        tracing::warn!("Unexpected message: {msg:?}");
787                    }
788                    Err(e) => {
789                        tracing::error!("{e}, stopping client");
790                        break; // Break as indicates a bug in the code
791                    }
792                },
793                Ok(None) => {
794                    tracing::info!("WebSocket stream closed");
795                    break;
796                }
797                Err(_) => {} // Timeout occurred awaiting a message, continue loop to check signal
798            }
799        }
800
801        log_task_stopped("message-streaming");
802        None
803    }
804}
805
806/// Provides a Nautilus parser for the Coinbase International WebSocket feed.
807struct CoinbaseIntxWsMessageHandler {
808    handler: CoinbaseIntxFeedHandler,
809    tx: tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>,
810    instruments_cache: AHashMap<Ustr, InstrumentAny>,
811}
812
813impl CoinbaseIntxWsMessageHandler {
814    /// Creates a new [`CoinbaseIntxWsMessageHandler`] instance.
815    pub const fn new(
816        reader: MessageReader,
817        signal: Arc<AtomicBool>,
818        tx: tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>,
819        instruments_cache: AHashMap<Ustr, InstrumentAny>,
820    ) -> Self {
821        let handler = CoinbaseIntxFeedHandler::new(reader, signal);
822        Self {
823            handler,
824            tx,
825            instruments_cache,
826        }
827    }
828
829    /// Runs the WebSocket message feed.
830    async fn run(&mut self) {
831        while let Some(data) = self.next().await {
832            if let Err(e) = self.tx.send(data) {
833                tracing::error!("Error sending data: {e}");
834                break; // Stop processing on channel error
835            }
836        }
837    }
838
839    /// Gets the next message from the WebSocket message handler.
840    async fn next(&mut self) -> Option<NautilusWsMessage> {
841        let clock = get_atomic_clock_realtime();
842
843        while let Some(event) = self.handler.next().await {
844            match event {
845                CoinbaseIntxWsMessage::Instrument(msg) => {
846                    if let Some(inst) = parse_instrument_any(&msg, clock.get_time_ns()) {
847                        // Update instruments map
848                        self.instruments_cache
849                            .insert(inst.raw_symbol().inner(), inst.clone());
850                        return Some(NautilusWsMessage::Instrument(inst));
851                    }
852                }
853                CoinbaseIntxWsMessage::Funding(msg) => {
854                    tracing::warn!("Received {msg:?}"); // TODO: Implement
855                }
856                CoinbaseIntxWsMessage::BookSnapshot(msg) => {
857                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
858                        match parse_orderbook_snapshot_msg(
859                            &msg,
860                            inst.id(),
861                            inst.price_precision(),
862                            inst.size_precision(),
863                            clock.get_time_ns(),
864                        ) {
865                            Ok(deltas) => {
866                                let deltas = OrderBookDeltas_API::new(deltas);
867                                let data = Data::Deltas(deltas);
868                                return Some(NautilusWsMessage::Data(data));
869                            }
870                            Err(e) => {
871                                tracing::error!("Failed to parse orderbook snapshot: {e}");
872                                return None;
873                            }
874                        }
875                    }
876                    tracing::error!("No instrument found for {}", msg.product_id);
877                    return None;
878                }
879                CoinbaseIntxWsMessage::BookUpdate(msg) => {
880                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
881                        match parse_orderbook_update_msg(
882                            &msg,
883                            inst.id(),
884                            inst.price_precision(),
885                            inst.size_precision(),
886                            clock.get_time_ns(),
887                        ) {
888                            Ok(deltas) => {
889                                let deltas = OrderBookDeltas_API::new(deltas);
890                                let data = Data::Deltas(deltas);
891                                return Some(NautilusWsMessage::Data(data));
892                            }
893                            Err(e) => {
894                                tracing::error!("Failed to parse orderbook update: {e}");
895                            }
896                        }
897                    } else {
898                        tracing::error!("No instrument found for {}", msg.product_id);
899                    }
900                }
901                CoinbaseIntxWsMessage::Quote(msg) => {
902                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
903                        match parse_quote_msg(
904                            &msg,
905                            inst.id(),
906                            inst.price_precision(),
907                            inst.size_precision(),
908                            clock.get_time_ns(),
909                        ) {
910                            Ok(quote) => return Some(NautilusWsMessage::Data(Data::Quote(quote))),
911                            Err(e) => {
912                                tracing::error!("Failed to parse quote: {e}");
913                            }
914                        }
915                    } else {
916                        tracing::error!("No instrument found for {}", msg.product_id);
917                    }
918                }
919                CoinbaseIntxWsMessage::Trade(msg) => {
920                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
921                        match parse_trade_msg(
922                            &msg,
923                            inst.id(),
924                            inst.price_precision(),
925                            inst.size_precision(),
926                            clock.get_time_ns(),
927                        ) {
928                            Ok(trade) => return Some(NautilusWsMessage::Data(Data::Trade(trade))),
929                            Err(e) => {
930                                tracing::error!("Failed to parse trade: {e}");
931                            }
932                        }
933                    } else {
934                        tracing::error!("No instrument found for {}", msg.product_id);
935                    }
936                }
937                CoinbaseIntxWsMessage::Risk(msg) => {
938                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
939                        let mark_price = match parse_mark_price_msg(
940                            &msg,
941                            inst.id(),
942                            inst.price_precision(),
943                            clock.get_time_ns(),
944                        ) {
945                            Ok(mark_price) => Some(mark_price),
946                            Err(e) => {
947                                tracing::error!("Failed to parse mark price: {e}");
948                                None
949                            }
950                        };
951
952                        let index_price = match parse_index_price_msg(
953                            &msg,
954                            inst.id(),
955                            inst.price_precision(),
956                            clock.get_time_ns(),
957                        ) {
958                            Ok(index_price) => Some(index_price),
959                            Err(e) => {
960                                tracing::error!("Failed to parse index price: {e}");
961                                None
962                            }
963                        };
964
965                        match (mark_price, index_price) {
966                            (Some(mark), Some(index)) => {
967                                return Some(NautilusWsMessage::MarkAndIndex((mark, index)));
968                            }
969                            (Some(mark), None) => return Some(NautilusWsMessage::MarkPrice(mark)),
970                            (None, Some(index)) => {
971                                return Some(NautilusWsMessage::IndexPrice(index));
972                            }
973                            (None, None) => continue,
974                        };
975                    }
976                    tracing::error!("No instrument found for {}", msg.product_id);
977                }
978                CoinbaseIntxWsMessage::CandleSnapshot(msg) => {
979                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
980                        match parse_candle_msg(
981                            &msg,
982                            inst.id(),
983                            inst.price_precision(),
984                            inst.size_precision(),
985                            clock.get_time_ns(),
986                        ) {
987                            Ok(bar) => return Some(NautilusWsMessage::Data(Data::Bar(bar))),
988                            Err(e) => {
989                                tracing::error!("Failed to parse candle: {e}");
990                            }
991                        }
992                    } else {
993                        tracing::error!("No instrument found for {}", msg.product_id);
994                    }
995                }
996                _ => {
997                    tracing::warn!("Not implemented: {event:?}");
998                }
999            }
1000        }
1001        None // Connection closed
1002    }
1003}