1use std::{
38 num::NonZeroUsize,
39 sync::{Arc, LazyLock},
40};
41
42use ahash::AHashSet;
43use dashmap::DashMap;
44use ustr::Ustr;
45
46pub(crate) static CHANNEL_LEVEL_MARKER: LazyLock<Ustr> = LazyLock::new(|| Ustr::from(""));
51
52#[derive(Clone, Debug)]
69pub struct SubscriptionState {
70 confirmed: Arc<DashMap<Ustr, AHashSet<Ustr>>>,
72 pending_subscribe: Arc<DashMap<Ustr, AHashSet<Ustr>>>,
74 pending_unsubscribe: Arc<DashMap<Ustr, AHashSet<Ustr>>>,
76 reference_counts: Arc<DashMap<Ustr, NonZeroUsize>>,
78 delimiter: char,
80}
81
82impl SubscriptionState {
83 pub fn new(delimiter: char) -> Self {
85 Self {
86 confirmed: Arc::new(DashMap::new()),
87 pending_subscribe: Arc::new(DashMap::new()),
88 pending_unsubscribe: Arc::new(DashMap::new()),
89 reference_counts: Arc::new(DashMap::new()),
90 delimiter,
91 }
92 }
93
94 pub fn delimiter(&self) -> char {
96 self.delimiter
97 }
98
99 pub fn confirmed(&self) -> Arc<DashMap<Ustr, AHashSet<Ustr>>> {
101 Arc::clone(&self.confirmed)
102 }
103
104 pub fn pending_subscribe(&self) -> Arc<DashMap<Ustr, AHashSet<Ustr>>> {
106 Arc::clone(&self.pending_subscribe)
107 }
108
109 pub fn pending_unsubscribe(&self) -> Arc<DashMap<Ustr, AHashSet<Ustr>>> {
111 Arc::clone(&self.pending_unsubscribe)
112 }
113
114 pub fn len(&self) -> usize {
118 self.confirmed.iter().map(|entry| entry.value().len()).sum()
119 }
120
121 pub fn is_empty(&self) -> bool {
123 self.confirmed.is_empty()
124 && self.pending_subscribe.is_empty()
125 && self.pending_unsubscribe.is_empty()
126 }
127
128 pub fn is_subscribed(&self, channel: &Ustr, symbol: &Ustr) -> bool {
130 if let Some(symbols) = self.confirmed.get(channel)
131 && symbols.contains(symbol)
132 {
133 return true;
134 }
135 if let Some(symbols) = self.pending_subscribe.get(channel)
136 && symbols.contains(symbol)
137 {
138 return true;
139 }
140 false
141 }
142
143 pub fn mark_subscribe(&self, topic: &str) {
149 let (channel, symbol) = split_topic(topic, self.delimiter);
150
151 if is_tracked(&self.confirmed, channel, symbol) {
153 return;
154 }
155
156 untrack_topic(&self.pending_unsubscribe, channel, symbol);
158
159 track_topic(&self.pending_subscribe, channel, symbol);
160 }
161
162 pub fn mark_unsubscribe(&self, topic: &str) {
168 let (channel, symbol) = split_topic(topic, self.delimiter);
169 track_topic(&self.pending_unsubscribe, channel, symbol);
170 untrack_topic(&self.confirmed, channel, symbol);
171 untrack_topic(&self.pending_subscribe, channel, symbol);
172 }
173
174 pub fn confirm_subscribe(&self, topic: &str) {
180 let (channel, symbol) = split_topic(topic, self.delimiter);
181
182 if is_tracked(&self.pending_unsubscribe, channel, symbol) {
184 return;
185 }
186
187 untrack_topic(&self.pending_subscribe, channel, symbol);
188 track_topic(&self.confirmed, channel, symbol);
189 }
190
191 pub fn confirm_unsubscribe(&self, topic: &str) {
202 let (channel, symbol) = split_topic(topic, self.delimiter);
203
204 if !is_tracked(&self.pending_unsubscribe, channel, symbol) {
207 return; }
209
210 untrack_topic(&self.pending_unsubscribe, channel, symbol);
211 untrack_topic(&self.confirmed, channel, symbol);
212 }
214
215 pub fn mark_failure(&self, topic: &str) {
220 let (channel, symbol) = split_topic(topic, self.delimiter);
221
222 if is_tracked(&self.pending_unsubscribe, channel, symbol) {
224 return;
225 }
226
227 untrack_topic(&self.confirmed, channel, symbol);
228 track_topic(&self.pending_subscribe, channel, symbol);
229 }
230
231 pub fn pending_subscribe_topics(&self) -> Vec<String> {
233 self.topics_from_map(&self.pending_subscribe)
234 }
235
236 pub fn pending_unsubscribe_topics(&self) -> Vec<String> {
238 self.topics_from_map(&self.pending_unsubscribe)
239 }
240
241 pub fn all_topics(&self) -> Vec<String> {
248 let mut topics = Vec::new();
249 topics.extend(self.topics_from_map(&self.confirmed));
250 topics.extend(self.topics_from_map(&self.pending_subscribe));
251 topics
252 }
253
254 fn topics_from_map(&self, map: &DashMap<Ustr, AHashSet<Ustr>>) -> Vec<String> {
256 let mut topics = Vec::new();
257 let marker = *CHANNEL_LEVEL_MARKER;
258
259 for entry in map {
260 let channel = entry.key();
261 let symbols = entry.value();
262
263 if symbols.contains(&marker) {
265 topics.push(channel.to_string());
266 }
267
268 for symbol in symbols {
270 if *symbol != marker {
271 topics.push(format!(
272 "{}{}{}",
273 channel.as_str(),
274 self.delimiter,
275 symbol.as_str()
276 ));
277 }
278 }
279 }
280
281 topics
282 }
283
284 pub fn add_reference(&self, topic: &str) -> bool {
293 let mut should_subscribe = false;
294 let topic_ustr = Ustr::from(topic);
295
296 self.reference_counts
297 .entry(topic_ustr)
298 .and_modify(|count| {
299 *count = NonZeroUsize::new(count.get() + 1).expect("reference count overflow");
300 })
301 .or_insert_with(|| {
302 should_subscribe = true;
303 NonZeroUsize::new(1).expect("NonZeroUsize::new(1) should never fail")
304 });
305
306 should_subscribe
307 }
308
309 pub fn remove_reference(&self, topic: &str) -> bool {
319 let topic_ustr = Ustr::from(topic);
320
321 if let dashmap::mapref::entry::Entry::Occupied(mut entry) =
324 self.reference_counts.entry(topic_ustr)
325 {
326 let current = entry.get().get();
327
328 if current == 1 {
329 entry.remove();
330 return true;
331 }
332
333 *entry.get_mut() = NonZeroUsize::new(current - 1)
334 .expect("reference count should never reach zero here");
335 }
336
337 false
338 }
339
340 pub fn get_reference_count(&self, topic: &str) -> usize {
344 let topic_ustr = Ustr::from(topic);
345 self.reference_counts
346 .get(&topic_ustr)
347 .map_or(0, |count| count.get())
348 }
349
350 pub fn clear(&self) {
354 self.confirmed.clear();
355 self.pending_subscribe.clear();
356 self.pending_unsubscribe.clear();
357 self.reference_counts.clear();
358 }
359}
360
361pub fn split_topic(topic: &str, delimiter: char) -> (&str, Option<&str>) {
363 topic
364 .split_once(delimiter)
365 .map_or((topic, None), |(channel, symbol)| (channel, Some(symbol)))
366}
367
368fn track_topic(map: &DashMap<Ustr, AHashSet<Ustr>>, channel: &str, symbol: Option<&str>) {
373 let channel_ustr = Ustr::from(channel);
374 let mut entry = map.entry(channel_ustr).or_default();
375
376 if let Some(symbol) = symbol {
377 entry.insert(Ustr::from(symbol));
378 } else {
379 entry.insert(*CHANNEL_LEVEL_MARKER);
380 }
381}
382
383fn untrack_topic(map: &DashMap<Ustr, AHashSet<Ustr>>, channel: &str, symbol: Option<&str>) {
387 let channel_ustr = Ustr::from(channel);
388 let symbol_to_remove = if let Some(symbol) = symbol {
389 Ustr::from(symbol)
390 } else {
391 *CHANNEL_LEVEL_MARKER
392 };
393
394 if let dashmap::mapref::entry::Entry::Occupied(mut entry) = map.entry(channel_ustr) {
397 entry.get_mut().remove(&symbol_to_remove);
398 if entry.get().is_empty() {
399 entry.remove();
400 }
401 }
402}
403
404fn is_tracked(map: &DashMap<Ustr, AHashSet<Ustr>>, channel: &str, symbol: Option<&str>) -> bool {
406 let channel_ustr = Ustr::from(channel);
407 let symbol_to_check = if let Some(symbol) = symbol {
408 Ustr::from(symbol)
409 } else {
410 *CHANNEL_LEVEL_MARKER
411 };
412
413 if let Some(entry) = map.get(&channel_ustr) {
414 entry.contains(&symbol_to_check)
415 } else {
416 false
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use rstest::rstest;
423
424 use super::*;
425
426 #[rstest]
427 fn test_split_topic_with_symbol() {
428 let (channel, symbol) = split_topic("tickers.BTCUSDT", '.');
429 assert_eq!(channel, "tickers");
430 assert_eq!(symbol, Some("BTCUSDT"));
431
432 let (channel, symbol) = split_topic("orderBookL2:XBTUSD", ':');
433 assert_eq!(channel, "orderBookL2");
434 assert_eq!(symbol, Some("XBTUSD"));
435 }
436
437 #[rstest]
438 fn test_split_topic_without_symbol() {
439 let (channel, symbol) = split_topic("orderbook", '.');
440 assert_eq!(channel, "orderbook");
441 assert_eq!(symbol, None);
442 }
443
444 #[rstest]
445 fn test_new_state_is_empty() {
446 let state = SubscriptionState::new('.');
447 assert!(state.is_empty());
448 assert_eq!(state.len(), 0);
449 }
450
451 #[rstest]
452 fn test_mark_subscribe() {
453 let state = SubscriptionState::new('.');
454 state.mark_subscribe("tickers.BTCUSDT");
455
456 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
457 assert_eq!(state.len(), 0); }
459
460 #[rstest]
461 fn test_confirm_subscribe() {
462 let state = SubscriptionState::new('.');
463 state.mark_subscribe("tickers.BTCUSDT");
464 state.confirm_subscribe("tickers.BTCUSDT");
465
466 assert!(state.pending_subscribe_topics().is_empty());
467 assert_eq!(state.len(), 1);
468 }
469
470 #[rstest]
471 fn test_is_subscribed_empty_state() {
472 let state = SubscriptionState::new('.');
473 let channel = Ustr::from("tickers");
474 let symbol = Ustr::from("BTCUSDT");
475
476 assert!(!state.is_subscribed(&channel, &symbol));
477 }
478
479 #[rstest]
480 fn test_is_subscribed_pending() {
481 let state = SubscriptionState::new('.');
482 let channel = Ustr::from("tickers");
483 let symbol = Ustr::from("BTCUSDT");
484
485 state.mark_subscribe("tickers.BTCUSDT");
486
487 assert!(state.is_subscribed(&channel, &symbol));
488 }
489
490 #[rstest]
491 fn test_is_subscribed_confirmed() {
492 let state = SubscriptionState::new('.');
493 let channel = Ustr::from("tickers");
494 let symbol = Ustr::from("BTCUSDT");
495
496 state.mark_subscribe("tickers.BTCUSDT");
497 state.confirm_subscribe("tickers.BTCUSDT");
498
499 assert!(state.is_subscribed(&channel, &symbol));
500 }
501
502 #[rstest]
503 fn test_is_subscribed_after_unsubscribe() {
504 let state = SubscriptionState::new('.');
505 let channel = Ustr::from("tickers");
506 let symbol = Ustr::from("BTCUSDT");
507
508 state.mark_subscribe("tickers.BTCUSDT");
509 state.confirm_subscribe("tickers.BTCUSDT");
510 state.mark_unsubscribe("tickers.BTCUSDT");
511
512 assert!(!state.is_subscribed(&channel, &symbol));
514 }
515
516 #[rstest]
517 fn test_is_subscribed_after_confirm_unsubscribe() {
518 let state = SubscriptionState::new('.');
519 let channel = Ustr::from("tickers");
520 let symbol = Ustr::from("BTCUSDT");
521
522 state.mark_subscribe("tickers.BTCUSDT");
523 state.confirm_subscribe("tickers.BTCUSDT");
524 state.mark_unsubscribe("tickers.BTCUSDT");
525 state.confirm_unsubscribe("tickers.BTCUSDT");
526
527 assert!(!state.is_subscribed(&channel, &symbol));
528 }
529
530 #[rstest]
531 fn test_mark_unsubscribe() {
532 let state = SubscriptionState::new('.');
533 state.mark_subscribe("tickers.BTCUSDT");
534 state.confirm_subscribe("tickers.BTCUSDT");
535 state.mark_unsubscribe("tickers.BTCUSDT");
536
537 assert_eq!(state.len(), 0); assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
539 }
540
541 #[rstest]
542 fn test_confirm_unsubscribe() {
543 let state = SubscriptionState::new('.');
544 state.mark_subscribe("tickers.BTCUSDT");
545 state.confirm_subscribe("tickers.BTCUSDT");
546 state.mark_unsubscribe("tickers.BTCUSDT");
547 state.confirm_unsubscribe("tickers.BTCUSDT");
548
549 assert!(state.is_empty());
550 }
551
552 #[rstest]
553 fn test_resubscribe_before_unsubscribe_ack() {
554 let state = SubscriptionState::new('.');
558
559 state.mark_subscribe("tickers.BTCUSDT");
560 state.confirm_subscribe("tickers.BTCUSDT");
561 assert_eq!(state.len(), 1);
562
563 state.mark_unsubscribe("tickers.BTCUSDT");
564 assert_eq!(state.len(), 0);
565 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
566
567 state.mark_subscribe("tickers.BTCUSDT");
569 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
570
571 state.confirm_unsubscribe("tickers.BTCUSDT");
573 assert!(state.pending_unsubscribe_topics().is_empty());
574 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]); state.confirm_subscribe("tickers.BTCUSDT");
578 assert_eq!(state.len(), 1);
579 assert!(state.pending_subscribe_topics().is_empty());
580
581 let all = state.all_topics();
583 assert_eq!(all.len(), 1);
584 assert!(all.contains(&"tickers.BTCUSDT".to_string()));
585 }
586
587 #[rstest]
588 fn test_stale_unsubscribe_ack_after_resubscribe_confirmed() {
589 let state = SubscriptionState::new('.');
594
595 state.mark_subscribe("tickers.BTCUSDT");
597 state.confirm_subscribe("tickers.BTCUSDT");
598 assert_eq!(state.len(), 1);
599
600 state.mark_unsubscribe("tickers.BTCUSDT");
602 assert_eq!(state.len(), 0);
603 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
604
605 state.mark_subscribe("tickers.BTCUSDT");
607 assert!(state.pending_unsubscribe_topics().is_empty()); assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
609
610 state.confirm_subscribe("tickers.BTCUSDT");
612 assert_eq!(state.len(), 1); assert!(state.pending_subscribe_topics().is_empty());
614
615 state.confirm_unsubscribe("tickers.BTCUSDT");
618
619 assert_eq!(state.len(), 1); assert!(state.pending_unsubscribe_topics().is_empty());
622 assert!(state.pending_subscribe_topics().is_empty());
623
624 let all = state.all_topics();
626 assert_eq!(all.len(), 1);
627 assert!(all.contains(&"tickers.BTCUSDT".to_string()));
628 }
629
630 #[rstest]
631 fn test_mark_failure() {
632 let state = SubscriptionState::new('.');
633 state.mark_subscribe("tickers.BTCUSDT");
634 state.confirm_subscribe("tickers.BTCUSDT");
635 state.mark_failure("tickers.BTCUSDT");
636
637 assert_eq!(state.len(), 0);
638 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
639 }
640
641 #[rstest]
642 fn test_all_topics_includes_confirmed_and_pending_subscribe() {
643 let state = SubscriptionState::new('.');
644 state.mark_subscribe("tickers.BTCUSDT");
645 state.confirm_subscribe("tickers.BTCUSDT");
646 state.mark_subscribe("tickers.ETHUSDT");
647
648 let topics = state.all_topics();
649 assert_eq!(topics.len(), 2);
650 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
651 assert!(topics.contains(&"tickers.ETHUSDT".to_string()));
652 }
653
654 #[rstest]
655 fn test_all_topics_excludes_pending_unsubscribe() {
656 let state = SubscriptionState::new('.');
657 state.mark_subscribe("tickers.BTCUSDT");
658 state.confirm_subscribe("tickers.BTCUSDT");
659 state.mark_unsubscribe("tickers.BTCUSDT");
660
661 let topics = state.all_topics();
662 assert!(topics.is_empty());
663 }
664
665 #[rstest]
666 fn test_reference_counting_single_topic() {
667 let state = SubscriptionState::new('.');
668
669 assert!(state.add_reference("tickers.BTCUSDT"));
670 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 1);
671
672 assert!(!state.add_reference("tickers.BTCUSDT"));
673 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 2);
674
675 assert!(!state.remove_reference("tickers.BTCUSDT"));
676 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 1);
677
678 assert!(state.remove_reference("tickers.BTCUSDT"));
679 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 0);
680 }
681
682 #[rstest]
683 fn test_reference_counting_multiple_topics() {
684 let state = SubscriptionState::new('.');
685
686 assert!(state.add_reference("tickers.BTCUSDT"));
687 assert!(state.add_reference("tickers.ETHUSDT"));
688
689 assert!(!state.add_reference("tickers.BTCUSDT"));
690 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 2);
691 assert_eq!(state.get_reference_count("tickers.ETHUSDT"), 1);
692
693 assert!(!state.remove_reference("tickers.BTCUSDT"));
694 assert!(state.remove_reference("tickers.ETHUSDT"));
695 }
696
697 #[rstest]
698 fn test_topic_without_symbol() {
699 let state = SubscriptionState::new('.');
700 state.mark_subscribe("orderbook");
701 state.confirm_subscribe("orderbook");
702
703 assert_eq!(state.len(), 1);
704 assert_eq!(state.all_topics(), vec!["orderbook"]);
705 }
706
707 #[rstest]
708 fn test_different_delimiters() {
709 let state_dot = SubscriptionState::new('.');
710 state_dot.mark_subscribe("tickers.BTCUSDT");
711 assert_eq!(
712 state_dot.pending_subscribe_topics(),
713 vec!["tickers.BTCUSDT"]
714 );
715
716 let state_colon = SubscriptionState::new(':');
717 state_colon.mark_subscribe("orderBookL2:XBTUSD");
718 assert_eq!(
719 state_colon.pending_subscribe_topics(),
720 vec!["orderBookL2:XBTUSD"]
721 );
722 }
723
724 #[rstest]
725 fn test_clear() {
726 let state = SubscriptionState::new('.');
727 state.mark_subscribe("tickers.BTCUSDT");
728 state.confirm_subscribe("tickers.BTCUSDT");
729 state.add_reference("tickers.BTCUSDT");
730
731 state.clear();
732
733 assert!(state.is_empty());
734 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 0);
735 }
736
737 #[rstest]
738 fn test_multiple_symbols_same_channel() {
739 let state = SubscriptionState::new('.');
740 state.mark_subscribe("tickers.BTCUSDT");
741 state.mark_subscribe("tickers.ETHUSDT");
742 state.confirm_subscribe("tickers.BTCUSDT");
743 state.confirm_subscribe("tickers.ETHUSDT");
744
745 assert_eq!(state.len(), 2);
746 let topics = state.all_topics();
747 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
748 assert!(topics.contains(&"tickers.ETHUSDT".to_string()));
749 }
750
751 #[rstest]
752 fn test_mixed_channel_and_symbol_subscriptions() {
753 let state = SubscriptionState::new('.');
754
755 state.mark_subscribe("tickers");
757 state.confirm_subscribe("tickers");
758 assert_eq!(state.len(), 1);
759 assert_eq!(state.all_topics(), vec!["tickers"]);
760
761 state.mark_subscribe("tickers.BTCUSDT");
763 state.confirm_subscribe("tickers.BTCUSDT");
764 assert_eq!(state.len(), 2);
765
766 let topics = state.all_topics();
768 assert_eq!(topics.len(), 2);
769 assert!(topics.contains(&"tickers".to_string()));
770 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
771
772 state.mark_subscribe("tickers.ETHUSDT");
774 state.confirm_subscribe("tickers.ETHUSDT");
775 assert_eq!(state.len(), 3);
776
777 let topics = state.all_topics();
778 assert_eq!(topics.len(), 3);
779 assert!(topics.contains(&"tickers".to_string()));
780 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
781 assert!(topics.contains(&"tickers.ETHUSDT".to_string()));
782
783 state.mark_unsubscribe("tickers");
785 state.confirm_unsubscribe("tickers");
786 assert_eq!(state.len(), 2);
787
788 let topics = state.all_topics();
789 assert_eq!(topics.len(), 2);
790 assert!(!topics.contains(&"tickers".to_string()));
791 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
792 assert!(topics.contains(&"tickers.ETHUSDT".to_string()));
793 }
794
795 #[rstest]
796 fn test_symbol_subscription_before_channel() {
797 let state = SubscriptionState::new('.');
798
799 state.mark_subscribe("tickers.BTCUSDT");
801 state.confirm_subscribe("tickers.BTCUSDT");
802 assert_eq!(state.len(), 1);
803
804 state.mark_subscribe("tickers");
806 state.confirm_subscribe("tickers");
807 assert_eq!(state.len(), 2);
808
809 let topics = state.all_topics();
811 assert_eq!(topics.len(), 2);
812 assert!(topics.contains(&"tickers".to_string()));
813 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
814 }
815
816 #[rstest]
817 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
818 async fn test_concurrent_subscribe_same_topic() {
819 let state = Arc::new(SubscriptionState::new('.'));
820 let mut handles = vec![];
821
822 for _ in 0..10 {
824 let state_clone = Arc::clone(&state);
825 let handle = tokio::spawn(async move {
826 state_clone.add_reference("tickers.BTCUSDT");
827 state_clone.mark_subscribe("tickers.BTCUSDT");
828 state_clone.confirm_subscribe("tickers.BTCUSDT");
829 });
830 handles.push(handle);
831 }
832
833 for handle in handles {
834 handle.await.unwrap();
835 }
836
837 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 10);
839 assert_eq!(state.len(), 1);
840 }
841
842 #[rstest]
843 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
844 async fn test_concurrent_subscribe_unsubscribe() {
845 let state = Arc::new(SubscriptionState::new('.'));
846 let mut handles = vec![];
847
848 for i in 0..20 {
851 let state_clone = Arc::clone(&state);
852 let handle = tokio::spawn(async move {
853 let topic = format!("tickers.SYMBOL{i}");
854 state_clone.add_reference(&topic);
856 state_clone.add_reference(&topic);
857 state_clone.mark_subscribe(&topic);
858 state_clone.confirm_subscribe(&topic);
859
860 state_clone.remove_reference(&topic);
862 });
863 handles.push(handle);
864 }
865
866 for handle in handles {
867 handle.await.unwrap();
868 }
869
870 for i in 0..20 {
872 let topic = format!("tickers.SYMBOL{i}");
873 assert_eq!(state.get_reference_count(&topic), 1);
874 }
875
876 assert_eq!(state.len(), 20);
878 assert!(!state.is_empty());
879 }
880
881 #[rstest]
882 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
883 async fn test_concurrent_reference_counting_same_topic() {
884 let state = Arc::new(SubscriptionState::new('.'));
885 let topic = "tickers.BTCUSDT";
886 let mut handles = vec![];
887
888 for _ in 0..10 {
890 let state_clone = Arc::clone(&state);
891 let handle = tokio::spawn(async move {
892 for _ in 0..10 {
893 state_clone.add_reference(topic);
894 }
895 });
896 handles.push(handle);
897 }
898
899 for handle in handles {
900 handle.await.unwrap();
901 }
902
903 assert_eq!(state.get_reference_count(topic), 100);
905
906 for _ in 0..50 {
908 state.remove_reference(topic);
909 }
910
911 assert_eq!(state.get_reference_count(topic), 50);
913 }
914
915 #[rstest]
916 fn test_reconnection_scenario() {
917 let state = SubscriptionState::new('.');
918
919 state.add_reference("tickers.BTCUSDT");
921 state.mark_subscribe("tickers.BTCUSDT");
922 state.confirm_subscribe("tickers.BTCUSDT");
923
924 state.add_reference("tickers.ETHUSDT");
925 state.mark_subscribe("tickers.ETHUSDT");
926 state.confirm_subscribe("tickers.ETHUSDT");
927
928 state.add_reference("orderbook");
929 state.mark_subscribe("orderbook");
930 state.confirm_subscribe("orderbook");
931
932 assert_eq!(state.len(), 3);
933
934 let topics_to_resubscribe = state.all_topics();
936 assert_eq!(topics_to_resubscribe.len(), 3);
937 assert!(topics_to_resubscribe.contains(&"tickers.BTCUSDT".to_string()));
938 assert!(topics_to_resubscribe.contains(&"tickers.ETHUSDT".to_string()));
939 assert!(topics_to_resubscribe.contains(&"orderbook".to_string()));
940
941 for topic in &topics_to_resubscribe {
943 state.mark_subscribe(topic);
944 }
945
946 for topic in &topics_to_resubscribe {
948 state.confirm_subscribe(topic);
949 }
950
951 assert_eq!(state.len(), 3);
953 assert_eq!(state.all_topics().len(), 3);
954 }
955
956 #[rstest]
957 fn test_state_machine_invalid_transitions() {
958 let state = SubscriptionState::new('.');
959
960 state.confirm_subscribe("tickers.BTCUSDT");
962 assert_eq!(state.len(), 1); state.confirm_unsubscribe("tickers.ETHUSDT");
966 assert_eq!(state.len(), 1); state.mark_subscribe("orderbook");
970 state.confirm_subscribe("orderbook");
971 state.confirm_subscribe("orderbook"); assert_eq!(state.len(), 2);
973
974 state.mark_unsubscribe("nonexistent");
976 state.confirm_unsubscribe("nonexistent");
977 assert_eq!(state.len(), 2); }
979
980 #[rstest]
981 fn test_mark_failure_moves_to_pending() {
982 let state = SubscriptionState::new('.');
983
984 state.mark_subscribe("tickers.BTCUSDT");
986 state.confirm_subscribe("tickers.BTCUSDT");
987 assert_eq!(state.len(), 1);
988 assert!(state.pending_subscribe_topics().is_empty());
989
990 state.mark_failure("tickers.BTCUSDT");
992
993 assert_eq!(state.len(), 0);
995 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
996
997 assert_eq!(state.all_topics(), vec!["tickers.BTCUSDT"]);
999 }
1000
1001 #[rstest]
1002 fn test_pending_subscribe_excludes_pending_unsubscribe() {
1003 let state = SubscriptionState::new('.');
1004
1005 state.mark_subscribe("tickers.BTCUSDT");
1007 state.confirm_subscribe("tickers.BTCUSDT");
1008
1009 state.mark_unsubscribe("tickers.BTCUSDT");
1011
1012 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
1014 assert!(state.all_topics().is_empty());
1015 assert_eq!(state.len(), 0);
1016 }
1017
1018 #[rstest]
1019 fn test_remove_reference_nonexistent_topic() {
1020 let state = SubscriptionState::new('.');
1021
1022 let should_unsubscribe = state.remove_reference("nonexistent");
1024
1025 assert!(!should_unsubscribe);
1027 assert_eq!(state.get_reference_count("nonexistent"), 0);
1028 }
1029
1030 #[rstest]
1031 fn test_edge_case_empty_channel_name() {
1032 let state = SubscriptionState::new('.');
1033
1034 state.mark_subscribe("");
1036 state.confirm_subscribe("");
1037
1038 assert_eq!(state.len(), 1);
1039 assert_eq!(state.all_topics(), vec![""]);
1040 }
1041
1042 #[rstest]
1043 fn test_special_characters_in_topics() {
1044 let state = SubscriptionState::new('.');
1045
1046 let special_topics = vec![
1048 "channel.symbol-with-dash",
1049 "channel.SYMBOL_WITH_UNDERSCORE",
1050 "channel.symbol123",
1051 "channel.symbol@special",
1052 ];
1053
1054 for topic in &special_topics {
1055 state.mark_subscribe(topic);
1056 state.confirm_subscribe(topic);
1057 }
1058
1059 assert_eq!(state.len(), special_topics.len());
1060
1061 let all_topics = state.all_topics();
1062 for topic in &special_topics {
1063 assert!(
1064 all_topics.contains(&(*topic).to_string()),
1065 "Missing topic: {topic}"
1066 );
1067 }
1068 }
1069
1070 #[rstest]
1071 fn test_clear_resets_all_state() {
1072 let state = SubscriptionState::new('.');
1073
1074 for i in 0..10 {
1076 let topic = format!("channel{i}.SYMBOL");
1077 state.add_reference(&topic);
1078 state.add_reference(&topic); state.mark_subscribe(&topic);
1080 state.confirm_subscribe(&topic);
1081 }
1082
1083 assert_eq!(state.len(), 10);
1084 assert!(!state.is_empty());
1085
1086 state.clear();
1088
1089 assert_eq!(state.len(), 0);
1091 assert!(state.is_empty());
1092 assert!(state.all_topics().is_empty());
1093 assert!(state.pending_subscribe_topics().is_empty());
1094 assert!(state.pending_unsubscribe_topics().is_empty());
1095
1096 for i in 0..10 {
1098 let topic = format!("channel{i}.SYMBOL");
1099 assert_eq!(state.get_reference_count(&topic), 0);
1100 }
1101 }
1102
1103 #[rstest]
1104 fn test_different_delimiter_does_not_affect_storage() {
1105 let state_dot = SubscriptionState::new('.');
1107 let state_colon = SubscriptionState::new(':');
1108
1109 state_dot.mark_subscribe("channel.SYMBOL");
1111 state_colon.mark_subscribe("channel:SYMBOL");
1112
1113 assert_eq!(state_dot.pending_subscribe_topics(), vec!["channel.SYMBOL"]);
1115 assert_eq!(
1116 state_colon.pending_subscribe_topics(),
1117 vec!["channel:SYMBOL"]
1118 );
1119 }
1120
1121 #[rstest]
1122 fn test_unsubscribe_before_subscribe_confirmed() {
1123 let state = SubscriptionState::new('.');
1124
1125 state.mark_subscribe("tickers.BTCUSDT");
1127 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
1128
1129 state.mark_unsubscribe("tickers.BTCUSDT");
1131
1132 assert!(state.pending_subscribe_topics().is_empty());
1134 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
1135
1136 state.confirm_unsubscribe("tickers.BTCUSDT");
1138
1139 assert!(state.is_empty());
1141 assert!(state.all_topics().is_empty());
1142 assert_eq!(state.len(), 0);
1143 }
1144
1145 #[rstest]
1146 fn test_late_subscribe_confirmation_after_unsubscribe() {
1147 let state = SubscriptionState::new('.');
1148
1149 state.mark_subscribe("tickers.BTCUSDT");
1151
1152 state.mark_unsubscribe("tickers.BTCUSDT");
1154
1155 state.confirm_subscribe("tickers.BTCUSDT");
1157
1158 assert_eq!(state.len(), 0);
1160 assert!(state.pending_subscribe_topics().is_empty());
1161
1162 state.confirm_unsubscribe("tickers.BTCUSDT");
1164
1165 assert!(state.is_empty());
1167 assert!(state.all_topics().is_empty());
1168 }
1169
1170 #[rstest]
1171 fn test_unsubscribe_clears_all_states() {
1172 let state = SubscriptionState::new('.');
1173
1174 state.mark_subscribe("tickers.BTCUSDT");
1176 state.confirm_subscribe("tickers.BTCUSDT");
1177 assert_eq!(state.len(), 1);
1178
1179 state.mark_unsubscribe("tickers.BTCUSDT");
1181
1182 assert_eq!(state.len(), 0);
1184 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
1185
1186 state.confirm_subscribe("tickers.BTCUSDT");
1188
1189 state.confirm_unsubscribe("tickers.BTCUSDT");
1191
1192 assert!(state.is_empty());
1194 assert_eq!(state.len(), 0);
1195 assert!(state.pending_subscribe_topics().is_empty());
1196 assert!(state.pending_unsubscribe_topics().is_empty());
1197 assert!(state.all_topics().is_empty());
1198 }
1199
1200 #[rstest]
1201 fn test_mark_failure_respects_pending_unsubscribe() {
1202 let state = SubscriptionState::new('.');
1203
1204 state.mark_subscribe("tickers.BTCUSDT");
1206 state.confirm_subscribe("tickers.BTCUSDT");
1207 assert_eq!(state.len(), 1);
1208
1209 state.mark_unsubscribe("tickers.BTCUSDT");
1211 assert_eq!(state.len(), 0);
1212 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
1213
1214 state.mark_failure("tickers.BTCUSDT");
1216
1217 assert!(state.pending_subscribe_topics().is_empty());
1219 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
1220
1221 assert!(state.all_topics().is_empty());
1223
1224 state.confirm_unsubscribe("tickers.BTCUSDT");
1226 assert!(state.is_empty());
1227 }
1228
1229 #[rstest]
1230 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
1231 async fn test_concurrent_stress_mixed_operations() {
1232 let state = Arc::new(SubscriptionState::new('.'));
1233 let mut handles = vec![];
1234
1235 for i in 0..50 {
1237 let state_clone = Arc::clone(&state);
1238 let handle = tokio::spawn(async move {
1239 let topic1 = format!("channel.SYMBOL{i}");
1240 let topic2 = format!("channel.SYMBOL{}", i + 100);
1241
1242 state_clone.add_reference(&topic1);
1244 state_clone.add_reference(&topic2);
1245
1246 state_clone.mark_subscribe(&topic1);
1248 state_clone.confirm_subscribe(&topic1);
1249 state_clone.mark_subscribe(&topic2);
1250
1251 if i % 3 == 0 {
1253 state_clone.mark_unsubscribe(&topic1);
1254 state_clone.confirm_unsubscribe(&topic1);
1255 }
1256
1257 state_clone.add_reference(&topic2);
1259 state_clone.remove_reference(&topic2);
1260
1261 state_clone.confirm_subscribe(&topic2);
1263 });
1264 handles.push(handle);
1265 }
1266
1267 for handle in handles {
1268 handle.await.unwrap();
1269 }
1270
1271 let all = state.all_topics();
1273 let confirmed_count = state.len();
1274
1275 assert!(confirmed_count > 50); assert!(confirmed_count <= 100); assert_eq!(
1280 all.len(),
1281 confirmed_count + state.pending_subscribe_topics().len()
1282 );
1283 }
1284
1285 #[rstest]
1286 fn test_edge_case_malformed_topics() {
1287 let state = SubscriptionState::new('.');
1288
1289 state.mark_subscribe("channel.symbol.extra");
1291 state.confirm_subscribe("channel.symbol.extra");
1292 let topics = state.all_topics();
1293 assert!(topics.contains(&"channel.symbol.extra".to_string()));
1294
1295 state.mark_subscribe(".channel");
1297 state.confirm_subscribe(".channel");
1298 assert_eq!(state.len(), 2);
1299
1300 state.mark_subscribe("channel.");
1303 state.confirm_subscribe("channel.");
1304 assert_eq!(state.len(), 3);
1305
1306 state.mark_subscribe("tickers");
1308 state.confirm_subscribe("tickers");
1309 assert_eq!(state.len(), 4);
1310
1311 let all = state.all_topics();
1313 assert_eq!(all.len(), 4);
1314 assert!(all.contains(&"channel.symbol.extra".to_string()));
1315 assert!(all.contains(&".channel".to_string()));
1316 assert!(all.contains(&"channel".to_string())); assert!(all.contains(&"tickers".to_string()));
1318 }
1319
1320 #[rstest]
1321 fn test_reference_count_underflow_safety() {
1322 let state = SubscriptionState::new('.');
1323
1324 assert!(!state.remove_reference("never.added"));
1326 assert_eq!(state.get_reference_count("never.added"), 0);
1327
1328 state.add_reference("once.added");
1330 assert_eq!(state.get_reference_count("once.added"), 1);
1331
1332 assert!(state.remove_reference("once.added")); assert_eq!(state.get_reference_count("once.added"), 0);
1334
1335 assert!(!state.remove_reference("once.added")); assert!(!state.remove_reference("once.added")); assert_eq!(state.get_reference_count("once.added"), 0);
1338
1339 assert!(state.add_reference("once.added"));
1341 assert_eq!(state.get_reference_count("once.added"), 1);
1342 }
1343
1344 #[rstest]
1345 fn test_reconnection_with_partial_state() {
1346 let state = SubscriptionState::new('.');
1347
1348 state.mark_subscribe("confirmed.BTCUSDT");
1351 state.confirm_subscribe("confirmed.BTCUSDT");
1352
1353 state.mark_subscribe("pending.ETHUSDT");
1355
1356 state.mark_subscribe("cancelled.XRPUSDT");
1358 state.confirm_subscribe("cancelled.XRPUSDT");
1359 state.mark_unsubscribe("cancelled.XRPUSDT");
1360
1361 assert_eq!(state.len(), 1); let all = state.all_topics();
1364 assert_eq!(all.len(), 2); assert!(all.contains(&"confirmed.BTCUSDT".to_string()));
1366 assert!(all.contains(&"pending.ETHUSDT".to_string()));
1367 assert!(!all.contains(&"cancelled.XRPUSDT".to_string())); let topics_to_resubscribe = state.all_topics();
1371
1372 state.confirmed().clear();
1374
1375 for topic in &topics_to_resubscribe {
1377 state.mark_subscribe(topic);
1378 }
1379
1380 for topic in &topics_to_resubscribe {
1382 state.confirm_subscribe(topic);
1383 }
1384
1385 assert_eq!(state.len(), 2); let final_topics = state.all_topics();
1388 assert_eq!(final_topics.len(), 2);
1389 assert!(final_topics.contains(&"confirmed.BTCUSDT".to_string()));
1390 assert!(final_topics.contains(&"pending.ETHUSDT".to_string()));
1391 assert!(!final_topics.contains(&"cancelled.XRPUSDT".to_string()));
1392 }
1393
1394 fn check_invariants(state: &SubscriptionState, label: &str) {
1405 let confirmed_topics: AHashSet<String> = state
1407 .topics_from_map(&state.confirmed)
1408 .into_iter()
1409 .collect();
1410 let pending_sub_topics: AHashSet<String> =
1411 state.pending_subscribe_topics().into_iter().collect();
1412 let pending_unsub_topics: AHashSet<String> =
1413 state.pending_unsubscribe_topics().into_iter().collect();
1414
1415 let confirmed_and_pending_sub: Vec<_> =
1417 confirmed_topics.intersection(&pending_sub_topics).collect();
1418 assert!(
1419 confirmed_and_pending_sub.is_empty(),
1420 "{label}: Topic in both confirmed and pending_subscribe: {confirmed_and_pending_sub:?}"
1421 );
1422
1423 let confirmed_and_pending_unsub: Vec<_> = confirmed_topics
1424 .intersection(&pending_unsub_topics)
1425 .collect();
1426 assert!(
1427 confirmed_and_pending_unsub.is_empty(),
1428 "{label}: Topic in both confirmed and pending_unsubscribe: {confirmed_and_pending_unsub:?}"
1429 );
1430
1431 let pending_sub_and_unsub: Vec<_> = pending_sub_topics
1432 .intersection(&pending_unsub_topics)
1433 .collect();
1434 assert!(
1435 pending_sub_and_unsub.is_empty(),
1436 "{label}: Topic in both pending_subscribe and pending_unsubscribe: {pending_sub_and_unsub:?}"
1437 );
1438
1439 let all_topics: AHashSet<String> = state.all_topics().into_iter().collect();
1441 let expected_all: AHashSet<String> = confirmed_topics
1442 .union(&pending_sub_topics)
1443 .cloned()
1444 .collect();
1445 assert_eq!(
1446 all_topics, expected_all,
1447 "{label}: all_topics() doesn't match confirmed ∪ pending_subscribe"
1448 );
1449
1450 for topic in &pending_unsub_topics {
1452 assert!(
1453 !all_topics.contains(topic),
1454 "{label}: pending_unsubscribe topic {topic} incorrectly in all_topics()"
1455 );
1456 }
1457
1458 let expected_len: usize = state
1460 .confirmed
1461 .iter()
1462 .map(|entry| entry.value().len())
1463 .sum();
1464 assert_eq!(
1465 state.len(),
1466 expected_len,
1467 "{label}: len() mismatch. Expected {expected_len}, was {}",
1468 state.len()
1469 );
1470
1471 let should_be_empty = state.confirmed.is_empty()
1473 && pending_sub_topics.is_empty()
1474 && pending_unsub_topics.is_empty();
1475 assert_eq!(
1476 state.is_empty(),
1477 should_be_empty,
1478 "{label}: is_empty() inconsistent. Maps empty: {should_be_empty}, is_empty(): {}",
1479 state.is_empty()
1480 );
1481
1482 for entry in state.reference_counts.iter() {
1484 let count = entry.value().get();
1485 assert!(
1486 count > 0,
1487 "{label}: Reference count should be NonZeroUsize (> 0), was {count} for {:?}",
1488 entry.key()
1489 );
1490 }
1491 }
1492
1493 fn check_topic_exclusivity(state: &SubscriptionState, topic: &str, label: &str) {
1495 let (channel, symbol) = split_topic(topic, state.delimiter);
1496
1497 let in_confirmed = is_tracked(&state.confirmed, channel, symbol);
1498 let in_pending_sub = is_tracked(&state.pending_subscribe, channel, symbol);
1499 let in_pending_unsub = is_tracked(&state.pending_unsubscribe, channel, symbol);
1500
1501 let count = [in_confirmed, in_pending_sub, in_pending_unsub]
1502 .iter()
1503 .filter(|&&x| x)
1504 .count();
1505
1506 assert!(
1507 count <= 1,
1508 "{label}: Topic {topic} in {count} states (should be 0 or 1). \
1509 confirmed: {in_confirmed}, pending_sub: {in_pending_sub}, pending_unsub: {in_pending_unsub}"
1510 );
1511 }
1512
1513 #[cfg(test)]
1514 mod property_tests {
1515 use proptest::prelude::*;
1516
1517 use super::*;
1518
1519 #[derive(Debug, Clone)]
1520 enum Operation {
1521 MarkSubscribe(String),
1522 ConfirmSubscribe(String),
1523 MarkUnsubscribe(String),
1524 ConfirmUnsubscribe(String),
1525 MarkFailure(String),
1526 AddReference(String),
1527 RemoveReference(String),
1528 Clear,
1529 }
1530
1531 fn topic_strategy() -> impl Strategy<Value = String> {
1533 prop_oneof![
1534 (any::<u8>(), any::<u8>())
1536 .prop_map(|(ch, sym)| { format!("channel{}.SYMBOL{}", ch % 5, sym % 10) }),
1537 any::<u8>().prop_map(|ch| format!("channel{}", ch % 5)),
1539 ]
1540 }
1541
1542 fn operation_strategy() -> impl Strategy<Value = Operation> {
1544 topic_strategy().prop_flat_map(|topic| {
1545 prop_oneof![
1546 Just(Operation::MarkSubscribe(topic.clone())),
1547 Just(Operation::ConfirmSubscribe(topic.clone())),
1548 Just(Operation::MarkUnsubscribe(topic.clone())),
1549 Just(Operation::ConfirmUnsubscribe(topic.clone())),
1550 Just(Operation::MarkFailure(topic.clone())),
1551 Just(Operation::AddReference(topic.clone())),
1552 Just(Operation::RemoveReference(topic)),
1553 Just(Operation::Clear),
1554 ]
1555 })
1556 }
1557
1558 fn apply_operation(state: &SubscriptionState, op: &Operation) {
1560 match op {
1561 Operation::MarkSubscribe(topic) => state.mark_subscribe(topic),
1562 Operation::ConfirmSubscribe(topic) => state.confirm_subscribe(topic),
1563 Operation::MarkUnsubscribe(topic) => state.mark_unsubscribe(topic),
1564 Operation::ConfirmUnsubscribe(topic) => state.confirm_unsubscribe(topic),
1565 Operation::MarkFailure(topic) => state.mark_failure(topic),
1566 Operation::AddReference(topic) => {
1567 state.add_reference(topic);
1568 }
1569 Operation::RemoveReference(topic) => {
1570 state.remove_reference(topic);
1571 }
1572 Operation::Clear => state.clear(),
1573 }
1574 }
1575
1576 proptest! {
1577 #![proptest_config(ProptestConfig::with_cases(500))]
1578
1579 #[rstest]
1581 fn prop_invariants_hold_after_operations(
1582 operations in prop::collection::vec(operation_strategy(), 1..50)
1583 ) {
1584 let state = SubscriptionState::new('.');
1585
1586 for (i, op) in operations.iter().enumerate() {
1588 apply_operation(&state, op);
1589
1590 check_invariants(&state, &format!("After op {i}: {op:?}"));
1592 }
1593
1594 check_invariants(&state, "Final state");
1596 }
1597
1598 #[rstest]
1600 fn prop_reference_counting_consistency(
1601 ops in prop::collection::vec(
1602 topic_strategy().prop_flat_map(|t| {
1603 prop_oneof![
1604 Just(Operation::AddReference(t.clone())),
1605 Just(Operation::RemoveReference(t)),
1606 ]
1607 }),
1608 1..100
1609 )
1610 ) {
1611 let state = SubscriptionState::new('.');
1612
1613 for op in &ops {
1614 apply_operation(&state, op);
1615
1616 for entry in state.reference_counts.iter() {
1618 assert!(entry.value().get() > 0);
1619 }
1620 }
1621 }
1622
1623 #[rstest]
1625 fn prop_all_topics_is_union(
1626 operations in prop::collection::vec(operation_strategy(), 1..50)
1627 ) {
1628 let state = SubscriptionState::new('.');
1629
1630 for op in &operations {
1631 apply_operation(&state, op);
1632
1633 let all_topics: AHashSet<String> = state.all_topics().into_iter().collect();
1635 let confirmed: AHashSet<String> = state.topics_from_map(&state.confirmed).into_iter().collect();
1636 let pending_sub: AHashSet<String> = state.pending_subscribe_topics().into_iter().collect();
1637 let expected: AHashSet<String> = confirmed.union(&pending_sub).cloned().collect();
1638
1639 assert_eq!(all_topics, expected);
1640
1641 let pending_unsub: AHashSet<String> = state.pending_unsubscribe_topics().into_iter().collect();
1643 for topic in pending_unsub {
1644 assert!(!all_topics.contains(&topic));
1645 }
1646 }
1647 }
1648
1649 #[rstest]
1651 fn prop_clear_resets_completely(
1652 operations in prop::collection::vec(operation_strategy(), 1..30)
1653 ) {
1654 let state = SubscriptionState::new('.');
1655
1656 for op in &operations {
1658 apply_operation(&state, op);
1659 }
1660
1661 state.clear();
1663
1664 assert!(state.is_empty());
1665 assert_eq!(state.len(), 0);
1666 assert!(state.all_topics().is_empty());
1667 assert!(state.pending_subscribe_topics().is_empty());
1668 assert!(state.pending_unsubscribe_topics().is_empty());
1669 assert!(state.confirmed.is_empty());
1670 assert!(state.pending_subscribe.is_empty());
1671 assert!(state.pending_unsubscribe.is_empty());
1672 assert!(state.reference_counts.is_empty());
1673 }
1674
1675 #[rstest]
1677 fn prop_topic_mutual_exclusivity(
1678 operations in prop::collection::vec(operation_strategy(), 1..50),
1679 topic in topic_strategy()
1680 ) {
1681 let state = SubscriptionState::new('.');
1682
1683 for (i, op) in operations.iter().enumerate() {
1684 apply_operation(&state, op);
1685 check_topic_exclusivity(&state, &topic, &format!("After op {i}: {op:?}"));
1686 }
1687 }
1688 }
1689 }
1690
1691 #[rstest]
1692 fn test_exhaustive_two_step_transitions() {
1693 let operations = [
1694 "mark_subscribe",
1695 "confirm_subscribe",
1696 "mark_unsubscribe",
1697 "confirm_unsubscribe",
1698 "mark_failure",
1699 ];
1700
1701 for &op1 in &operations {
1702 for &op2 in &operations {
1703 let state = SubscriptionState::new('.');
1704 let topic = "test.TOPIC";
1705
1706 apply_op(&state, op1, topic);
1708 apply_op(&state, op2, topic);
1709
1710 check_invariants(&state, &format!("{op1} → {op2}"));
1712 check_topic_exclusivity(&state, topic, &format!("{op1} → {op2}"));
1713 }
1714 }
1715 }
1716
1717 fn apply_op(state: &SubscriptionState, op: &str, topic: &str) {
1718 match op {
1719 "mark_subscribe" => state.mark_subscribe(topic),
1720 "confirm_subscribe" => state.confirm_subscribe(topic),
1721 "mark_unsubscribe" => state.mark_unsubscribe(topic),
1722 "confirm_unsubscribe" => state.confirm_unsubscribe(topic),
1723 "mark_failure" => state.mark_failure(topic),
1724 _ => panic!("Unknown operation: {op}"),
1725 }
1726 }
1727
1728 #[rstest]
1729 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
1730 async fn test_stress_rapid_resubscribe_pattern() {
1731 let state = Arc::new(SubscriptionState::new('.'));
1733 let mut handles = vec![];
1734
1735 for i in 0..100 {
1736 let state_clone = Arc::clone(&state);
1737 let handle = tokio::spawn(async move {
1738 let topic = format!("rapid.SYMBOL{}", i % 10); state_clone.mark_subscribe(&topic);
1742 state_clone.confirm_subscribe(&topic);
1743
1744 state_clone.mark_unsubscribe(&topic);
1746 state_clone.mark_subscribe(&topic);
1748 state_clone.confirm_unsubscribe(&topic);
1750 state_clone.confirm_subscribe(&topic);
1752 });
1753 handles.push(handle);
1754 }
1755
1756 for handle in handles {
1757 handle.await.unwrap();
1758 }
1759
1760 check_invariants(&state, "After rapid resubscribe stress test");
1761 }
1762
1763 #[rstest]
1764 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
1765 async fn test_stress_failure_recovery_loop() {
1766 let state = Arc::new(SubscriptionState::new('.'));
1769 let mut handles = vec![];
1770
1771 for i in 0..30 {
1772 let state_clone = Arc::clone(&state);
1773 let handle = tokio::spawn(async move {
1774 let topic = format!("failure.SYMBOL{i}"); state_clone.mark_subscribe(&topic);
1778 state_clone.confirm_subscribe(&topic);
1779
1780 for _ in 0..5 {
1782 state_clone.mark_failure(&topic);
1783 state_clone.confirm_subscribe(&topic); }
1785 });
1786 handles.push(handle);
1787 }
1788
1789 for handle in handles {
1790 handle.await.unwrap();
1791 }
1792
1793 check_invariants(&state, "After failure recovery loops");
1794
1795 assert_eq!(state.len(), 30);
1797 }
1798}