nautilus_network/websocket/
auth.rs1use std::{
40 sync::{
41 Arc, Mutex,
42 atomic::{AtomicBool, Ordering},
43 },
44 time::Duration,
45};
46
47pub type AuthResultSender = tokio::sync::oneshot::Sender<Result<(), String>>;
48pub type AuthResultReceiver = tokio::sync::oneshot::Receiver<Result<(), String>>;
49
50#[derive(Clone, Debug)]
73pub struct AuthTracker {
74 tx: Arc<Mutex<Option<AuthResultSender>>>,
75 authenticated: Arc<AtomicBool>,
76}
77
78impl AuthTracker {
79 pub fn new() -> Self {
81 Self {
82 tx: Arc::new(Mutex::new(None)),
83 authenticated: Arc::new(AtomicBool::new(false)),
84 }
85 }
86
87 #[must_use]
92 pub fn is_authenticated(&self) -> bool {
93 self.authenticated.load(Ordering::Acquire)
94 }
95
96 pub fn invalidate(&self) {
101 self.authenticated.store(false, Ordering::Release);
102 }
103
104 pub fn begin(&self) -> AuthResultReceiver {
113 let (sender, receiver) = tokio::sync::oneshot::channel();
114 self.authenticated.store(false, Ordering::Release);
115
116 if let Ok(mut guard) = self.tx.lock() {
117 if let Some(old) = guard.take() {
118 log::warn!("New authentication request superseding previous pending request");
119 let _ = old.send(Err("Authentication attempt superseded".to_string()));
120 } else {
121 log::debug!("Starting new authentication request");
122 }
123 *guard = Some(sender);
124 }
125
126 receiver
127 }
128
129 pub fn succeed(&self) {
138 self.authenticated.store(true, Ordering::Release);
139 if let Ok(mut guard) = self.tx.lock()
140 && let Some(sender) = guard.take()
141 {
142 let _ = sender.send(Ok(()));
143 }
144 }
145
146 pub fn fail(&self, error: impl Into<String>) {
155 self.authenticated.store(false, Ordering::Release);
156 let message = error.into();
157 if let Ok(mut guard) = self.tx.lock()
158 && let Some(sender) = guard.take()
159 {
160 let _ = sender.send(Err(message));
161 }
162 }
163
164 pub async fn wait_for_result<E>(
181 &self,
182 timeout: Duration,
183 receiver: AuthResultReceiver,
184 ) -> Result<(), E>
185 where
186 E: From<String>,
187 {
188 match tokio::time::timeout(timeout, receiver).await {
189 Ok(Ok(Ok(()))) => Ok(()),
190 Ok(Ok(Err(msg))) => Err(E::from(msg)),
191 Ok(Err(_)) => Err(E::from("Authentication channel closed".to_string())),
192 Err(_) => {
193 if let Ok(mut guard) = self.tx.lock() {
195 guard.take();
196 }
197 Err(E::from("Authentication timed out".to_string()))
198 }
199 }
200 }
201}
202
203impl Default for AuthTracker {
204 fn default() -> Self {
205 Self::new()
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use std::{
212 sync::atomic::{AtomicBool, Ordering},
213 time::Duration,
214 };
215
216 use rstest::rstest;
217
218 use super::*;
219
220 #[derive(Debug, PartialEq)]
221 struct TestError(String);
222
223 impl From<String> for TestError {
224 fn from(msg: String) -> Self {
225 Self(msg)
226 }
227 }
228
229 #[rstest]
230 #[tokio::test]
231 async fn test_successful_authentication() {
232 let tracker = AuthTracker::new();
233 let rx = tracker.begin();
234
235 tracker.succeed();
236
237 let result: Result<(), TestError> =
238 tracker.wait_for_result(Duration::from_secs(1), rx).await;
239
240 assert!(result.is_ok());
241 }
242
243 #[rstest]
244 #[tokio::test]
245 async fn test_failed_authentication() {
246 let tracker = AuthTracker::new();
247 let rx = tracker.begin();
248
249 tracker.fail("Invalid credentials");
250
251 let result: Result<(), TestError> =
252 tracker.wait_for_result(Duration::from_secs(1), rx).await;
253
254 assert_eq!(
255 result.unwrap_err(),
256 TestError("Invalid credentials".to_string())
257 );
258 }
259
260 #[rstest]
261 #[tokio::test]
262 async fn test_authentication_timeout() {
263 let tracker = AuthTracker::new();
264 let rx = tracker.begin();
265
266 let result: Result<(), TestError> =
269 tracker.wait_for_result(Duration::from_millis(50), rx).await;
270
271 assert_eq!(
272 result.unwrap_err(),
273 TestError("Authentication timed out".to_string())
274 );
275 }
276
277 #[rstest]
278 #[tokio::test]
279 async fn test_begin_supersedes_previous_sender() {
280 let tracker = AuthTracker::new();
281
282 let first = tracker.begin();
283 let second = tracker.begin();
284
285 let result = first.await.expect("oneshot closed unexpectedly");
287 assert_eq!(result, Err("Authentication attempt superseded".to_string()));
288
289 tracker.succeed();
291 let result: Result<(), TestError> = tracker
292 .wait_for_result(Duration::from_secs(1), second)
293 .await;
294
295 assert!(result.is_ok());
296 }
297
298 #[rstest]
299 #[tokio::test]
300 async fn test_succeed_without_pending_auth() {
301 let tracker = AuthTracker::new();
302
303 tracker.succeed();
305 }
306
307 #[rstest]
308 #[tokio::test]
309 async fn test_fail_without_pending_auth() {
310 let tracker = AuthTracker::new();
311
312 tracker.fail("Some error");
314 }
315
316 #[rstest]
317 #[tokio::test]
318 async fn test_multiple_sequential_authentications() {
319 let tracker = AuthTracker::new();
320
321 let rx1 = tracker.begin();
323 tracker.succeed();
324 let result1: Result<(), TestError> =
325 tracker.wait_for_result(Duration::from_secs(1), rx1).await;
326 assert!(result1.is_ok());
327
328 let rx2 = tracker.begin();
330 tracker.fail("Credentials expired");
331 let result2: Result<(), TestError> =
332 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
333 assert_eq!(
334 result2.unwrap_err(),
335 TestError("Credentials expired".to_string())
336 );
337
338 let rx3 = tracker.begin();
340 tracker.succeed();
341 let result3: Result<(), TestError> =
342 tracker.wait_for_result(Duration::from_secs(1), rx3).await;
343 assert!(result3.is_ok());
344 }
345
346 #[rstest]
347 #[tokio::test]
348 async fn test_channel_closed_before_result() {
349 let tracker = AuthTracker::new();
350 let rx = tracker.begin();
351
352 tracker.begin();
354
355 let result: Result<(), TestError> =
357 tracker.wait_for_result(Duration::from_secs(1), rx).await;
358
359 assert_eq!(
360 result.unwrap_err(),
361 TestError("Authentication attempt superseded".to_string())
362 );
363 }
364
365 #[rstest]
366 #[tokio::test]
367 async fn test_concurrent_auth_attempts() {
368 let tracker = Arc::new(AuthTracker::new());
369 let mut handles = vec![];
370
371 for i in 0..10 {
373 let tracker_clone = Arc::clone(&tracker);
374 let handle = tokio::spawn(async move {
375 let rx = tracker_clone.begin();
376
377 if i == 9 {
379 tokio::time::sleep(Duration::from_millis(10)).await;
380 tracker_clone.succeed();
381 }
382
383 let result: Result<(), TestError> = tracker_clone
384 .wait_for_result(Duration::from_secs(1), rx)
385 .await;
386
387 (i, result)
388 });
389 handles.push(handle);
390 }
391
392 let mut successes = 0;
393 let mut superseded = 0;
394
395 for handle in handles {
396 let (i, result) = handle.await.unwrap();
397 match result {
398 Ok(()) => {
399 assert_eq!(i, 9);
401 successes += 1;
402 }
403 Err(TestError(msg)) if msg.contains("superseded") => {
404 superseded += 1;
405 }
406 Err(e) => panic!("Unexpected error: {e:?}"),
407 }
408 }
409
410 assert_eq!(successes, 1);
411 assert_eq!(superseded, 9);
412 }
413
414 #[rstest]
415 fn test_default_trait() {
416 let _tracker = AuthTracker::default();
417 }
418
419 #[rstest]
420 #[tokio::test]
421 async fn test_clone_trait() {
422 let tracker = AuthTracker::new();
423 let cloned = tracker.clone();
424
425 let rx = tracker.begin();
427 cloned.succeed(); let result: Result<(), TestError> =
429 tracker.wait_for_result(Duration::from_secs(1), rx).await;
430 assert!(result.is_ok());
431 }
432
433 #[rstest]
434 fn test_debug_trait() {
435 let tracker = AuthTracker::new();
436 let debug_str = format!("{tracker:?}");
437 assert!(debug_str.contains("AuthTracker"));
438 }
439
440 #[rstest]
441 #[tokio::test]
442 async fn test_timeout_clears_sender() {
443 let tracker = AuthTracker::new();
444
445 let rx1 = tracker.begin();
447 let result1: Result<(), TestError> = tracker
448 .wait_for_result(Duration::from_millis(50), rx1)
449 .await;
450 assert_eq!(
451 result1.unwrap_err(),
452 TestError("Authentication timed out".to_string())
453 );
454
455 let rx2 = tracker.begin();
457 tracker.succeed();
458 let result2: Result<(), TestError> =
459 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
460 assert!(result2.is_ok());
461 }
462
463 #[rstest]
464 #[tokio::test]
465 async fn test_fail_clears_sender() {
466 let tracker = AuthTracker::new();
467
468 let rx1 = tracker.begin();
470 tracker.fail("Bad credentials");
471 let result1: Result<(), TestError> =
472 tracker.wait_for_result(Duration::from_secs(1), rx1).await;
473 assert!(result1.is_err());
474
475 let rx2 = tracker.begin();
477 tracker.succeed();
478 let result2: Result<(), TestError> =
479 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
480 assert!(result2.is_ok());
481 }
482
483 #[rstest]
484 #[tokio::test]
485 async fn test_succeed_clears_sender() {
486 let tracker = AuthTracker::new();
487
488 let rx1 = tracker.begin();
490 tracker.succeed();
491 let result1: Result<(), TestError> =
492 tracker.wait_for_result(Duration::from_secs(1), rx1).await;
493 assert!(result1.is_ok());
494
495 let rx2 = tracker.begin();
497 tracker.succeed();
498 let result2: Result<(), TestError> =
499 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
500 assert!(result2.is_ok());
501 }
502
503 #[rstest]
504 #[tokio::test]
505 async fn test_rapid_begin_succeed_cycles() {
506 let tracker = AuthTracker::new();
507
508 for _ in 0..100 {
510 let rx = tracker.begin();
511 tracker.succeed();
512 let result: Result<(), TestError> =
513 tracker.wait_for_result(Duration::from_secs(1), rx).await;
514 assert!(result.is_ok());
515 }
516 }
517
518 #[rstest]
519 #[tokio::test]
520 async fn test_double_succeed_is_safe() {
521 let tracker = AuthTracker::new();
522 let rx = tracker.begin();
523
524 tracker.succeed();
526 tracker.succeed(); let result: Result<(), TestError> =
529 tracker.wait_for_result(Duration::from_secs(1), rx).await;
530 assert!(result.is_ok());
531 }
532
533 #[rstest]
534 #[tokio::test]
535 async fn test_double_fail_is_safe() {
536 let tracker = AuthTracker::new();
537 let rx = tracker.begin();
538
539 tracker.fail("Error 1");
541 tracker.fail("Error 2"); let result: Result<(), TestError> =
544 tracker.wait_for_result(Duration::from_secs(1), rx).await;
545 assert_eq!(
546 result.unwrap_err(),
547 TestError("Error 1".to_string()) );
549 }
550
551 #[rstest]
552 #[tokio::test]
553 async fn test_succeed_after_fail_is_ignored() {
554 let tracker = AuthTracker::new();
555 let rx = tracker.begin();
556
557 tracker.fail("Auth failed");
558 tracker.succeed(); let result: Result<(), TestError> =
561 tracker.wait_for_result(Duration::from_secs(1), rx).await;
562 assert!(result.is_err()); }
564
565 #[rstest]
566 #[tokio::test]
567 async fn test_fail_after_succeed_is_ignored() {
568 let tracker = AuthTracker::new();
569 let rx = tracker.begin();
570
571 tracker.succeed();
572 tracker.fail("Auth failed"); let result: Result<(), TestError> =
575 tracker.wait_for_result(Duration::from_secs(1), rx).await;
576 assert!(result.is_ok()); }
578
579 #[rstest]
586 #[tokio::test]
587 async fn test_reconnect_flow_waits_for_auth() {
588 let tracker = Arc::new(AuthTracker::new());
589 let subscribed = Arc::new(tokio::sync::Notify::new());
590 let auth_completed = Arc::new(tokio::sync::Notify::new());
591
592 let tracker_reconnect = Arc::clone(&tracker);
594 let subscribed_reconnect = Arc::clone(&subscribed);
595 let auth_completed_reconnect = Arc::clone(&auth_completed);
596
597 let reconnect_task = tokio::spawn(async move {
598 let rx = tracker_reconnect.begin();
600
601 let tracker_resub = Arc::clone(&tracker_reconnect);
603 let subscribed_resub = Arc::clone(&subscribed_reconnect);
604 let auth_completed_resub = Arc::clone(&auth_completed_reconnect);
605
606 let resub_task = tokio::spawn(async move {
607 let result: Result<(), TestError> = tracker_resub
609 .wait_for_result(Duration::from_secs(5), rx)
610 .await;
611
612 if result.is_ok() {
613 auth_completed_resub.notify_one();
614 tokio::time::sleep(Duration::from_millis(10)).await;
616 subscribed_resub.notify_one();
617 }
618 });
619
620 resub_task.await.unwrap();
621 });
622
623 tokio::time::sleep(Duration::from_millis(100)).await;
625 tracker.succeed();
626
627 reconnect_task.await.unwrap();
629
630 tokio::select! {
632 () = auth_completed.notified() => {
633 }
635 () = tokio::time::sleep(Duration::from_secs(1)) => {
636 panic!("Auth never completed");
637 }
638 }
639
640 tokio::select! {
642 () = subscribed.notified() => {
643 }
645 () = tokio::time::sleep(Duration::from_secs(1)) => {
646 panic!("Subscription never completed");
647 }
648 }
649 }
650
651 #[rstest]
653 #[tokio::test]
654 async fn test_reconnect_flow_blocks_on_auth_failure() {
655 let tracker = Arc::new(AuthTracker::new());
656 let subscribed = Arc::new(AtomicBool::new(false));
657
658 let tracker_reconnect = Arc::clone(&tracker);
659 let subscribed_reconnect = Arc::clone(&subscribed);
660
661 let reconnect_task = tokio::spawn(async move {
662 let rx = tracker_reconnect.begin();
663
664 let tracker_resub = Arc::clone(&tracker_reconnect);
666 let subscribed_resub = Arc::clone(&subscribed_reconnect);
667
668 let resub_task = tokio::spawn(async move {
669 let result: Result<(), TestError> = tracker_resub
670 .wait_for_result(Duration::from_secs(5), rx)
671 .await;
672
673 if result.is_ok() {
675 subscribed_resub.store(true, Ordering::Relaxed);
676 }
677 });
678
679 resub_task.await.unwrap();
680 });
681
682 tokio::time::sleep(Duration::from_millis(50)).await;
684 tracker.fail("Invalid credentials");
685
686 reconnect_task.await.unwrap();
688
689 tokio::time::sleep(Duration::from_millis(100)).await;
691 assert!(!subscribed.load(Ordering::Relaxed));
692 }
693
694 #[rstest]
696 #[tokio::test]
697 async fn test_state_machine_transitions() {
698 let tracker = AuthTracker::new();
699
700 let rx1 = tracker.begin();
702
703 tracker.succeed();
705 let result1: Result<(), TestError> =
706 tracker.wait_for_result(Duration::from_secs(1), rx1).await;
707 assert!(result1.is_ok());
708
709 let rx2 = tracker.begin();
711
712 tracker.fail("Error");
714 let result2: Result<(), TestError> =
715 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
716 assert!(result2.is_err());
717
718 let rx3 = tracker.begin();
720
721 let result3: Result<(), TestError> = tracker
723 .wait_for_result(Duration::from_millis(50), rx3)
724 .await;
725 assert_eq!(
726 result3.unwrap_err(),
727 TestError("Authentication timed out".to_string())
728 );
729
730 let rx4 = tracker.begin();
732
733 let rx5 = tracker.begin();
735 let result4: Result<(), TestError> =
736 tracker.wait_for_result(Duration::from_secs(1), rx4).await;
737 assert_eq!(
738 result4.unwrap_err(),
739 TestError("Authentication attempt superseded".to_string())
740 );
741
742 tracker.succeed();
744 let result5: Result<(), TestError> =
745 tracker.wait_for_result(Duration::from_secs(1), rx5).await;
746 assert!(result5.is_ok());
747 }
748
749 #[rstest]
751 #[tokio::test]
752 async fn test_no_sender_leaks() {
753 let tracker = AuthTracker::new();
754
755 for _ in 0..100 {
756 let rx = tracker.begin();
757 let _result: Result<(), TestError> =
758 tracker.wait_for_result(Duration::from_millis(1), rx).await;
759 }
760
761 let rx = tracker.begin();
762 tracker.succeed();
763 let result: Result<(), TestError> =
764 tracker.wait_for_result(Duration::from_secs(1), rx).await;
765 assert!(result.is_ok());
766 }
767
768 #[rstest]
770 #[tokio::test]
771 async fn test_concurrent_succeed_fail_calls() {
772 let tracker = Arc::new(AuthTracker::new());
773 let rx = tracker.begin();
774
775 let mut handles = vec![];
776
777 for _ in 0..50 {
779 let tracker_clone = Arc::clone(&tracker);
780 handles.push(tokio::spawn(async move {
781 tracker_clone.succeed();
782 }));
783 }
784
785 for _ in 0..50 {
787 let tracker_clone = Arc::clone(&tracker);
788 handles.push(tokio::spawn(async move {
789 tracker_clone.fail("Error");
790 }));
791 }
792
793 for handle in handles {
795 handle.await.unwrap();
796 }
797
798 let result: Result<(), TestError> =
800 tracker.wait_for_result(Duration::from_secs(1), rx).await;
801 let _ = result;
803 }
804
805 #[rstest]
806 fn test_is_authenticated_initial_state() {
807 let tracker = AuthTracker::new();
808 assert!(!tracker.is_authenticated());
809 }
810
811 #[rstest]
812 #[tokio::test]
813 async fn test_is_authenticated_after_succeed() {
814 let tracker = AuthTracker::new();
815 assert!(!tracker.is_authenticated());
816
817 let _rx = tracker.begin();
818 assert!(!tracker.is_authenticated());
819
820 tracker.succeed();
821 assert!(tracker.is_authenticated());
822 }
823
824 #[rstest]
825 #[tokio::test]
826 async fn test_is_authenticated_after_fail() {
827 let tracker = AuthTracker::new();
828 let _rx = tracker.begin();
829 tracker.fail("error");
830 assert!(!tracker.is_authenticated());
831 }
832
833 #[rstest]
834 #[tokio::test]
835 async fn test_invalidate_clears_auth_state() {
836 let tracker = AuthTracker::new();
837 let _rx = tracker.begin();
838 tracker.succeed();
839 assert!(tracker.is_authenticated());
840
841 tracker.invalidate();
842 assert!(!tracker.is_authenticated());
843 }
844
845 #[rstest]
846 #[tokio::test]
847 async fn test_begin_clears_auth_state() {
848 let tracker = AuthTracker::new();
849 let _rx1 = tracker.begin();
850 tracker.succeed();
851 assert!(tracker.is_authenticated());
852
853 let _rx2 = tracker.begin();
854 assert!(!tracker.is_authenticated());
855 }
856
857 #[rstest]
858 fn test_is_authenticated_shared_across_clones() {
859 let tracker = AuthTracker::new();
860 let cloned = tracker.clone();
861
862 let _rx = tracker.begin();
863 tracker.succeed();
864
865 assert!(cloned.is_authenticated());
866 }
867
868 #[rstest]
869 fn test_invalidate_shared_across_clones() {
870 let tracker = AuthTracker::new();
871 let cloned = tracker.clone();
872
873 let _rx = tracker.begin();
874 tracker.succeed();
875 assert!(tracker.is_authenticated());
876
877 cloned.invalidate();
878 assert!(!tracker.is_authenticated());
879 }
880
881 #[rstest]
882 fn test_succeed_without_begin_still_updates_auth_state() {
883 let tracker = AuthTracker::new();
884 assert!(!tracker.is_authenticated());
885
886 tracker.succeed();
888 assert!(tracker.is_authenticated());
889 }
890
891 #[rstest]
892 fn test_fail_without_begin_still_updates_auth_state() {
893 let tracker = AuthTracker::new();
894 tracker.succeed();
895 assert!(tracker.is_authenticated());
896
897 tracker.fail("error");
899 assert!(!tracker.is_authenticated());
900 }
901
902 #[rstest]
903 #[tokio::test]
904 async fn test_auth_state_false_after_timeout_until_late_response() {
905 let tracker = AuthTracker::new();
906 let rx = tracker.begin();
907 assert!(!tracker.is_authenticated());
908
909 let result: Result<(), TestError> =
910 tracker.wait_for_result(Duration::from_millis(10), rx).await;
911
912 assert!(result.is_err());
913 assert!(!tracker.is_authenticated());
914
915 tracker.succeed();
917 assert!(tracker.is_authenticated());
918 }
919}