nautilus_network/websocket/
auth.rs1use std::{
40 sync::{Arc, Mutex},
41 time::Duration,
42};
43
44pub type AuthResultSender = tokio::sync::oneshot::Sender<Result<(), String>>;
45pub type AuthResultReceiver = tokio::sync::oneshot::Receiver<Result<(), String>>;
46
47#[derive(Clone, Debug)]
63pub struct AuthTracker {
64 tx: Arc<Mutex<Option<AuthResultSender>>>,
65}
66
67impl AuthTracker {
68 pub fn new() -> Self {
70 Self {
71 tx: Arc::new(Mutex::new(None)),
72 }
73 }
74
75 pub fn begin(&self) -> AuthResultReceiver {
81 let (sender, receiver) = tokio::sync::oneshot::channel();
82
83 if let Ok(mut guard) = self.tx.lock() {
84 if let Some(old) = guard.take() {
85 tracing::warn!("New authentication request superseding previous pending request");
86 let _ = old.send(Err("Authentication attempt superseded".to_string()));
87 } else {
88 tracing::debug!("Starting new authentication request");
89 }
90 *guard = Some(sender);
91 }
92
93 receiver
94 }
95
96 pub fn succeed(&self) {
103 if let Ok(mut guard) = self.tx.lock()
104 && let Some(sender) = guard.take()
105 {
106 let _ = sender.send(Ok(()));
107 }
108 }
109
110 pub fn fail(&self, error: impl Into<String>) {
117 let message = error.into();
118 if let Ok(mut guard) = self.tx.lock()
119 && let Some(sender) = guard.take()
120 {
121 let _ = sender.send(Err(message));
122 }
123 }
124
125 pub async fn wait_for_result<E>(
142 &self,
143 timeout: Duration,
144 receiver: AuthResultReceiver,
145 ) -> Result<(), E>
146 where
147 E: From<String>,
148 {
149 match tokio::time::timeout(timeout, receiver).await {
150 Ok(Ok(Ok(()))) => Ok(()),
151 Ok(Ok(Err(msg))) => Err(E::from(msg)),
152 Ok(Err(_)) => Err(E::from("Authentication channel closed".to_string())),
153 Err(_) => {
154 if let Ok(mut guard) = self.tx.lock() {
156 guard.take();
157 }
158 Err(E::from("Authentication timed out".to_string()))
159 }
160 }
161 }
162}
163
164impl Default for AuthTracker {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170#[cfg(test)]
175mod tests {
176 use std::{
177 sync::atomic::{AtomicBool, Ordering},
178 time::Duration,
179 };
180
181 use rstest::rstest;
182
183 use super::*;
184
185 #[derive(Debug, PartialEq)]
186 struct TestError(String);
187
188 impl From<String> for TestError {
189 fn from(msg: String) -> Self {
190 Self(msg)
191 }
192 }
193
194 #[rstest]
195 #[tokio::test]
196 async fn test_successful_authentication() {
197 let tracker = AuthTracker::new();
198 let rx = tracker.begin();
199
200 tracker.succeed();
201
202 let result: Result<(), TestError> =
203 tracker.wait_for_result(Duration::from_secs(1), rx).await;
204
205 assert!(result.is_ok());
206 }
207
208 #[rstest]
209 #[tokio::test]
210 async fn test_failed_authentication() {
211 let tracker = AuthTracker::new();
212 let rx = tracker.begin();
213
214 tracker.fail("Invalid credentials");
215
216 let result: Result<(), TestError> =
217 tracker.wait_for_result(Duration::from_secs(1), rx).await;
218
219 assert_eq!(
220 result.unwrap_err(),
221 TestError("Invalid credentials".to_string())
222 );
223 }
224
225 #[rstest]
226 #[tokio::test]
227 async fn test_authentication_timeout() {
228 let tracker = AuthTracker::new();
229 let rx = tracker.begin();
230
231 let result: Result<(), TestError> =
234 tracker.wait_for_result(Duration::from_millis(50), rx).await;
235
236 assert_eq!(
237 result.unwrap_err(),
238 TestError("Authentication timed out".to_string())
239 );
240 }
241
242 #[rstest]
243 #[tokio::test]
244 async fn test_begin_supersedes_previous_sender() {
245 let tracker = AuthTracker::new();
246
247 let first = tracker.begin();
248 let second = tracker.begin();
249
250 let result = first.await.expect("oneshot closed unexpectedly");
252 assert_eq!(result, Err("Authentication attempt superseded".to_string()));
253
254 tracker.succeed();
256 let result: Result<(), TestError> = tracker
257 .wait_for_result(Duration::from_secs(1), second)
258 .await;
259
260 assert!(result.is_ok());
261 }
262
263 #[rstest]
264 #[tokio::test]
265 async fn test_succeed_without_pending_auth() {
266 let tracker = AuthTracker::new();
267
268 tracker.succeed();
270 }
271
272 #[rstest]
273 #[tokio::test]
274 async fn test_fail_without_pending_auth() {
275 let tracker = AuthTracker::new();
276
277 tracker.fail("Some error");
279 }
280
281 #[rstest]
282 #[tokio::test]
283 async fn test_multiple_sequential_authentications() {
284 let tracker = AuthTracker::new();
285
286 let rx1 = tracker.begin();
288 tracker.succeed();
289 let result1: Result<(), TestError> =
290 tracker.wait_for_result(Duration::from_secs(1), rx1).await;
291 assert!(result1.is_ok());
292
293 let rx2 = tracker.begin();
295 tracker.fail("Credentials expired");
296 let result2: Result<(), TestError> =
297 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
298 assert_eq!(
299 result2.unwrap_err(),
300 TestError("Credentials expired".to_string())
301 );
302
303 let rx3 = tracker.begin();
305 tracker.succeed();
306 let result3: Result<(), TestError> =
307 tracker.wait_for_result(Duration::from_secs(1), rx3).await;
308 assert!(result3.is_ok());
309 }
310
311 #[rstest]
312 #[tokio::test]
313 async fn test_channel_closed_before_result() {
314 let tracker = AuthTracker::new();
315 let rx = tracker.begin();
316
317 tracker.begin();
319
320 let result: Result<(), TestError> =
322 tracker.wait_for_result(Duration::from_secs(1), rx).await;
323
324 assert_eq!(
325 result.unwrap_err(),
326 TestError("Authentication attempt superseded".to_string())
327 );
328 }
329
330 #[rstest]
331 #[tokio::test]
332 async fn test_concurrent_auth_attempts() {
333 let tracker = Arc::new(AuthTracker::new());
334 let mut handles = vec![];
335
336 for i in 0..10 {
338 let tracker_clone = Arc::clone(&tracker);
339 let handle = tokio::spawn(async move {
340 let rx = tracker_clone.begin();
341
342 if i == 9 {
344 tokio::time::sleep(Duration::from_millis(10)).await;
345 tracker_clone.succeed();
346 }
347
348 let result: Result<(), TestError> = tracker_clone
349 .wait_for_result(Duration::from_secs(1), rx)
350 .await;
351
352 (i, result)
353 });
354 handles.push(handle);
355 }
356
357 let mut successes = 0;
358 let mut superseded = 0;
359
360 for handle in handles {
361 let (i, result) = handle.await.unwrap();
362 match result {
363 Ok(()) => {
364 assert_eq!(i, 9);
366 successes += 1;
367 }
368 Err(TestError(msg)) if msg.contains("superseded") => {
369 superseded += 1;
370 }
371 Err(e) => panic!("Unexpected error: {e:?}"),
372 }
373 }
374
375 assert_eq!(successes, 1);
376 assert_eq!(superseded, 9);
377 }
378
379 #[rstest]
380 fn test_default_trait() {
381 let _tracker = AuthTracker::default();
382 }
383
384 #[rstest]
385 #[tokio::test]
386 async fn test_clone_trait() {
387 let tracker = AuthTracker::new();
388 let cloned = tracker.clone();
389
390 let rx = tracker.begin();
392 cloned.succeed(); let result: Result<(), TestError> =
394 tracker.wait_for_result(Duration::from_secs(1), rx).await;
395 assert!(result.is_ok());
396 }
397
398 #[rstest]
399 fn test_debug_trait() {
400 let tracker = AuthTracker::new();
401 let debug_str = format!("{tracker:?}");
402 assert!(debug_str.contains("AuthTracker"));
403 }
404
405 #[rstest]
406 #[tokio::test]
407 async fn test_timeout_clears_sender() {
408 let tracker = AuthTracker::new();
409
410 let rx1 = tracker.begin();
412 let result1: Result<(), TestError> = tracker
413 .wait_for_result(Duration::from_millis(50), rx1)
414 .await;
415 assert_eq!(
416 result1.unwrap_err(),
417 TestError("Authentication timed out".to_string())
418 );
419
420 let rx2 = tracker.begin();
422 tracker.succeed();
423 let result2: Result<(), TestError> =
424 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
425 assert!(result2.is_ok());
426 }
427
428 #[rstest]
429 #[tokio::test]
430 async fn test_fail_clears_sender() {
431 let tracker = AuthTracker::new();
432
433 let rx1 = tracker.begin();
435 tracker.fail("Bad credentials");
436 let result1: Result<(), TestError> =
437 tracker.wait_for_result(Duration::from_secs(1), rx1).await;
438 assert!(result1.is_err());
439
440 let rx2 = tracker.begin();
442 tracker.succeed();
443 let result2: Result<(), TestError> =
444 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
445 assert!(result2.is_ok());
446 }
447
448 #[rstest]
449 #[tokio::test]
450 async fn test_succeed_clears_sender() {
451 let tracker = AuthTracker::new();
452
453 let rx1 = tracker.begin();
455 tracker.succeed();
456 let result1: Result<(), TestError> =
457 tracker.wait_for_result(Duration::from_secs(1), rx1).await;
458 assert!(result1.is_ok());
459
460 let rx2 = tracker.begin();
462 tracker.succeed();
463 let result2: Result<(), TestError> =
464 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
465 assert!(result2.is_ok());
466 }
467
468 #[rstest]
469 #[tokio::test]
470 async fn test_rapid_begin_succeed_cycles() {
471 let tracker = AuthTracker::new();
472
473 for _ in 0..100 {
475 let rx = tracker.begin();
476 tracker.succeed();
477 let result: Result<(), TestError> =
478 tracker.wait_for_result(Duration::from_secs(1), rx).await;
479 assert!(result.is_ok());
480 }
481 }
482
483 #[rstest]
484 #[tokio::test]
485 async fn test_double_succeed_is_safe() {
486 let tracker = AuthTracker::new();
487 let rx = tracker.begin();
488
489 tracker.succeed();
491 tracker.succeed(); let result: Result<(), TestError> =
494 tracker.wait_for_result(Duration::from_secs(1), rx).await;
495 assert!(result.is_ok());
496 }
497
498 #[rstest]
499 #[tokio::test]
500 async fn test_double_fail_is_safe() {
501 let tracker = AuthTracker::new();
502 let rx = tracker.begin();
503
504 tracker.fail("Error 1");
506 tracker.fail("Error 2"); let result: Result<(), TestError> =
509 tracker.wait_for_result(Duration::from_secs(1), rx).await;
510 assert_eq!(
511 result.unwrap_err(),
512 TestError("Error 1".to_string()) );
514 }
515
516 #[rstest]
517 #[tokio::test]
518 async fn test_succeed_after_fail_is_ignored() {
519 let tracker = AuthTracker::new();
520 let rx = tracker.begin();
521
522 tracker.fail("Auth failed");
523 tracker.succeed(); let result: Result<(), TestError> =
526 tracker.wait_for_result(Duration::from_secs(1), rx).await;
527 assert!(result.is_err()); }
529
530 #[rstest]
531 #[tokio::test]
532 async fn test_fail_after_succeed_is_ignored() {
533 let tracker = AuthTracker::new();
534 let rx = tracker.begin();
535
536 tracker.succeed();
537 tracker.fail("Auth failed"); let result: Result<(), TestError> =
540 tracker.wait_for_result(Duration::from_secs(1), rx).await;
541 assert!(result.is_ok()); }
543
544 #[rstest]
551 #[tokio::test]
552 async fn test_reconnect_flow_waits_for_auth() {
553 let tracker = Arc::new(AuthTracker::new());
554 let subscribed = Arc::new(tokio::sync::Notify::new());
555 let auth_completed = Arc::new(tokio::sync::Notify::new());
556
557 let tracker_reconnect = Arc::clone(&tracker);
559 let subscribed_reconnect = Arc::clone(&subscribed);
560 let auth_completed_reconnect = Arc::clone(&auth_completed);
561
562 let reconnect_task = tokio::spawn(async move {
563 let rx = tracker_reconnect.begin();
565
566 let tracker_resub = Arc::clone(&tracker_reconnect);
568 let subscribed_resub = Arc::clone(&subscribed_reconnect);
569 let auth_completed_resub = Arc::clone(&auth_completed_reconnect);
570
571 let resub_task = tokio::spawn(async move {
572 let result: Result<(), TestError> = tracker_resub
574 .wait_for_result(Duration::from_secs(5), rx)
575 .await;
576
577 if result.is_ok() {
578 auth_completed_resub.notify_one();
579 tokio::time::sleep(Duration::from_millis(10)).await;
581 subscribed_resub.notify_one();
582 }
583 });
584
585 resub_task.await.unwrap();
586 });
587
588 tokio::time::sleep(Duration::from_millis(100)).await;
590 tracker.succeed();
591
592 reconnect_task.await.unwrap();
594
595 tokio::select! {
597 _ = auth_completed.notified() => {
598 }
600 _ = tokio::time::sleep(Duration::from_secs(1)) => {
601 panic!("Auth never completed");
602 }
603 }
604
605 tokio::select! {
607 _ = subscribed.notified() => {
608 }
610 _ = tokio::time::sleep(Duration::from_secs(1)) => {
611 panic!("Subscription never completed");
612 }
613 }
614 }
615
616 #[rstest]
618 #[tokio::test]
619 async fn test_reconnect_flow_blocks_on_auth_failure() {
620 let tracker = Arc::new(AuthTracker::new());
621 let subscribed = Arc::new(AtomicBool::new(false));
622
623 let tracker_reconnect = Arc::clone(&tracker);
624 let subscribed_reconnect = Arc::clone(&subscribed);
625
626 let reconnect_task = tokio::spawn(async move {
627 let rx = tracker_reconnect.begin();
628
629 let tracker_resub = Arc::clone(&tracker_reconnect);
631 let subscribed_resub = Arc::clone(&subscribed_reconnect);
632
633 let resub_task = tokio::spawn(async move {
634 let result: Result<(), TestError> = tracker_resub
635 .wait_for_result(Duration::from_secs(5), rx)
636 .await;
637
638 if result.is_ok() {
640 subscribed_resub.store(true, Ordering::Relaxed);
641 }
642 });
643
644 resub_task.await.unwrap();
645 });
646
647 tokio::time::sleep(Duration::from_millis(50)).await;
649 tracker.fail("Invalid credentials");
650
651 reconnect_task.await.unwrap();
653
654 tokio::time::sleep(Duration::from_millis(100)).await;
656 assert!(!subscribed.load(Ordering::Relaxed));
657 }
658
659 #[rstest]
661 #[tokio::test]
662 async fn test_state_machine_transitions() {
663 let tracker = AuthTracker::new();
664
665 let rx1 = tracker.begin();
667
668 tracker.succeed();
670 let result1: Result<(), TestError> =
671 tracker.wait_for_result(Duration::from_secs(1), rx1).await;
672 assert!(result1.is_ok());
673
674 let rx2 = tracker.begin();
676
677 tracker.fail("Error");
679 let result2: Result<(), TestError> =
680 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
681 assert!(result2.is_err());
682
683 let rx3 = tracker.begin();
685
686 let result3: Result<(), TestError> = tracker
688 .wait_for_result(Duration::from_millis(50), rx3)
689 .await;
690 assert_eq!(
691 result3.unwrap_err(),
692 TestError("Authentication timed out".to_string())
693 );
694
695 let rx4 = tracker.begin();
697
698 let rx5 = tracker.begin();
700 let result4: Result<(), TestError> =
701 tracker.wait_for_result(Duration::from_secs(1), rx4).await;
702 assert_eq!(
703 result4.unwrap_err(),
704 TestError("Authentication attempt superseded".to_string())
705 );
706
707 tracker.succeed();
709 let result5: Result<(), TestError> =
710 tracker.wait_for_result(Duration::from_secs(1), rx5).await;
711 assert!(result5.is_ok());
712 }
713
714 #[rstest]
716 #[tokio::test]
717 async fn test_no_sender_leaks() {
718 let tracker = AuthTracker::new();
719
720 for _ in 0..100 {
721 let rx = tracker.begin();
722 let _result: Result<(), TestError> =
723 tracker.wait_for_result(Duration::from_millis(1), rx).await;
724 }
725
726 let rx = tracker.begin();
727 tracker.succeed();
728 let result: Result<(), TestError> =
729 tracker.wait_for_result(Duration::from_secs(1), rx).await;
730 assert!(result.is_ok());
731 }
732
733 #[rstest]
735 #[tokio::test]
736 async fn test_concurrent_succeed_fail_calls() {
737 let tracker = Arc::new(AuthTracker::new());
738 let rx = tracker.begin();
739
740 let mut handles = vec![];
741
742 for _ in 0..50 {
744 let tracker_clone = Arc::clone(&tracker);
745 handles.push(tokio::spawn(async move {
746 tracker_clone.succeed();
747 }));
748 }
749
750 for _ in 0..50 {
752 let tracker_clone = Arc::clone(&tracker);
753 handles.push(tokio::spawn(async move {
754 tracker_clone.fail("Error");
755 }));
756 }
757
758 for handle in handles {
760 handle.await.unwrap();
761 }
762
763 let result: Result<(), TestError> =
765 tracker.wait_for_result(Duration::from_secs(1), rx).await;
766 let _ = result;
768 }
769}