nautilus_network/websocket/
auth.rs1use std::{
40 sync::{Arc, Mutex},
41 time::Duration,
42};
43
44pub type AuthResultSender = tokio::sync::oneshot::Sender<Result<(), String>>;
45pub type AuthResultReceiver = tokio::sync::oneshot::Receiver<Result<(), String>>;
46
47#[derive(Clone, Debug)]
63pub struct AuthTracker {
64 tx: Arc<Mutex<Option<AuthResultSender>>>,
65}
66
67impl AuthTracker {
68 pub fn new() -> Self {
70 Self {
71 tx: Arc::new(Mutex::new(None)),
72 }
73 }
74
75 pub fn begin(&self) -> AuthResultReceiver {
81 let (sender, receiver) = tokio::sync::oneshot::channel();
82
83 if let Ok(mut guard) = self.tx.lock() {
84 if let Some(old) = guard.take() {
85 tracing::warn!("New authentication request superseding previous pending request");
86 let _ = old.send(Err("Authentication attempt superseded".to_string()));
87 } else {
88 tracing::debug!("Starting new authentication request");
89 }
90 *guard = Some(sender);
91 }
92
93 receiver
94 }
95
96 pub fn succeed(&self) {
103 if let Ok(mut guard) = self.tx.lock()
104 && let Some(sender) = guard.take()
105 {
106 let _ = sender.send(Ok(()));
107 }
108 }
109
110 pub fn fail(&self, error: impl Into<String>) {
117 let message = error.into();
118 if let Ok(mut guard) = self.tx.lock()
119 && let Some(sender) = guard.take()
120 {
121 let _ = sender.send(Err(message));
122 }
123 }
124
125 pub async fn wait_for_result<E>(
142 &self,
143 timeout: Duration,
144 receiver: AuthResultReceiver,
145 ) -> Result<(), E>
146 where
147 E: From<String>,
148 {
149 match tokio::time::timeout(timeout, receiver).await {
150 Ok(Ok(Ok(()))) => Ok(()),
151 Ok(Ok(Err(msg))) => Err(E::from(msg)),
152 Ok(Err(_)) => Err(E::from("Authentication channel closed".to_string())),
153 Err(_) => {
154 if let Ok(mut guard) = self.tx.lock() {
156 guard.take();
157 }
158 Err(E::from("Authentication timed out".to_string()))
159 }
160 }
161 }
162}
163
164impl Default for AuthTracker {
165 fn default() -> Self {
166 Self::new()
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use std::{
173 sync::atomic::{AtomicBool, Ordering},
174 time::Duration,
175 };
176
177 use rstest::rstest;
178
179 use super::*;
180
181 #[derive(Debug, PartialEq)]
182 struct TestError(String);
183
184 impl From<String> for TestError {
185 fn from(msg: String) -> Self {
186 Self(msg)
187 }
188 }
189
190 #[rstest]
191 #[tokio::test]
192 async fn test_successful_authentication() {
193 let tracker = AuthTracker::new();
194 let rx = tracker.begin();
195
196 tracker.succeed();
197
198 let result: Result<(), TestError> =
199 tracker.wait_for_result(Duration::from_secs(1), rx).await;
200
201 assert!(result.is_ok());
202 }
203
204 #[rstest]
205 #[tokio::test]
206 async fn test_failed_authentication() {
207 let tracker = AuthTracker::new();
208 let rx = tracker.begin();
209
210 tracker.fail("Invalid credentials");
211
212 let result: Result<(), TestError> =
213 tracker.wait_for_result(Duration::from_secs(1), rx).await;
214
215 assert_eq!(
216 result.unwrap_err(),
217 TestError("Invalid credentials".to_string())
218 );
219 }
220
221 #[rstest]
222 #[tokio::test]
223 async fn test_authentication_timeout() {
224 let tracker = AuthTracker::new();
225 let rx = tracker.begin();
226
227 let result: Result<(), TestError> =
230 tracker.wait_for_result(Duration::from_millis(50), rx).await;
231
232 assert_eq!(
233 result.unwrap_err(),
234 TestError("Authentication timed out".to_string())
235 );
236 }
237
238 #[rstest]
239 #[tokio::test]
240 async fn test_begin_supersedes_previous_sender() {
241 let tracker = AuthTracker::new();
242
243 let first = tracker.begin();
244 let second = tracker.begin();
245
246 let result = first.await.expect("oneshot closed unexpectedly");
248 assert_eq!(result, Err("Authentication attempt superseded".to_string()));
249
250 tracker.succeed();
252 let result: Result<(), TestError> = tracker
253 .wait_for_result(Duration::from_secs(1), second)
254 .await;
255
256 assert!(result.is_ok());
257 }
258
259 #[rstest]
260 #[tokio::test]
261 async fn test_succeed_without_pending_auth() {
262 let tracker = AuthTracker::new();
263
264 tracker.succeed();
266 }
267
268 #[rstest]
269 #[tokio::test]
270 async fn test_fail_without_pending_auth() {
271 let tracker = AuthTracker::new();
272
273 tracker.fail("Some error");
275 }
276
277 #[rstest]
278 #[tokio::test]
279 async fn test_multiple_sequential_authentications() {
280 let tracker = AuthTracker::new();
281
282 let rx1 = tracker.begin();
284 tracker.succeed();
285 let result1: Result<(), TestError> =
286 tracker.wait_for_result(Duration::from_secs(1), rx1).await;
287 assert!(result1.is_ok());
288
289 let rx2 = tracker.begin();
291 tracker.fail("Credentials expired");
292 let result2: Result<(), TestError> =
293 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
294 assert_eq!(
295 result2.unwrap_err(),
296 TestError("Credentials expired".to_string())
297 );
298
299 let rx3 = tracker.begin();
301 tracker.succeed();
302 let result3: Result<(), TestError> =
303 tracker.wait_for_result(Duration::from_secs(1), rx3).await;
304 assert!(result3.is_ok());
305 }
306
307 #[rstest]
308 #[tokio::test]
309 async fn test_channel_closed_before_result() {
310 let tracker = AuthTracker::new();
311 let rx = tracker.begin();
312
313 tracker.begin();
315
316 let result: Result<(), TestError> =
318 tracker.wait_for_result(Duration::from_secs(1), rx).await;
319
320 assert_eq!(
321 result.unwrap_err(),
322 TestError("Authentication attempt superseded".to_string())
323 );
324 }
325
326 #[rstest]
327 #[tokio::test]
328 async fn test_concurrent_auth_attempts() {
329 let tracker = Arc::new(AuthTracker::new());
330 let mut handles = vec![];
331
332 for i in 0..10 {
334 let tracker_clone = Arc::clone(&tracker);
335 let handle = tokio::spawn(async move {
336 let rx = tracker_clone.begin();
337
338 if i == 9 {
340 tokio::time::sleep(Duration::from_millis(10)).await;
341 tracker_clone.succeed();
342 }
343
344 let result: Result<(), TestError> = tracker_clone
345 .wait_for_result(Duration::from_secs(1), rx)
346 .await;
347
348 (i, result)
349 });
350 handles.push(handle);
351 }
352
353 let mut successes = 0;
354 let mut superseded = 0;
355
356 for handle in handles {
357 let (i, result) = handle.await.unwrap();
358 match result {
359 Ok(()) => {
360 assert_eq!(i, 9);
362 successes += 1;
363 }
364 Err(TestError(msg)) if msg.contains("superseded") => {
365 superseded += 1;
366 }
367 Err(e) => panic!("Unexpected error: {e:?}"),
368 }
369 }
370
371 assert_eq!(successes, 1);
372 assert_eq!(superseded, 9);
373 }
374
375 #[rstest]
376 fn test_default_trait() {
377 let _tracker = AuthTracker::default();
378 }
379
380 #[rstest]
381 #[tokio::test]
382 async fn test_clone_trait() {
383 let tracker = AuthTracker::new();
384 let cloned = tracker.clone();
385
386 let rx = tracker.begin();
388 cloned.succeed(); let result: Result<(), TestError> =
390 tracker.wait_for_result(Duration::from_secs(1), rx).await;
391 assert!(result.is_ok());
392 }
393
394 #[rstest]
395 fn test_debug_trait() {
396 let tracker = AuthTracker::new();
397 let debug_str = format!("{tracker:?}");
398 assert!(debug_str.contains("AuthTracker"));
399 }
400
401 #[rstest]
402 #[tokio::test]
403 async fn test_timeout_clears_sender() {
404 let tracker = AuthTracker::new();
405
406 let rx1 = tracker.begin();
408 let result1: Result<(), TestError> = tracker
409 .wait_for_result(Duration::from_millis(50), rx1)
410 .await;
411 assert_eq!(
412 result1.unwrap_err(),
413 TestError("Authentication timed out".to_string())
414 );
415
416 let rx2 = tracker.begin();
418 tracker.succeed();
419 let result2: Result<(), TestError> =
420 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
421 assert!(result2.is_ok());
422 }
423
424 #[rstest]
425 #[tokio::test]
426 async fn test_fail_clears_sender() {
427 let tracker = AuthTracker::new();
428
429 let rx1 = tracker.begin();
431 tracker.fail("Bad credentials");
432 let result1: Result<(), TestError> =
433 tracker.wait_for_result(Duration::from_secs(1), rx1).await;
434 assert!(result1.is_err());
435
436 let rx2 = tracker.begin();
438 tracker.succeed();
439 let result2: Result<(), TestError> =
440 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
441 assert!(result2.is_ok());
442 }
443
444 #[rstest]
445 #[tokio::test]
446 async fn test_succeed_clears_sender() {
447 let tracker = AuthTracker::new();
448
449 let rx1 = tracker.begin();
451 tracker.succeed();
452 let result1: Result<(), TestError> =
453 tracker.wait_for_result(Duration::from_secs(1), rx1).await;
454 assert!(result1.is_ok());
455
456 let rx2 = tracker.begin();
458 tracker.succeed();
459 let result2: Result<(), TestError> =
460 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
461 assert!(result2.is_ok());
462 }
463
464 #[rstest]
465 #[tokio::test]
466 async fn test_rapid_begin_succeed_cycles() {
467 let tracker = AuthTracker::new();
468
469 for _ in 0..100 {
471 let rx = tracker.begin();
472 tracker.succeed();
473 let result: Result<(), TestError> =
474 tracker.wait_for_result(Duration::from_secs(1), rx).await;
475 assert!(result.is_ok());
476 }
477 }
478
479 #[rstest]
480 #[tokio::test]
481 async fn test_double_succeed_is_safe() {
482 let tracker = AuthTracker::new();
483 let rx = tracker.begin();
484
485 tracker.succeed();
487 tracker.succeed(); let result: Result<(), TestError> =
490 tracker.wait_for_result(Duration::from_secs(1), rx).await;
491 assert!(result.is_ok());
492 }
493
494 #[rstest]
495 #[tokio::test]
496 async fn test_double_fail_is_safe() {
497 let tracker = AuthTracker::new();
498 let rx = tracker.begin();
499
500 tracker.fail("Error 1");
502 tracker.fail("Error 2"); let result: Result<(), TestError> =
505 tracker.wait_for_result(Duration::from_secs(1), rx).await;
506 assert_eq!(
507 result.unwrap_err(),
508 TestError("Error 1".to_string()) );
510 }
511
512 #[rstest]
513 #[tokio::test]
514 async fn test_succeed_after_fail_is_ignored() {
515 let tracker = AuthTracker::new();
516 let rx = tracker.begin();
517
518 tracker.fail("Auth failed");
519 tracker.succeed(); let result: Result<(), TestError> =
522 tracker.wait_for_result(Duration::from_secs(1), rx).await;
523 assert!(result.is_err()); }
525
526 #[rstest]
527 #[tokio::test]
528 async fn test_fail_after_succeed_is_ignored() {
529 let tracker = AuthTracker::new();
530 let rx = tracker.begin();
531
532 tracker.succeed();
533 tracker.fail("Auth failed"); let result: Result<(), TestError> =
536 tracker.wait_for_result(Duration::from_secs(1), rx).await;
537 assert!(result.is_ok()); }
539
540 #[rstest]
547 #[tokio::test]
548 async fn test_reconnect_flow_waits_for_auth() {
549 let tracker = Arc::new(AuthTracker::new());
550 let subscribed = Arc::new(tokio::sync::Notify::new());
551 let auth_completed = Arc::new(tokio::sync::Notify::new());
552
553 let tracker_reconnect = Arc::clone(&tracker);
555 let subscribed_reconnect = Arc::clone(&subscribed);
556 let auth_completed_reconnect = Arc::clone(&auth_completed);
557
558 let reconnect_task = tokio::spawn(async move {
559 let rx = tracker_reconnect.begin();
561
562 let tracker_resub = Arc::clone(&tracker_reconnect);
564 let subscribed_resub = Arc::clone(&subscribed_reconnect);
565 let auth_completed_resub = Arc::clone(&auth_completed_reconnect);
566
567 let resub_task = tokio::spawn(async move {
568 let result: Result<(), TestError> = tracker_resub
570 .wait_for_result(Duration::from_secs(5), rx)
571 .await;
572
573 if result.is_ok() {
574 auth_completed_resub.notify_one();
575 tokio::time::sleep(Duration::from_millis(10)).await;
577 subscribed_resub.notify_one();
578 }
579 });
580
581 resub_task.await.unwrap();
582 });
583
584 tokio::time::sleep(Duration::from_millis(100)).await;
586 tracker.succeed();
587
588 reconnect_task.await.unwrap();
590
591 tokio::select! {
593 _ = auth_completed.notified() => {
594 }
596 _ = tokio::time::sleep(Duration::from_secs(1)) => {
597 panic!("Auth never completed");
598 }
599 }
600
601 tokio::select! {
603 _ = subscribed.notified() => {
604 }
606 _ = tokio::time::sleep(Duration::from_secs(1)) => {
607 panic!("Subscription never completed");
608 }
609 }
610 }
611
612 #[rstest]
614 #[tokio::test]
615 async fn test_reconnect_flow_blocks_on_auth_failure() {
616 let tracker = Arc::new(AuthTracker::new());
617 let subscribed = Arc::new(AtomicBool::new(false));
618
619 let tracker_reconnect = Arc::clone(&tracker);
620 let subscribed_reconnect = Arc::clone(&subscribed);
621
622 let reconnect_task = tokio::spawn(async move {
623 let rx = tracker_reconnect.begin();
624
625 let tracker_resub = Arc::clone(&tracker_reconnect);
627 let subscribed_resub = Arc::clone(&subscribed_reconnect);
628
629 let resub_task = tokio::spawn(async move {
630 let result: Result<(), TestError> = tracker_resub
631 .wait_for_result(Duration::from_secs(5), rx)
632 .await;
633
634 if result.is_ok() {
636 subscribed_resub.store(true, Ordering::Relaxed);
637 }
638 });
639
640 resub_task.await.unwrap();
641 });
642
643 tokio::time::sleep(Duration::from_millis(50)).await;
645 tracker.fail("Invalid credentials");
646
647 reconnect_task.await.unwrap();
649
650 tokio::time::sleep(Duration::from_millis(100)).await;
652 assert!(!subscribed.load(Ordering::Relaxed));
653 }
654
655 #[rstest]
657 #[tokio::test]
658 async fn test_state_machine_transitions() {
659 let tracker = AuthTracker::new();
660
661 let rx1 = tracker.begin();
663
664 tracker.succeed();
666 let result1: Result<(), TestError> =
667 tracker.wait_for_result(Duration::from_secs(1), rx1).await;
668 assert!(result1.is_ok());
669
670 let rx2 = tracker.begin();
672
673 tracker.fail("Error");
675 let result2: Result<(), TestError> =
676 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
677 assert!(result2.is_err());
678
679 let rx3 = tracker.begin();
681
682 let result3: Result<(), TestError> = tracker
684 .wait_for_result(Duration::from_millis(50), rx3)
685 .await;
686 assert_eq!(
687 result3.unwrap_err(),
688 TestError("Authentication timed out".to_string())
689 );
690
691 let rx4 = tracker.begin();
693
694 let rx5 = tracker.begin();
696 let result4: Result<(), TestError> =
697 tracker.wait_for_result(Duration::from_secs(1), rx4).await;
698 assert_eq!(
699 result4.unwrap_err(),
700 TestError("Authentication attempt superseded".to_string())
701 );
702
703 tracker.succeed();
705 let result5: Result<(), TestError> =
706 tracker.wait_for_result(Duration::from_secs(1), rx5).await;
707 assert!(result5.is_ok());
708 }
709
710 #[rstest]
712 #[tokio::test]
713 async fn test_no_sender_leaks() {
714 let tracker = AuthTracker::new();
715
716 for _ in 0..100 {
717 let rx = tracker.begin();
718 let _result: Result<(), TestError> =
719 tracker.wait_for_result(Duration::from_millis(1), rx).await;
720 }
721
722 let rx = tracker.begin();
723 tracker.succeed();
724 let result: Result<(), TestError> =
725 tracker.wait_for_result(Duration::from_secs(1), rx).await;
726 assert!(result.is_ok());
727 }
728
729 #[rstest]
731 #[tokio::test]
732 async fn test_concurrent_succeed_fail_calls() {
733 let tracker = Arc::new(AuthTracker::new());
734 let rx = tracker.begin();
735
736 let mut handles = vec![];
737
738 for _ in 0..50 {
740 let tracker_clone = Arc::clone(&tracker);
741 handles.push(tokio::spawn(async move {
742 tracker_clone.succeed();
743 }));
744 }
745
746 for _ in 0..50 {
748 let tracker_clone = Arc::clone(&tracker);
749 handles.push(tokio::spawn(async move {
750 tracker_clone.fail("Error");
751 }));
752 }
753
754 for handle in handles {
756 handle.await.unwrap();
757 }
758
759 let result: Result<(), TestError> =
761 tracker.wait_for_result(Duration::from_secs(1), rx).await;
762 let _ = result;
764 }
765}