nautilus_coinbase_intx/websocket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    sync::{
18        Arc,
19        atomic::{AtomicBool, Ordering},
20    },
21    time::{Duration, SystemTime},
22};
23
24use ahash::{AHashMap, AHashSet};
25use chrono::Utc;
26use dashmap::DashMap;
27use futures_util::{Stream, StreamExt};
28use nautilus_common::{live::runtime::get_runtime, logging::log_task_stopped};
29use nautilus_core::{
30    consts::NAUTILUS_USER_AGENT, env::get_or_env_var, time::get_atomic_clock_realtime,
31};
32use nautilus_model::{
33    data::{BarType, Data, OrderBookDeltas_API},
34    identifiers::InstrumentId,
35    instruments::{Instrument, InstrumentAny},
36};
37use nautilus_network::{
38    http::USER_AGENT,
39    websocket::{MessageReader, WebSocketClient, WebSocketConfig},
40};
41use tokio_tungstenite::tungstenite::{Error, Message};
42use ustr::Ustr;
43
44use super::{
45    enums::{CoinbaseIntxWsChannel, WsOperation},
46    error::CoinbaseIntxWsError,
47    messages::{CoinbaseIntxSubscription, CoinbaseIntxWsMessage, NautilusWsMessage},
48    parse::{
49        parse_candle_msg, parse_index_price_msg, parse_mark_price_msg,
50        parse_orderbook_snapshot_msg, parse_orderbook_update_msg, parse_quote_msg,
51    },
52};
53use crate::{
54    common::{
55        consts::COINBASE_INTX_WS_URL, credential::Credential, parse::bar_spec_as_coinbase_channel,
56    },
57    websocket::parse::{parse_instrument_any, parse_trade_msg},
58};
59
60/// Provides a WebSocket client for connecting to [Coinbase International](https://www.coinbase.com/en/international-exchange).
61#[derive(Debug, Clone)]
62#[cfg_attr(
63    feature = "python",
64    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.adapters")
65)]
66pub struct CoinbaseIntxWebSocketClient {
67    url: String,
68    credential: Credential,
69    heartbeat: Option<u64>,
70    inner: Arc<tokio::sync::RwLock<Option<WebSocketClient>>>,
71    rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>>,
72    signal: Arc<AtomicBool>,
73    task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
74    subscriptions: Arc<DashMap<CoinbaseIntxWsChannel, AHashSet<Ustr>>>,
75    instruments_cache: Arc<AHashMap<Ustr, InstrumentAny>>,
76}
77
78impl Default for CoinbaseIntxWebSocketClient {
79    fn default() -> Self {
80        Self::new(None, None, None, None, Some(10)).expect("Failed to create client")
81    }
82}
83
84impl CoinbaseIntxWebSocketClient {
85    /// Creates a new [`CoinbaseIntxWebSocketClient`] instance.
86    ///
87    /// # Errors
88    ///
89    /// Returns an error if required environment variables are missing or invalid.
90    pub fn new(
91        url: Option<String>,
92        api_key: Option<String>,
93        api_secret: Option<String>,
94        api_passphrase: Option<String>,
95        heartbeat: Option<u64>,
96    ) -> anyhow::Result<Self> {
97        let url = url.unwrap_or(COINBASE_INTX_WS_URL.to_string());
98        let api_key = get_or_env_var(api_key, "COINBASE_INTX_API_KEY")?;
99        let api_secret = get_or_env_var(api_secret, "COINBASE_INTX_API_SECRET")?;
100        let api_passphrase = get_or_env_var(api_passphrase, "COINBASE_INTX_API_PASSPHRASE")?;
101
102        let credential = Credential::new(api_key, api_secret, api_passphrase);
103        let signal = Arc::new(AtomicBool::new(false));
104        let subscriptions = Arc::new(DashMap::new());
105        let instruments_cache = Arc::new(AHashMap::new());
106
107        Ok(Self {
108            url,
109            credential,
110            heartbeat,
111            inner: Arc::new(tokio::sync::RwLock::new(None)),
112            rx: None,
113            signal,
114            task_handle: None,
115            subscriptions,
116            instruments_cache,
117        })
118    }
119
120    /// Creates a new authenticated [`CoinbaseIntxWebSocketClient`] using environment variables and
121    /// the default Coinbase International production websocket url.
122    ///
123    /// # Errors
124    ///
125    /// Returns an error if required environment variables are missing or invalid.
126    pub fn from_env() -> anyhow::Result<Self> {
127        Self::new(None, None, None, None, None)
128    }
129
130    /// Returns the websocket url being used by the client.
131    #[must_use]
132    pub const fn url(&self) -> &str {
133        self.url.as_str()
134    }
135
136    /// Returns the public API key being used by the client.
137    #[must_use]
138    pub fn api_key(&self) -> &str {
139        self.credential.api_key.as_str()
140    }
141
142    /// Returns a masked version of the API key for logging purposes.
143    #[must_use]
144    pub fn api_key_masked(&self) -> String {
145        self.credential.api_key_masked()
146    }
147
148    /// Returns a value indicating whether the client is active.
149    #[must_use]
150    pub fn is_active(&self) -> bool {
151        self.inner
152            .try_read()
153            .ok()
154            .and_then(|guard| guard.as_ref().map(WebSocketClient::is_active))
155            .unwrap_or(false)
156    }
157
158    /// Returns a value indicating whether the client is closed.
159    #[must_use]
160    pub fn is_closed(&self) -> bool {
161        self.inner
162            .try_read()
163            .ok()
164            .and_then(|guard| guard.as_ref().map(WebSocketClient::is_closed))
165            .unwrap_or(true)
166    }
167
168    /// Initialize the instruments cache with the given `instruments`.
169    pub fn cache_instruments(&mut self, instruments: Vec<InstrumentAny>) {
170        let mut instruments_cache: AHashMap<Ustr, InstrumentAny> = AHashMap::new();
171
172        for inst in instruments {
173            instruments_cache.insert(inst.symbol().inner(), inst.clone());
174        }
175
176        self.instruments_cache = Arc::new(instruments_cache);
177    }
178
179    /// Get active subscriptions for a specific instrument.
180    #[must_use]
181    pub fn get_subscriptions(&self, instrument_id: InstrumentId) -> Vec<CoinbaseIntxWsChannel> {
182        let product_id = instrument_id.symbol.inner();
183        let mut channels = Vec::new();
184
185        for entry in self.subscriptions.iter() {
186            let (channel, instruments) = entry.pair();
187            if instruments.contains(&product_id) {
188                channels.push(*channel);
189            }
190        }
191
192        channels
193    }
194
195    /// Connects the client to the server and caches the given instruments.
196    ///
197    /// # Errors
198    ///
199    /// Returns an error if the WebSocket connection or initial subscription fails.
200    pub async fn connect(&mut self) -> anyhow::Result<()> {
201        let client = self.clone();
202        let post_reconnect = Arc::new(move || {
203            let client = client.clone();
204
205            tokio::spawn(async move { client.resubscribe_all().await });
206        });
207
208        let config = WebSocketConfig {
209            url: self.url.clone(),
210            headers: vec![(USER_AGENT.to_string(), NAUTILUS_USER_AGENT.to_string())],
211            message_handler: None, // Will be handled by the returned reader
212            heartbeat: self.heartbeat,
213            heartbeat_msg: None,
214            ping_handler: None,
215            reconnect_timeout_ms: Some(5_000),
216            reconnect_delay_initial_ms: None, // Use default
217            reconnect_delay_max_ms: None,     // Use default
218            reconnect_backoff_factor: None,   // Use default
219            reconnect_jitter_ms: None,        // Use default
220            reconnect_max_attempts: None,
221        };
222        let (reader, client) =
223            WebSocketClient::connect_stream(config, vec![], None, Some(post_reconnect)).await?;
224
225        *self.inner.write().await = Some(client);
226
227        let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
228        self.rx = Some(Arc::new(rx));
229        let signal = self.signal.clone();
230
231        // TODO: For now just clone the entire cache out of the arc on connect
232        let instruments_cache = (*self.instruments_cache).clone();
233
234        let stream_handle = get_runtime().spawn(async move {
235            CoinbaseIntxWsMessageHandler::new(reader, signal, tx, instruments_cache)
236                .run()
237                .await;
238        });
239
240        self.task_handle = Some(Arc::new(stream_handle));
241
242        Ok(())
243    }
244
245    /// Wait until the WebSocket connection is active.
246    ///
247    /// # Errors
248    ///
249    /// Returns an error if the connection times out.
250    pub async fn wait_until_active(&self, timeout_secs: f64) -> Result<(), CoinbaseIntxWsError> {
251        let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
252
253        tokio::time::timeout(timeout, async {
254            while !self.is_active() {
255                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
256            }
257        })
258        .await
259        .map_err(|_| {
260            CoinbaseIntxWsError::ClientError(format!(
261                "WebSocket connection timeout after {timeout_secs} seconds"
262            ))
263        })?;
264
265        Ok(())
266    }
267
268    /// Provides the internal data stream as a channel-based stream.
269    ///
270    /// # Panics
271    ///
272    /// This function panics if:
273    /// - The websocket is not connected.
274    /// - If `stream_data` has already been called somewhere else (stream receiver is then taken).
275    pub fn stream(&mut self) -> impl Stream<Item = NautilusWsMessage> + 'static {
276        let rx = self
277            .rx
278            .take()
279            .expect("Data stream receiver already taken or not connected"); // Design-time error
280        let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
281        async_stream::stream! {
282            while let Some(data) = rx.recv().await {
283                yield data;
284            }
285        }
286    }
287
288    /// Closes the client.
289    ///
290    /// # Errors
291    ///
292    /// Returns an error if the WebSocket fails to close properly.
293    pub async fn close(&mut self) -> Result<(), Error> {
294        tracing::debug!("Closing");
295        self.signal.store(true, Ordering::Relaxed);
296
297        match tokio::time::timeout(Duration::from_secs(5), async {
298            if let Some(inner) = self.inner.read().await.as_ref() {
299                inner.disconnect().await;
300            } else {
301                log::error!("Error on close: not connected");
302            }
303        })
304        .await
305        {
306            Ok(()) => {
307                tracing::debug!("Inner disconnected");
308            }
309            Err(_) => {
310                tracing::error!("Timeout waiting for inner client to disconnect");
311            }
312        }
313
314        log::debug!("Closed");
315
316        Ok(())
317    }
318
319    /// Subscribes to the given channels and product IDs.
320    ///
321    /// # Errors
322    ///
323    /// Returns an error if the subscription message cannot be sent.
324    async fn subscribe(
325        &self,
326        channels: Vec<CoinbaseIntxWsChannel>,
327        product_ids: Vec<Ustr>,
328    ) -> Result<(), CoinbaseIntxWsError> {
329        // Update active subscriptions
330        for channel in &channels {
331            self.subscriptions
332                .entry(*channel)
333                .or_default()
334                .extend(product_ids.clone());
335        }
336        tracing::debug!(
337            "Added active subscription(s): channels={channels:?}, product_ids={product_ids:?}"
338        );
339
340        let time = chrono::DateTime::<Utc>::from(SystemTime::now())
341            .timestamp()
342            .to_string();
343        let signature = self.credential.sign_ws(&time);
344        let message = CoinbaseIntxSubscription {
345            op: WsOperation::Subscribe,
346            product_ids: Some(product_ids),
347            channels,
348            time,
349            key: self.credential.api_key,
350            passphrase: self.credential.api_passphrase,
351            signature,
352        };
353
354        let json_txt = serde_json::to_string(&message)
355            .map_err(|e| CoinbaseIntxWsError::JsonError(e.to_string()))?;
356
357        if let Some(inner) = self.inner.read().await.as_ref() {
358            if let Err(e) = inner.send_text(json_txt, None).await {
359                tracing::error!("Error sending message: {e:?}");
360            }
361        } else {
362            return Err(CoinbaseIntxWsError::ClientError(
363                "Cannot send message: not connected".to_string(),
364            ));
365        }
366
367        Ok(())
368    }
369
370    /// Unsubscribes from the given channels and product IDs.
371    async fn unsubscribe(
372        &self,
373        channels: Vec<CoinbaseIntxWsChannel>,
374        product_ids: Vec<Ustr>,
375    ) -> Result<(), CoinbaseIntxWsError> {
376        // Update active subscriptions
377        for channel in &channels {
378            if let Some(mut entry) = self.subscriptions.get_mut(channel) {
379                for product_id in &product_ids {
380                    entry.remove(product_id);
381                }
382                if entry.is_empty() {
383                    drop(entry);
384                    self.subscriptions.remove(channel);
385                }
386            }
387        }
388        tracing::debug!(
389            "Removed active subscription(s): channels={channels:?}, product_ids={product_ids:?}"
390        );
391
392        let time = chrono::DateTime::<Utc>::from(SystemTime::now())
393            .timestamp()
394            .to_string();
395        let signature = self.credential.sign_ws(&time);
396        let message = CoinbaseIntxSubscription {
397            op: WsOperation::Unsubscribe,
398            product_ids: Some(product_ids),
399            channels,
400            time,
401            key: self.credential.api_key,
402            passphrase: self.credential.api_passphrase,
403            signature,
404        };
405
406        let json_txt = serde_json::to_string(&message)
407            .map_err(|e| CoinbaseIntxWsError::JsonError(e.to_string()))?;
408
409        if let Some(inner) = self.inner.read().await.as_ref() {
410            if let Err(e) = inner.send_text(json_txt, None).await {
411                tracing::error!("Error sending message: {e:?}");
412            }
413        } else {
414            return Err(CoinbaseIntxWsError::ClientError(
415                "Cannot send message: not connected".to_string(),
416            ));
417        }
418
419        Ok(())
420    }
421
422    /// Resubscribes for all active subscriptions.
423    async fn resubscribe_all(&self) {
424        let mut subs = Vec::new();
425        for entry in self.subscriptions.iter() {
426            let (channel, product_ids) = entry.pair();
427            if !product_ids.is_empty() {
428                subs.push((*channel, product_ids.clone()));
429            }
430        }
431
432        for (channel, product_ids) in subs {
433            tracing::debug!("Resubscribing: channel={channel}, product_ids={product_ids:?}");
434
435            if let Err(e) = self
436                .subscribe(vec![channel], product_ids.into_iter().collect())
437                .await
438            {
439                tracing::error!("Failed to resubscribe to channel {channel}: {e}");
440            }
441        }
442    }
443
444    /// Subscribes to instrument definition updates for the given instrument IDs.
445    /// Subscribes to instrument updates for the specified instruments.
446    ///
447    /// # Errors
448    ///
449    /// Returns an error if the subscription fails.
450    pub async fn subscribe_instruments(
451        &self,
452        instrument_ids: Vec<InstrumentId>,
453    ) -> Result<(), CoinbaseIntxWsError> {
454        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
455        self.subscribe(vec![CoinbaseIntxWsChannel::Instruments], product_ids)
456            .await
457    }
458
459    /// Subscribes to funding message streams for the given instrument IDs.
460    /// Subscribes to funding rate updates for the specified instruments.
461    ///
462    /// # Errors
463    ///
464    /// Returns an error if the subscription fails.
465    pub async fn subscribe_funding_rates(
466        &self,
467        instrument_ids: Vec<InstrumentId>,
468    ) -> Result<(), CoinbaseIntxWsError> {
469        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
470        self.subscribe(vec![CoinbaseIntxWsChannel::Funding], product_ids)
471            .await
472    }
473
474    /// Subscribes to risk message streams for the given instrument IDs.
475    /// Subscribes to risk updates for the specified instruments.
476    ///
477    /// # Errors
478    ///
479    /// Returns an error if the subscription fails.
480    pub async fn subscribe_risk(
481        &self,
482        instrument_ids: Vec<InstrumentId>,
483    ) -> Result<(), CoinbaseIntxWsError> {
484        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
485        self.subscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
486            .await
487    }
488
489    /// Subscribes to order book (level 2) streams for the given instrument IDs.
490    /// Subscribes to order book snapshots and updates for the specified instruments.
491    ///
492    /// # Errors
493    ///
494    /// Returns an error if the subscription fails.
495    pub async fn subscribe_book(
496        &self,
497        instrument_ids: Vec<InstrumentId>,
498    ) -> Result<(), CoinbaseIntxWsError> {
499        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
500        self.subscribe(vec![CoinbaseIntxWsChannel::Level2], product_ids)
501            .await
502    }
503
504    /// Subscribes to quote (level 1) streams for the given instrument IDs.
505    /// Subscribes to top-of-book quote updates for the specified instruments.
506    ///
507    /// # Errors
508    ///
509    /// Returns an error if the subscription fails.
510    pub async fn subscribe_quotes(
511        &self,
512        instrument_ids: Vec<InstrumentId>,
513    ) -> Result<(), CoinbaseIntxWsError> {
514        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
515        self.subscribe(vec![CoinbaseIntxWsChannel::Level1], product_ids)
516            .await
517    }
518
519    /// Subscribes to trade (match) streams for the given instrument IDs.
520    /// Subscribes to trade updates for the specified instruments.
521    ///
522    /// # Errors
523    ///
524    /// Returns an error if the subscription fails.
525    pub async fn subscribe_trades(
526        &self,
527        instrument_ids: Vec<InstrumentId>,
528    ) -> Result<(), CoinbaseIntxWsError> {
529        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
530        self.subscribe(vec![CoinbaseIntxWsChannel::Match], product_ids)
531            .await
532    }
533
534    /// Subscribes to risk streams (for mark prices) for the given instrument IDs.
535    /// Subscribes to mark price updates for the specified instruments.
536    ///
537    /// # Errors
538    ///
539    /// Returns an error if the subscription fails.
540    pub async fn subscribe_mark_prices(
541        &self,
542        instrument_ids: Vec<InstrumentId>,
543    ) -> Result<(), CoinbaseIntxWsError> {
544        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
545        self.subscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
546            .await
547    }
548
549    /// Subscribes to risk streams (for index prices) for the given instrument IDs.
550    /// Subscribes to index price updates for the specified instruments.
551    ///
552    /// # Errors
553    ///
554    /// Returns an error if the subscription fails.
555    pub async fn subscribe_index_prices(
556        &self,
557        instrument_ids: Vec<InstrumentId>,
558    ) -> Result<(), CoinbaseIntxWsError> {
559        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
560        self.subscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
561            .await
562    }
563
564    /// Subscribes to bar (candle) streams for the given instrument IDs.
565    /// Subscribes to candlestick bar updates for the specified bar type.
566    ///
567    /// # Errors
568    ///
569    /// Returns an error if the subscription fails.
570    pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), CoinbaseIntxWsError> {
571        let channel = bar_spec_as_coinbase_channel(bar_type.spec())
572            .map_err(|e| CoinbaseIntxWsError::ClientError(e.to_string()))?;
573        let product_ids = vec![bar_type.standard().instrument_id().symbol.inner()];
574        self.subscribe(vec![channel], product_ids).await
575    }
576
577    /// Unsubscribes from instrument definition streams for the given instrument IDs.
578    /// Unsubscribes from instrument updates for the specified instruments.
579    ///
580    /// # Errors
581    ///
582    /// Returns an error if the unsubscription fails.
583    pub async fn unsubscribe_instruments(
584        &self,
585        instrument_ids: Vec<InstrumentId>,
586    ) -> Result<(), CoinbaseIntxWsError> {
587        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
588        self.unsubscribe(vec![CoinbaseIntxWsChannel::Instruments], product_ids)
589            .await
590    }
591
592    /// Unsubscribes from risk message streams for the given instrument IDs.
593    /// Unsubscribes from risk updates for the specified instruments.
594    ///
595    /// # Errors
596    ///
597    /// Returns an error if the unsubscription fails.
598    pub async fn unsubscribe_risk(
599        &self,
600        instrument_ids: Vec<InstrumentId>,
601    ) -> Result<(), CoinbaseIntxWsError> {
602        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
603        self.unsubscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
604            .await
605    }
606
607    /// Unsubscribes from funding message streams for the given instrument IDs.
608    /// Unsubscribes from funding updates for the specified instruments.
609    ///
610    /// # Errors
611    ///
612    /// Returns an error if the unsubscription fails.
613    pub async fn unsubscribe_funding(
614        &self,
615        instrument_ids: Vec<InstrumentId>,
616    ) -> Result<(), CoinbaseIntxWsError> {
617        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
618        self.unsubscribe(vec![CoinbaseIntxWsChannel::Funding], product_ids)
619            .await
620    }
621
622    /// Unsubscribes from order book (level 2) streams for the given instrument IDs.
623    /// Unsubscribes from order book updates for the specified instruments.
624    ///
625    /// # Errors
626    ///
627    /// Returns an error if the unsubscription fails.
628    pub async fn unsubscribe_book(
629        &self,
630        instrument_ids: Vec<InstrumentId>,
631    ) -> Result<(), CoinbaseIntxWsError> {
632        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
633        self.unsubscribe(vec![CoinbaseIntxWsChannel::Level2], product_ids)
634            .await
635    }
636
637    /// Unsubscribes from quote (level 1) streams for the given instrument IDs.
638    /// Unsubscribes from quote updates for the specified instruments.
639    ///
640    /// # Errors
641    ///
642    /// Returns an error if the unsubscription fails.
643    pub async fn unsubscribe_quotes(
644        &self,
645        instrument_ids: Vec<InstrumentId>,
646    ) -> Result<(), CoinbaseIntxWsError> {
647        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
648        self.unsubscribe(vec![CoinbaseIntxWsChannel::Level1], product_ids)
649            .await
650    }
651
652    /// Unsubscribes from trade (match) streams for the given instrument IDs.
653    /// Unsubscribes from trade updates for the specified instruments.
654    ///
655    /// # Errors
656    ///
657    /// Returns an error if the unsubscription fails.
658    pub async fn unsubscribe_trades(
659        &self,
660        instrument_ids: Vec<InstrumentId>,
661    ) -> Result<(), CoinbaseIntxWsError> {
662        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
663        self.unsubscribe(vec![CoinbaseIntxWsChannel::Match], product_ids)
664            .await
665    }
666
667    /// Unsubscribes from risk streams (for mark prices) for the given instrument IDs.
668    /// Unsubscribes from mark price updates for the specified instruments.
669    ///
670    /// # Errors
671    ///
672    /// Returns an error if the unsubscription fails.
673    pub async fn unsubscribe_mark_prices(
674        &self,
675        instrument_ids: Vec<InstrumentId>,
676    ) -> Result<(), CoinbaseIntxWsError> {
677        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
678        self.unsubscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
679            .await
680    }
681
682    /// Unsubscribes from risk streams (for index prices) for the given instrument IDs.
683    /// Unsubscribes from index price updates for the specified instruments.
684    ///
685    /// # Errors
686    ///
687    /// Returns an error if the unsubscription fails.
688    pub async fn unsubscribe_index_prices(
689        &self,
690        instrument_ids: Vec<InstrumentId>,
691    ) -> Result<(), CoinbaseIntxWsError> {
692        let product_ids = instrument_ids_to_product_ids(&instrument_ids);
693        self.unsubscribe(vec![CoinbaseIntxWsChannel::Risk], product_ids)
694            .await
695    }
696
697    /// Unsubscribes from bar (candle) streams for the given instrument IDs.
698    /// Unsubscribes from bar updates for the specified bar type.
699    ///
700    /// # Errors
701    ///
702    /// Returns an error if the unsubscription fails.
703    pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), CoinbaseIntxWsError> {
704        let channel = bar_spec_as_coinbase_channel(bar_type.spec())
705            .map_err(|e| CoinbaseIntxWsError::ClientError(e.to_string()))?;
706        let product_id = bar_type.standard().instrument_id().symbol.inner();
707        self.unsubscribe(vec![channel], vec![product_id]).await
708    }
709}
710
711fn instrument_ids_to_product_ids(instrument_ids: &[InstrumentId]) -> Vec<Ustr> {
712    instrument_ids.iter().map(|x| x.symbol.inner()).collect()
713}
714
715/// Provides a raw message handler for Coinbase International WebSocket feed.
716struct CoinbaseIntxFeedHandler {
717    reader: MessageReader,
718    signal: Arc<AtomicBool>,
719}
720
721impl CoinbaseIntxFeedHandler {
722    /// Creates a new [`CoinbaseIntxFeedHandler`] instance.
723    pub const fn new(reader: MessageReader, signal: Arc<AtomicBool>) -> Self {
724        Self { reader, signal }
725    }
726
727    /// Gets the next message from the WebSocket message stream.
728    async fn next(&mut self) -> Option<CoinbaseIntxWsMessage> {
729        // Timeout awaiting the next message before checking signal
730        let timeout = Duration::from_millis(10);
731
732        loop {
733            if self.signal.load(Ordering::Relaxed) {
734                tracing::debug!("Stop signal received");
735                break;
736            }
737
738            match tokio::time::timeout(timeout, self.reader.next()).await {
739                Ok(Some(msg)) => match msg {
740                    Ok(Message::Pong(_)) => {
741                        tracing::trace!("Received pong");
742                    }
743                    Ok(Message::Ping(_)) => {
744                        tracing::trace!("Received pong"); // Coinbase send ping frames as pongs
745                    }
746                    Ok(Message::Text(text)) => {
747                        match serde_json::from_str(&text) {
748                            Ok(event) => match &event {
749                                CoinbaseIntxWsMessage::Reject(msg) => {
750                                    tracing::error!("{msg:?}");
751                                }
752                                CoinbaseIntxWsMessage::Confirmation(msg) => {
753                                    tracing::debug!("{msg:?}");
754                                    continue;
755                                }
756                                CoinbaseIntxWsMessage::Instrument(_) => return Some(event),
757                                CoinbaseIntxWsMessage::Funding(_) => return Some(event),
758                                CoinbaseIntxWsMessage::Risk(_) => return Some(event),
759                                CoinbaseIntxWsMessage::BookSnapshot(_) => return Some(event),
760                                CoinbaseIntxWsMessage::BookUpdate(_) => return Some(event),
761                                CoinbaseIntxWsMessage::Quote(_) => return Some(event),
762                                CoinbaseIntxWsMessage::Trade(_) => return Some(event),
763                                CoinbaseIntxWsMessage::CandleSnapshot(_) => return Some(event),
764                                CoinbaseIntxWsMessage::CandleUpdate(_) => continue, // Ignore
765                            },
766                            Err(e) => {
767                                tracing::error!("Failed to parse message: {e}: {text}");
768                                break;
769                            }
770                        }
771                    }
772                    Ok(Message::Binary(msg)) => {
773                        tracing::debug!("Raw binary: {msg:?}");
774                    }
775                    Ok(Message::Close(_)) => {
776                        tracing::debug!("Received close message");
777                        return None;
778                    }
779                    Ok(msg) => {
780                        tracing::warn!("Unexpected message: {msg:?}");
781                    }
782                    Err(e) => {
783                        tracing::error!("{e}, stopping client");
784                        break; // Break as indicates a bug in the code
785                    }
786                },
787                Ok(None) => {
788                    tracing::info!("WebSocket stream closed");
789                    break;
790                }
791                Err(_) => {} // Timeout occurred awaiting a message, continue loop to check signal
792            }
793        }
794
795        log_task_stopped("message-streaming");
796        None
797    }
798}
799
800/// Provides a Nautilus parser for the Coinbase International WebSocket feed.
801struct CoinbaseIntxWsMessageHandler {
802    handler: CoinbaseIntxFeedHandler,
803    tx: tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>,
804    instruments_cache: AHashMap<Ustr, InstrumentAny>,
805}
806
807impl CoinbaseIntxWsMessageHandler {
808    /// Creates a new [`CoinbaseIntxWsMessageHandler`] instance.
809    pub const fn new(
810        reader: MessageReader,
811        signal: Arc<AtomicBool>,
812        tx: tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>,
813        instruments_cache: AHashMap<Ustr, InstrumentAny>,
814    ) -> Self {
815        let handler = CoinbaseIntxFeedHandler::new(reader, signal);
816        Self {
817            handler,
818            tx,
819            instruments_cache,
820        }
821    }
822
823    /// Runs the WebSocket message feed.
824    async fn run(&mut self) {
825        while let Some(data) = self.next().await {
826            if let Err(e) = self.tx.send(data) {
827                tracing::error!("Error sending data: {e}");
828                break; // Stop processing on channel error
829            }
830        }
831    }
832
833    /// Gets the next message from the WebSocket message handler.
834    async fn next(&mut self) -> Option<NautilusWsMessage> {
835        let clock = get_atomic_clock_realtime();
836
837        while let Some(event) = self.handler.next().await {
838            match event {
839                CoinbaseIntxWsMessage::Instrument(msg) => {
840                    if let Some(inst) = parse_instrument_any(&msg, clock.get_time_ns()) {
841                        // Update instruments map
842                        self.instruments_cache
843                            .insert(inst.raw_symbol().inner(), inst.clone());
844                        return Some(NautilusWsMessage::Instrument(inst));
845                    }
846                }
847                CoinbaseIntxWsMessage::Funding(msg) => {
848                    tracing::warn!("Received {msg:?}"); // TODO: Implement
849                }
850                CoinbaseIntxWsMessage::BookSnapshot(msg) => {
851                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
852                        match parse_orderbook_snapshot_msg(
853                            &msg,
854                            inst.id(),
855                            inst.price_precision(),
856                            inst.size_precision(),
857                            clock.get_time_ns(),
858                        ) {
859                            Ok(deltas) => {
860                                let deltas = OrderBookDeltas_API::new(deltas);
861                                let data = Data::Deltas(deltas);
862                                return Some(NautilusWsMessage::Data(data));
863                            }
864                            Err(e) => {
865                                tracing::error!("Failed to parse orderbook snapshot: {e}");
866                                return None;
867                            }
868                        }
869                    }
870                    tracing::error!("No instrument found for {}", msg.product_id);
871                    return None;
872                }
873                CoinbaseIntxWsMessage::BookUpdate(msg) => {
874                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
875                        match parse_orderbook_update_msg(
876                            &msg,
877                            inst.id(),
878                            inst.price_precision(),
879                            inst.size_precision(),
880                            clock.get_time_ns(),
881                        ) {
882                            Ok(deltas) => {
883                                let deltas = OrderBookDeltas_API::new(deltas);
884                                let data = Data::Deltas(deltas);
885                                return Some(NautilusWsMessage::Data(data));
886                            }
887                            Err(e) => {
888                                tracing::error!("Failed to parse orderbook update: {e}");
889                            }
890                        }
891                    } else {
892                        tracing::error!("No instrument found for {}", msg.product_id);
893                    }
894                }
895                CoinbaseIntxWsMessage::Quote(msg) => {
896                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
897                        match parse_quote_msg(
898                            &msg,
899                            inst.id(),
900                            inst.price_precision(),
901                            inst.size_precision(),
902                            clock.get_time_ns(),
903                        ) {
904                            Ok(quote) => return Some(NautilusWsMessage::Data(Data::Quote(quote))),
905                            Err(e) => {
906                                tracing::error!("Failed to parse quote: {e}");
907                            }
908                        }
909                    } else {
910                        tracing::error!("No instrument found for {}", msg.product_id);
911                    }
912                }
913                CoinbaseIntxWsMessage::Trade(msg) => {
914                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
915                        match parse_trade_msg(
916                            &msg,
917                            inst.id(),
918                            inst.price_precision(),
919                            inst.size_precision(),
920                            clock.get_time_ns(),
921                        ) {
922                            Ok(trade) => return Some(NautilusWsMessage::Data(Data::Trade(trade))),
923                            Err(e) => {
924                                tracing::error!("Failed to parse trade: {e}");
925                            }
926                        }
927                    } else {
928                        tracing::error!("No instrument found for {}", msg.product_id);
929                    }
930                }
931                CoinbaseIntxWsMessage::Risk(msg) => {
932                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
933                        let mark_price = match parse_mark_price_msg(
934                            &msg,
935                            inst.id(),
936                            inst.price_precision(),
937                            clock.get_time_ns(),
938                        ) {
939                            Ok(mark_price) => Some(mark_price),
940                            Err(e) => {
941                                tracing::error!("Failed to parse mark price: {e}");
942                                None
943                            }
944                        };
945
946                        let index_price = match parse_index_price_msg(
947                            &msg,
948                            inst.id(),
949                            inst.price_precision(),
950                            clock.get_time_ns(),
951                        ) {
952                            Ok(index_price) => Some(index_price),
953                            Err(e) => {
954                                tracing::error!("Failed to parse index price: {e}");
955                                None
956                            }
957                        };
958
959                        match (mark_price, index_price) {
960                            (Some(mark), Some(index)) => {
961                                return Some(NautilusWsMessage::MarkAndIndex((mark, index)));
962                            }
963                            (Some(mark), None) => return Some(NautilusWsMessage::MarkPrice(mark)),
964                            (None, Some(index)) => {
965                                return Some(NautilusWsMessage::IndexPrice(index));
966                            }
967                            (None, None) => continue,
968                        };
969                    }
970                    tracing::error!("No instrument found for {}", msg.product_id);
971                }
972                CoinbaseIntxWsMessage::CandleSnapshot(msg) => {
973                    if let Some(inst) = self.instruments_cache.get(&msg.product_id) {
974                        match parse_candle_msg(
975                            &msg,
976                            inst.id(),
977                            inst.price_precision(),
978                            inst.size_precision(),
979                            clock.get_time_ns(),
980                        ) {
981                            Ok(bar) => return Some(NautilusWsMessage::Data(Data::Bar(bar))),
982                            Err(e) => {
983                                tracing::error!("Failed to parse candle: {e}");
984                            }
985                        }
986                    } else {
987                        tracing::error!("No instrument found for {}", msg.product_id);
988                    }
989                }
990                _ => {
991                    tracing::warn!("Not implemented: {event:?}");
992                }
993            }
994        }
995        None // Connection closed
996    }
997}