nautilus_coinbase_intx/websocket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    sync::{
18        Arc,
19        atomic::{AtomicBool, Ordering},
20    },
21    time::{Duration, SystemTime},
22};
23
24use ahash::{AHashMap, AHashSet};
25use chrono::Utc;
26use dashmap::DashMap;
27use futures_util::{Stream, StreamExt};
28use nautilus_common::{live::get_runtime, logging::log_task_stopped};
29use nautilus_core::{
30    consts::NAUTILUS_USER_AGENT, env::get_or_env_var, time::get_atomic_clock_realtime,
31};
32use nautilus_model::{
33    data::{BarType, Data, OrderBookDeltas_API},
34    identifiers::InstrumentId,
35    instruments::{Instrument, InstrumentAny},
36};
37use nautilus_network::{
38    http::USER_AGENT,
39    websocket::{MessageReader, WebSocketClient, WebSocketConfig},
40};
41use tokio_tungstenite::tungstenite::{Error, Message};
42use ustr::Ustr;
43
44use super::{
45    enums::{CoinbaseIntxWsChannel, WsOperation},
46    error::CoinbaseIntxWsError,
47    messages::{CoinbaseIntxSubscription, CoinbaseIntxWsMessage, NautilusWsMessage},
48    parse::{
49        parse_candle_msg, parse_index_price_msg, parse_mark_price_msg,
50        parse_orderbook_snapshot_msg, parse_orderbook_update_msg, parse_quote_msg,
51    },
52};
53use crate::{
54    common::{
55        consts::COINBASE_INTX_WS_URL, credential::Credential, parse::bar_spec_as_coinbase_channel,
56    },
57    websocket::parse::{parse_instrument_any, parse_trade_msg},
58};
59
60/// Provides a WebSocket client for connecting to [Coinbase International](https://www.coinbase.com/en/international-exchange).
61#[derive(Debug, Clone)]
62#[cfg_attr(
63    feature = "python",
64    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.coinbase_intx")
65)]
66pub struct CoinbaseIntxWebSocketClient {
67    url: String,
68    credential: Credential,
69    heartbeat: Option<u64>,
70    inner: Arc<tokio::sync::RwLock<Option<WebSocketClient>>>,
71    rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
72    signal: Arc<AtomicBool>,
73    task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
74    subscriptions: Arc<DashMap<CoinbaseIntxWsChannel, AHashSet<Ustr>>>,
75    instruments_cache: Arc<AHashMap<Ustr, InstrumentAny>>,
76}
77
78impl Default for CoinbaseIntxWebSocketClient {
79    fn default() -> Self {
80        Self::new(None, None, None, None, Some(10)).expect("Failed to create client")
81    }
82}
83
84impl CoinbaseIntxWebSocketClient {
85    /// Creates a new [`CoinbaseIntxWebSocketClient`] instance.
86    ///
87    /// # Errors
88    ///
89    /// Returns an error if required environment variables are missing or invalid.
90    pub fn new(
91        url: Option<String>,
92        api_key: Option<String>,
93        api_secret: Option<String>,
94        api_passphrase: Option<String>,
95        heartbeat: Option<u64>,
96    ) -> anyhow::Result<Self> {
97        let url = url.unwrap_or(COINBASE_INTX_WS_URL.to_string());
98        let api_key = get_or_env_var(api_key, "COINBASE_INTX_API_KEY")?;
99        let api_secret = get_or_env_var(api_secret, "COINBASE_INTX_API_SECRET")?;
100        let api_passphrase = get_or_env_var(api_passphrase, "COINBASE_INTX_API_PASSPHRASE")?;
101
102        let credential = Credential::new(api_key, api_secret, api_passphrase);
103        let signal = Arc::new(AtomicBool::new(false));
104        let subscriptions = Arc::new(DashMap::new());
105        let instruments_cache = Arc::new(AHashMap::new());
106
107        Ok(Self {
108            url,
109            credential,
110            heartbeat,
111            inner: Arc::new(tokio::sync::RwLock::new(None)),
112            rx: None,
113            signal,
114            task_handle: None,
115            subscriptions,
116            instruments_cache,
117        })
118    }
119
120    /// Creates a new authenticated [`CoinbaseIntxWebSocketClient`] using environment variables and
121    /// the default Coinbase International production websocket url.
122    ///
123    /// # Errors
124    ///
125    /// Returns an error if required environment variables are missing or invalid.
126    pub fn from_env() -> anyhow::Result<Self> {
127        Self::new(None, None, None, None, None)
128    }
129
130    /// Returns the websocket url being used by the client.
131    #[must_use]
132    pub const fn url(&self) -> &str {
133        self.url.as_str()
134    }
135
136    /// Returns the public API key being used by the client.
137    #[must_use]
138    pub fn api_key(&self) -> &str {
139        self.credential.api_key.as_str()
140    }
141
142    /// Returns a masked version of the API key for logging purposes.
143    #[must_use]
144    pub fn api_key_masked(&self) -> String {
145        self.credential.api_key_masked()
146    }
147
148    /// Returns a value indicating whether the client is active.
149    #[must_use]
150    pub fn is_active(&self) -> bool {
151        self.inner
152            .try_read()
153            .ok()
154            .and_then(|guard| guard.as_ref().map(WebSocketClient::is_active))
155            .unwrap_or(false)
156    }
157
158    /// Returns a value indicating whether the client is closed.
159    #[must_use]
160    pub fn is_closed(&self) -> bool {
161        self.inner
162            .try_read()
163            .ok()
164            .and_then(|guard| guard.as_ref().map(WebSocketClient::is_closed))
165            .unwrap_or(true)
166    }
167
168    /// Initialize the instruments cache with the given `instruments`.
169    pub fn cache_instruments(&mut self, instruments: Vec<InstrumentAny>) {
170        let mut instruments_cache: AHashMap<Ustr, InstrumentAny> = AHashMap::new();
171
172        for inst in instruments {
173            instruments_cache.insert(inst.symbol().inner(), inst.clone());
174        }
175
176        self.instruments_cache = Arc::new(instruments_cache);
177    }
178
179    /// Get active subscriptions for a specific instrument.
180    #[must_use]
181    pub fn get_subscriptions(&self, instrument_id: InstrumentId) -> Vec<CoinbaseIntxWsChannel> {
182        let product_id = instrument_id.symbol.inner();
183        let mut channels = Vec::new();
184
185        for entry in self.subscriptions.iter() {
186            let (channel, instruments) = entry.pair();
187            if instruments.contains(&product_id) {
188                channels.push(*channel);
189            }
190        }
191
192        channels
193    }
194
195    /// Connects the client to the server and caches the given instruments.
196    ///
197    /// # Errors
198    ///
199    /// Returns an error if the WebSocket connection or initial subscription fails.
200    pub async fn connect(&mut self) -> anyhow::Result<()> {
201        let client = self.clone();
202        let post_reconnect = Arc::new(move || {
203            let client = client.clone();
204
205            get_runtime().spawn(async move { client.resubscribe_all().await });
206        });
207
208        let config = WebSocketConfig {
209            url: self.url.clone(),
210            headers: vec![(USER_AGENT.to_string(), NAUTILUS_USER_AGENT.to_string())],
211            heartbeat: self.heartbeat,
212            heartbeat_msg: None,
213            reconnect_timeout_ms: Some(5_000),
214            reconnect_delay_initial_ms: None, // Use default
215            reconnect_delay_max_ms: None,     // Use default
216            reconnect_backoff_factor: None,   // Use default
217            reconnect_jitter_ms: None,        // Use default
218            reconnect_max_attempts: None,
219        };
220        let (reader, client) =
221            WebSocketClient::connect_stream(config, vec![], None, Some(post_reconnect)).await?;
222
223        *self.inner.write().await = Some(client);
224
225        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
226        self.rx = Some(Arc::new(rx));
227        let signal = self.signal.clone();
228
229        // TODO: For now just clone the entire cache out of the arc on connect
230        let instruments_cache = (*self.instruments_cache).clone();
231
232        let stream_handle = get_runtime().spawn(async move {
233            CoinbaseIntxWsMessageHandler::new(reader, signal, tx, instruments_cache)
234                .run()
235                .await;
236        });
237
238        self.task_handle = Some(Arc::new(stream_handle));
239
240        Ok(())
241    }
242
243    /// Wait until the WebSocket connection is active.
244    ///
245    /// # Errors
246    ///
247    /// Returns an error if the connection times out.
248    pub async fn wait_until_active(&self, timeout_secs: f64) -> Result<(), CoinbaseIntxWsError> {
249        let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
250
251        tokio::time::timeout(timeout, async {
252            while !self.is_active() {
253                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
254            }
255        })
256        .await
257        .map_err(|_| {
258            CoinbaseIntxWsError::ClientError(format!(
259                "WebSocket connection timeout after {timeout_secs} seconds"
260            ))
261        })?;
262
263        Ok(())
264    }
265
266    /// Provides the internal data stream as a channel-based stream.
267    ///
268    /// # Panics
269    ///
270    /// This function panics if:
271    /// - The websocket is not connected.
272    /// - If `stream_data` has already been called somewhere else (stream receiver is then taken).
273    pub fn stream(&mut self) -> impl Stream<Item = NautilusWsMessage> + 'static {
274        let rx = self
275            .rx
276            .take()
277            .expect("Data stream receiver already taken or not connected"); // Design-time error
278        let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
279        async_stream::stream! {
280            while let Some(data) = rx.recv().await {
281                yield data;
282            }
283        }
284    }
285
286    /// Closes the client.
287    ///
288    /// # Errors
289    ///
290    /// Returns an error if the WebSocket fails to close properly.
291    pub async fn close(&mut self) -> Result<(), Error> {
292        tracing::debug!("Closing");
293        self.signal.store(true, Ordering::Relaxed);
294
295        match tokio::time::timeout(Duration::from_secs(5), async {
296            if let Some(inner) = self.inner.read().await.as_ref() {
297                inner.disconnect().await;
298            } else {
299                log::error!("Error on close: not connected");
300            }
301        })
302        .await
303        {
304            Ok(()) => {
305                tracing::debug!("Inner disconnected");
306            }
307            Err(_) => {
308                tracing::error!("Timeout waiting for inner client to disconnect");
309            }
310        }
311
312        log::debug!("Closed");
313
314        Ok(())
315    }
316
317    /// Subscribes to the given channels and product IDs.
318    ///
319    /// # Errors
320    ///
321    /// Returns an error if the subscription message cannot be sent.
322    async fn subscribe(
323        &self,
324        channels: Vec<CoinbaseIntxWsChannel>,
325        product_ids: Vec<Ustr>,
326    ) -> Result<(), CoinbaseIntxWsError> {
327        // Update active subscriptions
328        for channel in &channels {
329            self.subscriptions
330                .entry(*channel)
331                .or_default()
332                .extend(product_ids.clone());
333        }
334        tracing::debug!(
335            "Added active subscription(s): channels={channels:?}, product_ids={product_ids:?}"
336        );
337
338        let time = chrono::DateTime::<Utc>::from(SystemTime::now())
339            .timestamp()
340            .to_string();
341        let signature = self.credential.sign_ws(&time);
342        let message = CoinbaseIntxSubscription {
343            op: WsOperation::Subscribe,
344            product_ids: Some(product_ids),
345            channels,
346            time,
347            key: self.credential.api_key,
348            passphrase: self.credential.api_passphrase,
349            signature,
350        };
351
352        let json_txt = serde_json::to_string(&message)
353            .map_err(|e| CoinbaseIntxWsError::JsonError(e.to_string()))?;
354
355        if let Some(inner) = self.inner.read().await.as_ref() {
356            if let Err(e) = inner.send_text(json_txt, None).await {
357                tracing::error!("Error sending message: {e:?}");
358            }
359        } else {
360            return Err(CoinbaseIntxWsError::ClientError(
361                "Cannot send message: not connected".to_string(),
362            ));
363        }
364
365        Ok(())
366    }
367
368    /// Unsubscribes from the given channels and product IDs.
369    async fn unsubscribe(
370        &self,
371        channels: Vec<CoinbaseIntxWsChannel>,
372        product_ids: Vec<Ustr>,
373    ) -> Result<(), CoinbaseIntxWsError> {
374        // Update active subscriptions
375        for channel in &channels {
376            if let Some(mut entry) = self.subscriptions.get_mut(channel) {
377                for product_id in &product_ids {
378                    entry.remove(product_id);
379                }
380                if entry.is_empty() {
381                    drop(entry);
382                    self.subscriptions.remove(channel);
383                }
384            }
385        }
386        tracing::debug!(
387            "Removed active subscription(s): channels={channels:?}, product_ids={product_ids:?}"
388        );
389
390        let time = chrono::DateTime::<Utc>::from(SystemTime::now())
391            .timestamp()
392            .to_string();
393        let signature = self.credential.sign_ws(&time);
394        let message = CoinbaseIntxSubscription {
395            op: WsOperation::Unsubscribe,
396            product_ids: Some(product_ids),
397            channels,
398            time,
399            key: self.credential.api_key,
400            passphrase: self.credential.api_passphrase,
401            signature,
402        };
403
404        let json_txt = serde_json::to_string(&message)
405            .map_err(|e| CoinbaseIntxWsError::JsonError(e.to_string()))?;
406
407        if let Some(inner) = self.inner.read().await.as_ref() {
408            if let Err(e) = inner.send_text(json_txt, None).await {
409                tracing::error!("Error sending message: {e:?}");
410            }
411        } else {
412            return Err(CoinbaseIntxWsError::ClientError(
413                "Cannot send message: not connected".to_string(),
414            ));
415        }
416
417        Ok(())
418    }
419
420    /// Resubscribes for all active subscriptions.
421    async fn resubscribe_all(&self) {
422        let mut subs = Vec::new();
423        for entry in self.subscriptions.iter() {
424            let (channel, product_ids) = entry.pair();
425            if !product_ids.is_empty() {
426                subs.push((*channel, product_ids.clone()));
427            }
428        }
429
430        for (channel, product_ids) in subs {
431            tracing::debug!("Resubscribing: channel={channel}, product_ids={product_ids:?}");
432
433            if let Err(e) = self
434                .subscribe(vec![channel], product_ids.into_iter().collect())
435                .await
436            {
437                tracing::error!("Failed to resubscribe to channel {channel}: {e}");
438            }
439        }
440    }
441
442    /// Subscribes to instrument definition updates for the given instrument IDs.
443    /// Subscribes to instrument updates for the specified instruments.
444    ///
445    /// # Errors
446    ///
447    /// Returns an error if the subscription fails.
448    pub async fn subscribe_instruments(
449        &self,
450        instrument_ids: Vec<InstrumentId>,
451    ) -> Result<(), CoinbaseIntxWsError> {
452        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
453        self.subscribe(vec![CoinbaseIntxWsChannel::Instruments], product_ids)
454            .await
455    }
456
457    /// Subscribes to funding message streams for the given instrument IDs.
458    /// Subscribes to funding rate updates for the specified instruments.
459    ///
460    /// # Errors
461    ///
462    /// Returns an error if the subscription fails.
463    pub async fn subscribe_funding_rates(
464        &self,
465        instrument_ids: Vec<InstrumentId>,
466    ) -> Result<(), CoinbaseIntxWsError> {
467        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
468        self.subscribe(vec![CoinbaseIntxWsChannel::Funding], product_ids)
469            .await
470    }
471
472    /// Subscribes to risk message streams for the given instrument IDs.
473    /// Subscribes to risk updates for the specified instruments.
474    ///
475    /// # Errors
476    ///
477    /// Returns an error if the subscription fails.
478    pub async fn subscribe_risk(
479        &self,
480        instrument_ids: Vec<InstrumentId>,
481    ) -> Result<(), CoinbaseIntxWsError> {
482        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
483        self.subscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
484            .await
485    }
486
487    /// Subscribes to order book (level 2) streams for the given instrument IDs.
488    /// Subscribes to order book snapshots and updates for the specified instruments.
489    ///
490    /// # Errors
491    ///
492    /// Returns an error if the subscription fails.
493    pub async fn subscribe_book(
494        &self,
495        instrument_ids: Vec<InstrumentId>,
496    ) -> Result<(), CoinbaseIntxWsError> {
497        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
498        self.subscribe(vec![CoinbaseIntxWsChannel::Level2], product_ids)
499            .await
500    }
501
502    /// Subscribes to quote (level 1) streams for the given instrument IDs.
503    /// Subscribes to top-of-book quote updates for the specified instruments.
504    ///
505    /// # Errors
506    ///
507    /// Returns an error if the subscription fails.
508    pub async fn subscribe_quotes(
509        &self,
510        instrument_ids: Vec<InstrumentId>,
511    ) -> Result<(), CoinbaseIntxWsError> {
512        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
513        self.subscribe(vec![CoinbaseIntxWsChannel::Level1], product_ids)
514            .await
515    }
516
517    /// Subscribes to trade (match) streams for the given instrument IDs.
518    /// Subscribes to trade updates for the specified instruments.
519    ///
520    /// # Errors
521    ///
522    /// Returns an error if the subscription fails.
523    pub async fn subscribe_trades(
524        &self,
525        instrument_ids: Vec<InstrumentId>,
526    ) -> Result<(), CoinbaseIntxWsError> {
527        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
528        self.subscribe(vec![CoinbaseIntxWsChannel::Match], product_ids)
529            .await
530    }
531
532    /// Subscribes to risk streams (for mark prices) for the given instrument IDs.
533    /// Subscribes to mark price updates for the specified instruments.
534    ///
535    /// # Errors
536    ///
537    /// Returns an error if the subscription fails.
538    pub async fn subscribe_mark_prices(
539        &self,
540        instrument_ids: Vec<InstrumentId>,
541    ) -> Result<(), CoinbaseIntxWsError> {
542        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
543        self.subscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
544            .await
545    }
546
547    /// Subscribes to risk streams (for index prices) for the given instrument IDs.
548    /// Subscribes to index price updates for the specified instruments.
549    ///
550    /// # Errors
551    ///
552    /// Returns an error if the subscription fails.
553    pub async fn subscribe_index_prices(
554        &self,
555        instrument_ids: Vec<InstrumentId>,
556    ) -> Result<(), CoinbaseIntxWsError> {
557        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
558        self.subscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
559            .await
560    }
561
562    /// Subscribes to bar (candle) streams for the given instrument IDs.
563    /// Subscribes to candlestick bar updates for the specified bar type.
564    ///
565    /// # Errors
566    ///
567    /// Returns an error if the subscription fails.
568    pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), CoinbaseIntxWsError> {
569        let channel = bar_spec_as_coinbase_channel(bar_type.spec())
570            .map_err(|e| CoinbaseIntxWsError::ClientError(e.to_string()))?;
571        let product_ids = vec![bar_type.standard().instrument_id().symbol.inner()];
572        self.subscribe(vec![channel], product_ids).await
573    }
574
575    /// Unsubscribes from instrument definition streams for the given instrument IDs.
576    /// Unsubscribes from instrument updates for the specified instruments.
577    ///
578    /// # Errors
579    ///
580    /// Returns an error if the unsubscription fails.
581    pub async fn unsubscribe_instruments(
582        &self,
583        instrument_ids: Vec<InstrumentId>,
584    ) -> Result<(), CoinbaseIntxWsError> {
585        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
586        self.unsubscribe(vec![CoinbaseIntxWsChannel::Instruments], product_ids)
587            .await
588    }
589
590    /// Unsubscribes from risk message streams for the given instrument IDs.
591    /// Unsubscribes from risk updates for the specified instruments.
592    ///
593    /// # Errors
594    ///
595    /// Returns an error if the unsubscription fails.
596    pub async fn unsubscribe_risk(
597        &self,
598        instrument_ids: Vec<InstrumentId>,
599    ) -> Result<(), CoinbaseIntxWsError> {
600        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
601        self.unsubscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
602            .await
603    }
604
605    /// Unsubscribes from funding message streams for the given instrument IDs.
606    /// Unsubscribes from funding updates for the specified instruments.
607    ///
608    /// # Errors
609    ///
610    /// Returns an error if the unsubscription fails.
611    pub async fn unsubscribe_funding(
612        &self,
613        instrument_ids: Vec<InstrumentId>,
614    ) -> Result<(), CoinbaseIntxWsError> {
615        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
616        self.unsubscribe(vec![CoinbaseIntxWsChannel::Funding], product_ids)
617            .await
618    }
619
620    /// Unsubscribes from order book (level 2) streams for the given instrument IDs.
621    /// Unsubscribes from order book updates for the specified instruments.
622    ///
623    /// # Errors
624    ///
625    /// Returns an error if the unsubscription fails.
626    pub async fn unsubscribe_book(
627        &self,
628        instrument_ids: Vec<InstrumentId>,
629    ) -> Result<(), CoinbaseIntxWsError> {
630        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
631        self.unsubscribe(vec![CoinbaseIntxWsChannel::Level2], product_ids)
632            .await
633    }
634
635    /// Unsubscribes from quote (level 1) streams for the given instrument IDs.
636    /// Unsubscribes from quote updates for the specified instruments.
637    ///
638    /// # Errors
639    ///
640    /// Returns an error if the unsubscription fails.
641    pub async fn unsubscribe_quotes(
642        &self,
643        instrument_ids: Vec<InstrumentId>,
644    ) -> Result<(), CoinbaseIntxWsError> {
645        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
646        self.unsubscribe(vec![CoinbaseIntxWsChannel::Level1], product_ids)
647            .await
648    }
649
650    /// Unsubscribes from trade (match) streams for the given instrument IDs.
651    /// Unsubscribes from trade updates for the specified instruments.
652    ///
653    /// # Errors
654    ///
655    /// Returns an error if the unsubscription fails.
656    pub async fn unsubscribe_trades(
657        &self,
658        instrument_ids: Vec<InstrumentId>,
659    ) -> Result<(), CoinbaseIntxWsError> {
660        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
661        self.unsubscribe(vec![CoinbaseIntxWsChannel::Match], product_ids)
662            .await
663    }
664
665    /// Unsubscribes from risk streams (for mark prices) for the given instrument IDs.
666    /// Unsubscribes from mark price updates for the specified instruments.
667    ///
668    /// # Errors
669    ///
670    /// Returns an error if the unsubscription fails.
671    pub async fn unsubscribe_mark_prices(
672        &self,
673        instrument_ids: Vec<InstrumentId>,
674    ) -> Result<(), CoinbaseIntxWsError> {
675        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
676        self.unsubscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
677            .await
678    }
679
680    /// Unsubscribes from risk streams (for index prices) for the given instrument IDs.
681    /// Unsubscribes from index price updates for the specified instruments.
682    ///
683    /// # Errors
684    ///
685    /// Returns an error if the unsubscription fails.
686    pub async fn unsubscribe_index_prices(
687        &self,
688        instrument_ids: Vec<InstrumentId>,
689    ) -> Result<(), CoinbaseIntxWsError> {
690        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
691        self.unsubscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
692            .await
693    }
694
695    /// Unsubscribes from bar (candle) streams for the given instrument IDs.
696    /// Unsubscribes from bar updates for the specified bar type.
697    ///
698    /// # Errors
699    ///
700    /// Returns an error if the unsubscription fails.
701    pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), CoinbaseIntxWsError> {
702        let channel = bar_spec_as_coinbase_channel(bar_type.spec())
703            .map_err(|e| CoinbaseIntxWsError::ClientError(e.to_string()))?;
704        let product_id = bar_type.standard().instrument_id().symbol.inner();
705        self.unsubscribe(vec![channel], vec![product_id]).await
706    }
707}
708
709fn instrument_ids_to_product_ids(instrument_ids: &[InstrumentId]) -> Vec<Ustr> {
710    instrument_ids.iter().map(|x| x.symbol.inner()).collect()
711}
712
713/// Provides a raw message handler for Coinbase International WebSocket feed.
714struct CoinbaseIntxFeedHandler {
715    reader: MessageReader,
716    signal: Arc<AtomicBool>,
717}
718
719impl CoinbaseIntxFeedHandler {
720    /// Creates a new [`CoinbaseIntxFeedHandler`] instance.
721    pub const fn new(reader: MessageReader, signal: Arc<AtomicBool>) -> Self {
722        Self { reader, signal }
723    }
724
725    /// Gets the next message from the WebSocket message stream.
726    async fn next(&mut self) -> Option<CoinbaseIntxWsMessage> {
727        // Timeout awaiting the next message before checking signal
728        let timeout = Duration::from_millis(10);
729
730        loop {
731            if self.signal.load(Ordering::Relaxed) {
732                tracing::debug!("Stop signal received");
733                break;
734            }
735
736            match tokio::time::timeout(timeout, self.reader.next()).await {
737                Ok(Some(msg)) => match msg {
738                    Ok(Message::Pong(_)) => {
739                        tracing::trace!("Received pong");
740                    }
741                    Ok(Message::Ping(_)) => {
742                        tracing::trace!("Received pong"); // Coinbase send ping frames as pongs
743                    }
744                    Ok(Message::Text(text)) => {
745                        match serde_json::from_str(&text) {
746                            Ok(event) => match &event {
747                                CoinbaseIntxWsMessage::Reject(msg) => {
748                                    tracing::error!("{msg:?}");
749                                }
750                                CoinbaseIntxWsMessage::Confirmation(msg) => {
751                                    tracing::debug!("{msg:?}");
752                                    continue;
753                                }
754                                CoinbaseIntxWsMessage::Instrument(_) => return Some(event),
755                                CoinbaseIntxWsMessage::Funding(_) => return Some(event),
756                                CoinbaseIntxWsMessage::Risk(_) => return Some(event),
757                                CoinbaseIntxWsMessage::BookSnapshot(_) => return Some(event),
758                                CoinbaseIntxWsMessage::BookUpdate(_) => return Some(event),
759                                CoinbaseIntxWsMessage::Quote(_) => return Some(event),
760                                CoinbaseIntxWsMessage::Trade(_) => return Some(event),
761                                CoinbaseIntxWsMessage::CandleSnapshot(_) => return Some(event),
762                                CoinbaseIntxWsMessage::CandleUpdate(_) => continue, // Ignore
763                            },
764                            Err(e) => {
765                                tracing::error!("Failed to parse message: {e}: {text}");
766                                break;
767                            }
768                        }
769                    }
770                    Ok(Message::Binary(msg)) => {
771                        tracing::debug!("Raw binary: {msg:?}");
772                    }
773                    Ok(Message::Close(_)) => {
774                        tracing::debug!("Received close message");
775                        return None;
776                    }
777                    Ok(msg) => {
778                        tracing::warn!("Unexpected message: {msg:?}");
779                    }
780                    Err(e) => {
781                        tracing::error!("{e}, stopping client");
782                        break; // Break as indicates a bug in the code
783                    }
784                },
785                Ok(None) => {
786                    tracing::info!("WebSocket stream closed");
787                    break;
788                }
789                Err(_) => {} // Timeout occurred awaiting a message, continue loop to check signal
790            }
791        }
792
793        log_task_stopped("message-streaming");
794        None
795    }
796}
797
798/// Provides a Nautilus parser for the Coinbase International WebSocket feed.
799struct CoinbaseIntxWsMessageHandler {
800    handler: CoinbaseIntxFeedHandler,
801    tx: tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>,
802    instruments_cache: AHashMap<Ustr, InstrumentAny>,
803}
804
805impl CoinbaseIntxWsMessageHandler {
806    /// Creates a new [`CoinbaseIntxWsMessageHandler`] instance.
807    pub const fn new(
808        reader: MessageReader,
809        signal: Arc<AtomicBool>,
810        tx: tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>,
811        instruments_cache: AHashMap<Ustr, InstrumentAny>,
812    ) -> Self {
813        let handler = CoinbaseIntxFeedHandler::new(reader, signal);
814        Self {
815            handler,
816            tx,
817            instruments_cache,
818        }
819    }
820
821    /// Runs the WebSocket message feed.
822    async fn run(&mut self) {
823        while let Some(data) = self.next().await {
824            if let Err(e) = self.tx.send(data) {
825                tracing::error!("Error sending data: {e}");
826                break; // Stop processing on channel error
827            }
828        }
829    }
830
831    /// Gets the next message from the WebSocket message handler.
832    async fn next(&mut self) -> Option<NautilusWsMessage> {
833        let clock = get_atomic_clock_realtime();
834
835        while let Some(event) = self.handler.next().await {
836            match event {
837                CoinbaseIntxWsMessage::Instrument(msg) => {
838                    if let Some(inst) = parse_instrument_any(&msg, clock.get_time_ns()) {
839                        // Update instruments map
840                        self.instruments_cache
841                            .insert(inst.raw_symbol().inner(), inst.clone());
842                        return Some(NautilusWsMessage::Instrument(inst));
843                    }
844                }
845                CoinbaseIntxWsMessage::Funding(msg) => {
846                    tracing::warn!("Received {msg:?}"); // TODO: Implement
847                }
848                CoinbaseIntxWsMessage::BookSnapshot(msg) => {
849                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
850                        match parse_orderbook_snapshot_msg(
851                            &msg,
852                            inst.id(),
853                            inst.price_precision(),
854                            inst.size_precision(),
855                            clock.get_time_ns(),
856                        ) {
857                            Ok(deltas) => {
858                                let deltas = OrderBookDeltas_API::new(deltas);
859                                let data = Data::Deltas(deltas);
860                                return Some(NautilusWsMessage::Data(data));
861                            }
862                            Err(e) => {
863                                tracing::error!("Failed to parse orderbook snapshot: {e}");
864                                return None;
865                            }
866                        }
867                    }
868                    tracing::error!("No instrument found for {}", msg.product_id);
869                    return None;
870                }
871                CoinbaseIntxWsMessage::BookUpdate(msg) => {
872                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
873                        match parse_orderbook_update_msg(
874                            &msg,
875                            inst.id(),
876                            inst.price_precision(),
877                            inst.size_precision(),
878                            clock.get_time_ns(),
879                        ) {
880                            Ok(deltas) => {
881                                let deltas = OrderBookDeltas_API::new(deltas);
882                                let data = Data::Deltas(deltas);
883                                return Some(NautilusWsMessage::Data(data));
884                            }
885                            Err(e) => {
886                                tracing::error!("Failed to parse orderbook update: {e}");
887                            }
888                        }
889                    } else {
890                        tracing::error!("No instrument found for {}", msg.product_id);
891                    }
892                }
893                CoinbaseIntxWsMessage::Quote(msg) => {
894                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
895                        match parse_quote_msg(
896                            &msg,
897                            inst.id(),
898                            inst.price_precision(),
899                            inst.size_precision(),
900                            clock.get_time_ns(),
901                        ) {
902                            Ok(quote) => return Some(NautilusWsMessage::Data(Data::Quote(quote))),
903                            Err(e) => {
904                                tracing::error!("Failed to parse quote: {e}");
905                            }
906                        }
907                    } else {
908                        tracing::error!("No instrument found for {}", msg.product_id);
909                    }
910                }
911                CoinbaseIntxWsMessage::Trade(msg) => {
912                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
913                        match parse_trade_msg(
914                            &msg,
915                            inst.id(),
916                            inst.price_precision(),
917                            inst.size_precision(),
918                            clock.get_time_ns(),
919                        ) {
920                            Ok(trade) => return Some(NautilusWsMessage::Data(Data::Trade(trade))),
921                            Err(e) => {
922                                tracing::error!("Failed to parse trade: {e}");
923                            }
924                        }
925                    } else {
926                        tracing::error!("No instrument found for {}", msg.product_id);
927                    }
928                }
929                CoinbaseIntxWsMessage::Risk(msg) => {
930                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
931                        let mark_price = match parse_mark_price_msg(
932                            &msg,
933                            inst.id(),
934                            inst.price_precision(),
935                            clock.get_time_ns(),
936                        ) {
937                            Ok(mark_price) => Some(mark_price),
938                            Err(e) => {
939                                tracing::error!("Failed to parse mark price: {e}");
940                                None
941                            }
942                        };
943
944                        let index_price = match parse_index_price_msg(
945                            &msg,
946                            inst.id(),
947                            inst.price_precision(),
948                            clock.get_time_ns(),
949                        ) {
950                            Ok(index_price) => Some(index_price),
951                            Err(e) => {
952                                tracing::error!("Failed to parse index price: {e}");
953                                None
954                            }
955                        };
956
957                        match (mark_price, index_price) {
958                            (Some(mark), Some(index)) => {
959                                return Some(NautilusWsMessage::MarkAndIndex((mark, index)));
960                            }
961                            (Some(mark), None) => return Some(NautilusWsMessage::MarkPrice(mark)),
962                            (None, Some(index)) => {
963                                return Some(NautilusWsMessage::IndexPrice(index));
964                            }
965                            (None, None) => continue,
966                        };
967                    }
968                    tracing::error!("No instrument found for {}", msg.product_id);
969                }
970                CoinbaseIntxWsMessage::CandleSnapshot(msg) => {
971                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
972                        match parse_candle_msg(
973                            &msg,
974                            inst.id(),
975                            inst.price_precision(),
976                            inst.size_precision(),
977                            clock.get_time_ns(),
978                        ) {
979                            Ok(bar) => return Some(NautilusWsMessage::Data(Data::Bar(bar))),
980                            Err(e) => {
981                                tracing::error!("Failed to parse candle: {e}");
982                            }
983                        }
984                    } else {
985                        tracing::error!("No instrument found for {}", msg.product_id);
986                    }
987                }
988                _ => {
989                    tracing::warn!("Not implemented: {event:?}");
990                }
991            }
992        }
993        None // Connection closed
994    }
995}