nautilus_bybit/websocket/
auth.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Authentication coordination for the Bybit WebSocket client.
17
18use std::{
19    sync::{Arc, Mutex},
20    time::Duration,
21};
22
23use super::error::BybitWsError;
24
25pub(crate) type AuthResultSender = tokio::sync::oneshot::Sender<Result<(), String>>;
26pub(crate) type AuthResultReceiver = tokio::sync::oneshot::Receiver<Result<(), String>>;
27
28pub(crate) const AUTHENTICATION_TIMEOUT_SECS: u64 = 10;
29
30#[derive(Clone, Debug)]
31pub(crate) struct AuthTracker {
32    tx: Arc<Mutex<Option<AuthResultSender>>>,
33}
34
35#[allow(dead_code)]
36impl AuthTracker {
37    pub(crate) fn new() -> Self {
38        Self {
39            tx: Arc::new(Mutex::new(None)),
40        }
41    }
42
43    pub(crate) fn begin(&self) -> AuthResultReceiver {
44        let (sender, receiver) = tokio::sync::oneshot::channel();
45
46        if let Ok(mut guard) = self.tx.lock() {
47            if let Some(old) = guard.take() {
48                tracing::warn!("New authentication request superseding previous pending request");
49                let _ = old.send(Err("Authentication attempt superseded".to_string()));
50            } else {
51                tracing::debug!("Starting new authentication request");
52            }
53            *guard = Some(sender);
54        }
55
56        receiver
57    }
58
59    pub(crate) fn succeed(&self) {
60        if let Ok(mut guard) = self.tx.lock()
61            && let Some(sender) = guard.take()
62        {
63            let _ = sender.send(Ok(()));
64        }
65    }
66
67    pub(crate) fn fail(&self, error: impl Into<String>) {
68        let message = error.into();
69        if let Ok(mut guard) = self.tx.lock()
70            && let Some(sender) = guard.take()
71        {
72            let _ = sender.send(Err(message));
73        }
74    }
75
76    pub(crate) async fn wait_for_result(
77        &self,
78        timeout: Duration,
79        receiver: AuthResultReceiver,
80    ) -> Result<(), BybitWsError> {
81        match tokio::time::timeout(timeout, receiver).await {
82            Ok(Ok(Ok(()))) => Ok(()),
83            Ok(Ok(Err(msg))) => Err(BybitWsError::Authentication(msg)),
84            Ok(Err(_)) => Err(BybitWsError::Authentication(
85                "Authentication channel closed".to_string(),
86            )),
87            Err(_) => {
88                if let Ok(mut guard) = self.tx.lock() {
89                    guard.take();
90                }
91                Err(BybitWsError::Authentication(
92                    "Authentication timed out".to_string(),
93                ))
94            }
95        }
96    }
97}
98
99////////////////////////////////////////////////////////////////////////////////
100// Tests
101////////////////////////////////////////////////////////////////////////////////
102
103#[cfg(test)]
104mod tests {
105    use std::time::Duration;
106
107    use rstest::rstest;
108
109    use super::*;
110
111    #[rstest]
112    #[tokio::test]
113    async fn begin_supersedes_previous_sender() {
114        let tracker = AuthTracker::new();
115
116        let first = tracker.begin();
117        let second = tracker.begin();
118
119        // Completing the first receiver should yield an error indicating it was superseded.
120        let result = first.await.expect("oneshot closed unexpectedly");
121        assert_eq!(result, Err("Authentication attempt superseded".to_string()));
122
123        // Fulfil the second attempt to keep the mutex state clean.
124        tracker.succeed();
125        tracker
126            .wait_for_result(Duration::from_secs(1), second)
127            .await
128            .expect("expected successful authentication");
129    }
130}