nautilus_bybit/websocket/
auth.rs1use std::{
19 sync::{Arc, Mutex},
20 time::Duration,
21};
22
23use super::error::BybitWsError;
24
25pub(crate) type AuthResultSender = tokio::sync::oneshot::Sender<Result<(), String>>;
26pub(crate) type AuthResultReceiver = tokio::sync::oneshot::Receiver<Result<(), String>>;
27
28pub(crate) const AUTHENTICATION_TIMEOUT_SECS: u64 = 10;
29
30#[derive(Clone, Debug)]
31pub(crate) struct AuthTracker {
32 tx: Arc<Mutex<Option<AuthResultSender>>>,
33}
34
35#[allow(dead_code)]
36impl AuthTracker {
37 pub(crate) fn new() -> Self {
38 Self {
39 tx: Arc::new(Mutex::new(None)),
40 }
41 }
42
43 pub(crate) fn begin(&self) -> AuthResultReceiver {
44 let (sender, receiver) = tokio::sync::oneshot::channel();
45
46 if let Ok(mut guard) = self.tx.lock() {
47 if let Some(old) = guard.take() {
48 tracing::warn!("New authentication request superseding previous pending request");
49 let _ = old.send(Err("Authentication attempt superseded".to_string()));
50 } else {
51 tracing::debug!("Starting new authentication request");
52 }
53 *guard = Some(sender);
54 }
55
56 receiver
57 }
58
59 pub(crate) fn succeed(&self) {
60 if let Ok(mut guard) = self.tx.lock()
61 && let Some(sender) = guard.take()
62 {
63 let _ = sender.send(Ok(()));
64 }
65 }
66
67 pub(crate) fn fail(&self, error: impl Into<String>) {
68 let message = error.into();
69 if let Ok(mut guard) = self.tx.lock()
70 && let Some(sender) = guard.take()
71 {
72 let _ = sender.send(Err(message));
73 }
74 }
75
76 pub(crate) async fn wait_for_result(
77 &self,
78 timeout: Duration,
79 receiver: AuthResultReceiver,
80 ) -> Result<(), BybitWsError> {
81 match tokio::time::timeout(timeout, receiver).await {
82 Ok(Ok(Ok(()))) => Ok(()),
83 Ok(Ok(Err(msg))) => Err(BybitWsError::Authentication(msg)),
84 Ok(Err(_)) => Err(BybitWsError::Authentication(
85 "Authentication channel closed".to_string(),
86 )),
87 Err(_) => {
88 if let Ok(mut guard) = self.tx.lock() {
89 guard.take();
90 }
91 Err(BybitWsError::Authentication(
92 "Authentication timed out".to_string(),
93 ))
94 }
95 }
96 }
97}
98
99#[cfg(test)]
104mod tests {
105 use std::time::Duration;
106
107 use rstest::rstest;
108
109 use super::*;
110
111 #[rstest]
112 #[tokio::test]
113 async fn begin_supersedes_previous_sender() {
114 let tracker = AuthTracker::new();
115
116 let first = tracker.begin();
117 let second = tracker.begin();
118
119 let result = first.await.expect("oneshot closed unexpectedly");
121 assert_eq!(result, Err("Authentication attempt superseded".to_string()));
122
123 tracker.succeed();
125 tracker
126 .wait_for_result(Duration::from_secs(1), second)
127 .await
128 .expect("expected successful authentication");
129 }
130}