nautilus_okx/websocket/
auth.rs1use std::{
22 sync::{Arc, Mutex},
23 time::Duration,
24};
25
26use super::error::OKXWsError;
27
28pub(crate) type AuthResultSender = tokio::sync::oneshot::Sender<Result<(), String>>;
29pub(crate) type AuthResultReceiver = tokio::sync::oneshot::Receiver<Result<(), String>>;
30
31pub(crate) const AUTHENTICATION_TIMEOUT_SECS: u64 = 10;
32
33#[derive(Clone, Debug)]
34pub(crate) struct AuthTracker {
35 tx: Arc<Mutex<Option<AuthResultSender>>>,
36}
37
38impl AuthTracker {
39 pub(crate) fn new() -> Self {
40 Self {
41 tx: Arc::new(Mutex::new(None)),
42 }
43 }
44
45 pub(crate) fn begin(&self) -> AuthResultReceiver {
46 let (sender, receiver) = tokio::sync::oneshot::channel();
47
48 if let Ok(mut guard) = self.tx.lock() {
49 if let Some(old) = guard.take() {
50 tracing::warn!("New authentication request superseding previous pending request");
51 let _ = old.send(Err("Authentication attempt superseded".to_string()));
52 } else {
53 tracing::debug!("Starting new authentication request");
54 }
55 *guard = Some(sender);
56 }
57
58 receiver
59 }
60
61 pub(crate) fn succeed(&self) {
62 if let Ok(mut guard) = self.tx.lock()
63 && let Some(sender) = guard.take()
64 {
65 let _ = sender.send(Ok(()));
66 }
67 }
68
69 pub(crate) fn fail(&self, error: impl Into<String>) {
70 let message = error.into();
71 if let Ok(mut guard) = self.tx.lock()
72 && let Some(sender) = guard.take()
73 {
74 let _ = sender.send(Err(message));
75 }
76 }
77
78 pub(crate) async fn wait_for_result(
79 &self,
80 timeout: Duration,
81 receiver: AuthResultReceiver,
82 ) -> Result<(), OKXWsError> {
83 match tokio::time::timeout(timeout, receiver).await {
84 Ok(Ok(Ok(()))) => Ok(()),
85 Ok(Ok(Err(msg))) => Err(OKXWsError::AuthenticationError(msg)),
86 Ok(Err(_)) => Err(OKXWsError::AuthenticationError(
87 "Authentication channel closed".to_string(),
88 )),
89 Err(_) => {
90 if let Ok(mut guard) = self.tx.lock() {
91 guard.take();
92 }
93 Err(OKXWsError::AuthenticationError(
94 "Authentication timed out".to_string(),
95 ))
96 }
97 }
98 }
99}
100
101#[cfg(test)]
106mod tests {
107 use std::time::Duration;
108
109 use rstest::rstest;
110
111 use super::*;
112
113 #[rstest]
114 #[tokio::test]
115 async fn begin_supersedes_previous_sender() {
116 let tracker = AuthTracker::new();
117
118 let first = tracker.begin();
119 let second = tracker.begin();
120
121 let result = first.await.expect("oneshot closed unexpectedly");
123 assert_eq!(result, Err("Authentication attempt superseded".to_string()));
124
125 tracker.succeed();
127 tracker
128 .wait_for_result(Duration::from_secs(1), second)
129 .await
130 .expect("expected successful authentication");
131 }
132}