nautilus_okx/websocket/
auth.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Authentication coordination for the OKX WebSocket client.
17//!
18//! This module ensures each login attempt produces a fresh success or
19//! failure signal before subscriptions resume.
20
21use std::{
22    sync::{Arc, Mutex},
23    time::Duration,
24};
25
26use super::error::OKXWsError;
27
28pub(crate) type AuthResultSender = tokio::sync::oneshot::Sender<Result<(), String>>;
29pub(crate) type AuthResultReceiver = tokio::sync::oneshot::Receiver<Result<(), String>>;
30
31pub(crate) const AUTHENTICATION_TIMEOUT_SECS: u64 = 10;
32
33#[derive(Clone, Debug)]
34pub(crate) struct AuthTracker {
35    tx: Arc<Mutex<Option<AuthResultSender>>>,
36}
37
38impl AuthTracker {
39    pub(crate) fn new() -> Self {
40        Self {
41            tx: Arc::new(Mutex::new(None)),
42        }
43    }
44
45    pub(crate) fn begin(&self) -> AuthResultReceiver {
46        let (sender, receiver) = tokio::sync::oneshot::channel();
47
48        if let Ok(mut guard) = self.tx.lock() {
49            if let Some(old) = guard.take() {
50                tracing::warn!("New authentication request superseding previous pending request");
51                let _ = old.send(Err("Authentication attempt superseded".to_string()));
52            } else {
53                tracing::debug!("Starting new authentication request");
54            }
55            *guard = Some(sender);
56        }
57
58        receiver
59    }
60
61    pub(crate) fn succeed(&self) {
62        if let Ok(mut guard) = self.tx.lock()
63            && let Some(sender) = guard.take()
64        {
65            let _ = sender.send(Ok(()));
66        }
67    }
68
69    pub(crate) fn fail(&self, error: impl Into<String>) {
70        let message = error.into();
71        if let Ok(mut guard) = self.tx.lock()
72            && let Some(sender) = guard.take()
73        {
74            let _ = sender.send(Err(message));
75        }
76    }
77
78    pub(crate) async fn wait_for_result(
79        &self,
80        timeout: Duration,
81        receiver: AuthResultReceiver,
82    ) -> Result<(), OKXWsError> {
83        match tokio::time::timeout(timeout, receiver).await {
84            Ok(Ok(Ok(()))) => Ok(()),
85            Ok(Ok(Err(msg))) => Err(OKXWsError::AuthenticationError(msg)),
86            Ok(Err(_)) => Err(OKXWsError::AuthenticationError(
87                "Authentication channel closed".to_string(),
88            )),
89            Err(_) => {
90                if let Ok(mut guard) = self.tx.lock() {
91                    guard.take();
92                }
93                Err(OKXWsError::AuthenticationError(
94                    "Authentication timed out".to_string(),
95                ))
96            }
97        }
98    }
99}
100
101////////////////////////////////////////////////////////////////////////////////
102// Tests
103////////////////////////////////////////////////////////////////////////////////
104
105#[cfg(test)]
106mod tests {
107    use std::time::Duration;
108
109    use rstest::rstest;
110
111    use super::*;
112
113    #[rstest]
114    #[tokio::test]
115    async fn begin_supersedes_previous_sender() {
116        let tracker = AuthTracker::new();
117
118        let first = tracker.begin();
119        let second = tracker.begin();
120
121        // Completing the first receiver should yield an error indicating it was superseded
122        let result = first.await.expect("oneshot closed unexpectedly");
123        assert_eq!(result, Err("Authentication attempt superseded".to_string()));
124
125        // Fulfil the second attempt to keep the mutex state clean
126        tracker.succeed();
127        tracker
128            .wait_for_result(Duration::from_secs(1), second)
129            .await
130            .expect("expected successful authentication");
131    }
132}