1use std::{
38 num::NonZeroUsize,
39 sync::{Arc, LazyLock},
40};
41
42use ahash::AHashSet;
43use dashmap::DashMap;
44use ustr::Ustr;
45
46pub(crate) static CHANNEL_LEVEL_MARKER: LazyLock<Ustr> = LazyLock::new(|| Ustr::from(""));
51
52#[derive(Clone, Debug)]
69pub struct SubscriptionState {
70 confirmed: Arc<DashMap<Ustr, AHashSet<Ustr>>>,
72 pending_subscribe: Arc<DashMap<Ustr, AHashSet<Ustr>>>,
74 pending_unsubscribe: Arc<DashMap<Ustr, AHashSet<Ustr>>>,
76 reference_counts: Arc<DashMap<Ustr, NonZeroUsize>>,
78 delimiter: char,
80}
81
82impl SubscriptionState {
83 pub fn new(delimiter: char) -> Self {
85 Self {
86 confirmed: Arc::new(DashMap::new()),
87 pending_subscribe: Arc::new(DashMap::new()),
88 pending_unsubscribe: Arc::new(DashMap::new()),
89 reference_counts: Arc::new(DashMap::new()),
90 delimiter,
91 }
92 }
93
94 pub fn delimiter(&self) -> char {
96 self.delimiter
97 }
98
99 pub fn confirmed(&self) -> Arc<DashMap<Ustr, AHashSet<Ustr>>> {
101 Arc::clone(&self.confirmed)
102 }
103
104 pub fn pending_subscribe(&self) -> Arc<DashMap<Ustr, AHashSet<Ustr>>> {
106 Arc::clone(&self.pending_subscribe)
107 }
108
109 pub fn pending_unsubscribe(&self) -> Arc<DashMap<Ustr, AHashSet<Ustr>>> {
111 Arc::clone(&self.pending_unsubscribe)
112 }
113
114 pub fn len(&self) -> usize {
118 self.confirmed.iter().map(|entry| entry.value().len()).sum()
119 }
120
121 pub fn is_empty(&self) -> bool {
123 self.confirmed.is_empty()
124 && self.pending_subscribe.is_empty()
125 && self.pending_unsubscribe.is_empty()
126 }
127
128 pub fn is_subscribed(&self, channel: &Ustr, symbol: &Ustr) -> bool {
130 if let Some(symbols) = self.confirmed.get(channel)
131 && symbols.contains(symbol)
132 {
133 return true;
134 }
135 if let Some(symbols) = self.pending_subscribe.get(channel)
136 && symbols.contains(symbol)
137 {
138 return true;
139 }
140 false
141 }
142
143 pub fn mark_subscribe(&self, topic: &str) {
149 let (channel, symbol) = split_topic(topic, self.delimiter);
150
151 if is_tracked(&self.confirmed, channel, symbol) {
153 return;
154 }
155
156 untrack_topic(&self.pending_unsubscribe, channel, symbol);
158
159 track_topic(&self.pending_subscribe, channel, symbol);
160 }
161
162 pub fn try_mark_subscribe(&self, topic: &str) -> bool {
169 let (channel, symbol) = split_topic(topic, self.delimiter);
170
171 if is_tracked(&self.confirmed, channel, symbol) {
173 return false;
174 }
175
176 let channel_ustr = Ustr::from(channel);
178 let symbol_ustr = symbol.map_or(*CHANNEL_LEVEL_MARKER, Ustr::from);
179
180 let mut entry = self.pending_subscribe.entry(channel_ustr).or_default();
181 let inserted = entry.insert(symbol_ustr);
182
183 if inserted {
185 untrack_topic(&self.pending_unsubscribe, channel, symbol);
186 }
187
188 inserted
189 }
190
191 pub fn mark_unsubscribe(&self, topic: &str) {
197 let (channel, symbol) = split_topic(topic, self.delimiter);
198 track_topic(&self.pending_unsubscribe, channel, symbol);
199 untrack_topic(&self.confirmed, channel, symbol);
200 untrack_topic(&self.pending_subscribe, channel, symbol);
201 }
202
203 pub fn confirm_subscribe(&self, topic: &str) {
209 let (channel, symbol) = split_topic(topic, self.delimiter);
210
211 if is_tracked(&self.pending_unsubscribe, channel, symbol) {
213 return;
214 }
215
216 untrack_topic(&self.pending_subscribe, channel, symbol);
217 track_topic(&self.confirmed, channel, symbol);
218 }
219
220 pub fn confirm_unsubscribe(&self, topic: &str) {
231 let (channel, symbol) = split_topic(topic, self.delimiter);
232
233 if !is_tracked(&self.pending_unsubscribe, channel, symbol) {
236 return; }
238
239 untrack_topic(&self.pending_unsubscribe, channel, symbol);
240 untrack_topic(&self.confirmed, channel, symbol);
241 }
243
244 pub fn mark_failure(&self, topic: &str) {
249 let (channel, symbol) = split_topic(topic, self.delimiter);
250
251 if is_tracked(&self.pending_unsubscribe, channel, symbol) {
253 return;
254 }
255
256 untrack_topic(&self.confirmed, channel, symbol);
257 track_topic(&self.pending_subscribe, channel, symbol);
258 }
259
260 pub fn pending_subscribe_topics(&self) -> Vec<String> {
262 self.topics_from_map(&self.pending_subscribe)
263 }
264
265 pub fn pending_unsubscribe_topics(&self) -> Vec<String> {
267 self.topics_from_map(&self.pending_unsubscribe)
268 }
269
270 pub fn all_topics(&self) -> Vec<String> {
277 let mut topics = Vec::new();
278 topics.extend(self.topics_from_map(&self.confirmed));
279 topics.extend(self.topics_from_map(&self.pending_subscribe));
280 topics
281 }
282
283 fn topics_from_map(&self, map: &DashMap<Ustr, AHashSet<Ustr>>) -> Vec<String> {
285 let mut topics = Vec::new();
286 let marker = *CHANNEL_LEVEL_MARKER;
287
288 for entry in map {
289 let channel = entry.key();
290 let symbols = entry.value();
291
292 if symbols.contains(&marker) {
294 topics.push(channel.to_string());
295 }
296
297 for symbol in symbols {
299 if *symbol != marker {
300 topics.push(format!(
301 "{}{}{}",
302 channel.as_str(),
303 self.delimiter,
304 symbol.as_str()
305 ));
306 }
307 }
308 }
309
310 topics
311 }
312
313 pub fn add_reference(&self, topic: &str) -> bool {
322 let mut should_subscribe = false;
323 let topic_ustr = Ustr::from(topic);
324
325 self.reference_counts
326 .entry(topic_ustr)
327 .and_modify(|count| {
328 *count = NonZeroUsize::new(count.get() + 1).expect("reference count overflow");
329 })
330 .or_insert_with(|| {
331 should_subscribe = true;
332 NonZeroUsize::new(1).expect("NonZeroUsize::new(1) should never fail")
333 });
334
335 should_subscribe
336 }
337
338 pub fn remove_reference(&self, topic: &str) -> bool {
348 let topic_ustr = Ustr::from(topic);
349
350 if let dashmap::mapref::entry::Entry::Occupied(mut entry) =
353 self.reference_counts.entry(topic_ustr)
354 {
355 let current = entry.get().get();
356
357 if current == 1 {
358 entry.remove();
359 return true;
360 }
361
362 *entry.get_mut() = NonZeroUsize::new(current - 1)
363 .expect("reference count should never reach zero here");
364 }
365
366 false
367 }
368
369 pub fn get_reference_count(&self, topic: &str) -> usize {
373 let topic_ustr = Ustr::from(topic);
374 self.reference_counts
375 .get(&topic_ustr)
376 .map_or(0, |count| count.get())
377 }
378
379 pub fn clear(&self) {
383 self.confirmed.clear();
384 self.pending_subscribe.clear();
385 self.pending_unsubscribe.clear();
386 self.reference_counts.clear();
387 }
388}
389
390pub fn split_topic(topic: &str, delimiter: char) -> (&str, Option<&str>) {
392 topic
393 .split_once(delimiter)
394 .map_or((topic, None), |(channel, symbol)| (channel, Some(symbol)))
395}
396
397fn track_topic(map: &DashMap<Ustr, AHashSet<Ustr>>, channel: &str, symbol: Option<&str>) {
402 let channel_ustr = Ustr::from(channel);
403 let mut entry = map.entry(channel_ustr).or_default();
404
405 if let Some(symbol) = symbol {
406 entry.insert(Ustr::from(symbol));
407 } else {
408 entry.insert(*CHANNEL_LEVEL_MARKER);
409 }
410}
411
412fn untrack_topic(map: &DashMap<Ustr, AHashSet<Ustr>>, channel: &str, symbol: Option<&str>) {
416 let channel_ustr = Ustr::from(channel);
417 let symbol_to_remove = if let Some(symbol) = symbol {
418 Ustr::from(symbol)
419 } else {
420 *CHANNEL_LEVEL_MARKER
421 };
422
423 if let dashmap::mapref::entry::Entry::Occupied(mut entry) = map.entry(channel_ustr) {
426 entry.get_mut().remove(&symbol_to_remove);
427 if entry.get().is_empty() {
428 entry.remove();
429 }
430 }
431}
432
433fn is_tracked(map: &DashMap<Ustr, AHashSet<Ustr>>, channel: &str, symbol: Option<&str>) -> bool {
435 let channel_ustr = Ustr::from(channel);
436 let symbol_to_check = if let Some(symbol) = symbol {
437 Ustr::from(symbol)
438 } else {
439 *CHANNEL_LEVEL_MARKER
440 };
441
442 if let Some(entry) = map.get(&channel_ustr) {
443 entry.contains(&symbol_to_check)
444 } else {
445 false
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use rstest::rstest;
452
453 use super::*;
454
455 #[rstest]
456 fn test_split_topic_with_symbol() {
457 let (channel, symbol) = split_topic("tickers.BTCUSDT", '.');
458 assert_eq!(channel, "tickers");
459 assert_eq!(symbol, Some("BTCUSDT"));
460
461 let (channel, symbol) = split_topic("orderBookL2:XBTUSD", ':');
462 assert_eq!(channel, "orderBookL2");
463 assert_eq!(symbol, Some("XBTUSD"));
464 }
465
466 #[rstest]
467 fn test_split_topic_without_symbol() {
468 let (channel, symbol) = split_topic("orderbook", '.');
469 assert_eq!(channel, "orderbook");
470 assert_eq!(symbol, None);
471 }
472
473 #[rstest]
474 fn test_new_state_is_empty() {
475 let state = SubscriptionState::new('.');
476 assert!(state.is_empty());
477 assert_eq!(state.len(), 0);
478 }
479
480 #[rstest]
481 fn test_mark_subscribe() {
482 let state = SubscriptionState::new('.');
483 state.mark_subscribe("tickers.BTCUSDT");
484
485 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
486 assert_eq!(state.len(), 0); }
488
489 #[rstest]
490 fn test_confirm_subscribe() {
491 let state = SubscriptionState::new('.');
492 state.mark_subscribe("tickers.BTCUSDT");
493 state.confirm_subscribe("tickers.BTCUSDT");
494
495 assert!(state.pending_subscribe_topics().is_empty());
496 assert_eq!(state.len(), 1);
497 }
498
499 #[rstest]
500 fn test_is_subscribed_empty_state() {
501 let state = SubscriptionState::new('.');
502 let channel = Ustr::from("tickers");
503 let symbol = Ustr::from("BTCUSDT");
504
505 assert!(!state.is_subscribed(&channel, &symbol));
506 }
507
508 #[rstest]
509 fn test_is_subscribed_pending() {
510 let state = SubscriptionState::new('.');
511 let channel = Ustr::from("tickers");
512 let symbol = Ustr::from("BTCUSDT");
513
514 state.mark_subscribe("tickers.BTCUSDT");
515
516 assert!(state.is_subscribed(&channel, &symbol));
517 }
518
519 #[rstest]
520 fn test_is_subscribed_confirmed() {
521 let state = SubscriptionState::new('.');
522 let channel = Ustr::from("tickers");
523 let symbol = Ustr::from("BTCUSDT");
524
525 state.mark_subscribe("tickers.BTCUSDT");
526 state.confirm_subscribe("tickers.BTCUSDT");
527
528 assert!(state.is_subscribed(&channel, &symbol));
529 }
530
531 #[rstest]
532 fn test_is_subscribed_after_unsubscribe() {
533 let state = SubscriptionState::new('.');
534 let channel = Ustr::from("tickers");
535 let symbol = Ustr::from("BTCUSDT");
536
537 state.mark_subscribe("tickers.BTCUSDT");
538 state.confirm_subscribe("tickers.BTCUSDT");
539 state.mark_unsubscribe("tickers.BTCUSDT");
540
541 assert!(!state.is_subscribed(&channel, &symbol));
543 }
544
545 #[rstest]
546 fn test_is_subscribed_after_confirm_unsubscribe() {
547 let state = SubscriptionState::new('.');
548 let channel = Ustr::from("tickers");
549 let symbol = Ustr::from("BTCUSDT");
550
551 state.mark_subscribe("tickers.BTCUSDT");
552 state.confirm_subscribe("tickers.BTCUSDT");
553 state.mark_unsubscribe("tickers.BTCUSDT");
554 state.confirm_unsubscribe("tickers.BTCUSDT");
555
556 assert!(!state.is_subscribed(&channel, &symbol));
557 }
558
559 #[rstest]
560 fn test_mark_unsubscribe() {
561 let state = SubscriptionState::new('.');
562 state.mark_subscribe("tickers.BTCUSDT");
563 state.confirm_subscribe("tickers.BTCUSDT");
564 state.mark_unsubscribe("tickers.BTCUSDT");
565
566 assert_eq!(state.len(), 0); assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
568 }
569
570 #[rstest]
571 fn test_confirm_unsubscribe() {
572 let state = SubscriptionState::new('.');
573 state.mark_subscribe("tickers.BTCUSDT");
574 state.confirm_subscribe("tickers.BTCUSDT");
575 state.mark_unsubscribe("tickers.BTCUSDT");
576 state.confirm_unsubscribe("tickers.BTCUSDT");
577
578 assert!(state.is_empty());
579 }
580
581 #[rstest]
582 fn test_resubscribe_before_unsubscribe_ack() {
583 let state = SubscriptionState::new('.');
587
588 state.mark_subscribe("tickers.BTCUSDT");
589 state.confirm_subscribe("tickers.BTCUSDT");
590 assert_eq!(state.len(), 1);
591
592 state.mark_unsubscribe("tickers.BTCUSDT");
593 assert_eq!(state.len(), 0);
594 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
595
596 state.mark_subscribe("tickers.BTCUSDT");
598 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
599
600 state.confirm_unsubscribe("tickers.BTCUSDT");
602 assert!(state.pending_unsubscribe_topics().is_empty());
603 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]); state.confirm_subscribe("tickers.BTCUSDT");
607 assert_eq!(state.len(), 1);
608 assert!(state.pending_subscribe_topics().is_empty());
609
610 let all = state.all_topics();
612 assert_eq!(all.len(), 1);
613 assert!(all.contains(&"tickers.BTCUSDT".to_string()));
614 }
615
616 #[rstest]
617 fn test_stale_unsubscribe_ack_after_resubscribe_confirmed() {
618 let state = SubscriptionState::new('.');
623
624 state.mark_subscribe("tickers.BTCUSDT");
626 state.confirm_subscribe("tickers.BTCUSDT");
627 assert_eq!(state.len(), 1);
628
629 state.mark_unsubscribe("tickers.BTCUSDT");
631 assert_eq!(state.len(), 0);
632 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
633
634 state.mark_subscribe("tickers.BTCUSDT");
636 assert!(state.pending_unsubscribe_topics().is_empty()); assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
638
639 state.confirm_subscribe("tickers.BTCUSDT");
641 assert_eq!(state.len(), 1); assert!(state.pending_subscribe_topics().is_empty());
643
644 state.confirm_unsubscribe("tickers.BTCUSDT");
647
648 assert_eq!(state.len(), 1); assert!(state.pending_unsubscribe_topics().is_empty());
651 assert!(state.pending_subscribe_topics().is_empty());
652
653 let all = state.all_topics();
655 assert_eq!(all.len(), 1);
656 assert!(all.contains(&"tickers.BTCUSDT".to_string()));
657 }
658
659 #[rstest]
660 fn test_mark_failure() {
661 let state = SubscriptionState::new('.');
662 state.mark_subscribe("tickers.BTCUSDT");
663 state.confirm_subscribe("tickers.BTCUSDT");
664 state.mark_failure("tickers.BTCUSDT");
665
666 assert_eq!(state.len(), 0);
667 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
668 }
669
670 #[rstest]
671 fn test_all_topics_includes_confirmed_and_pending_subscribe() {
672 let state = SubscriptionState::new('.');
673 state.mark_subscribe("tickers.BTCUSDT");
674 state.confirm_subscribe("tickers.BTCUSDT");
675 state.mark_subscribe("tickers.ETHUSDT");
676
677 let topics = state.all_topics();
678 assert_eq!(topics.len(), 2);
679 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
680 assert!(topics.contains(&"tickers.ETHUSDT".to_string()));
681 }
682
683 #[rstest]
684 fn test_all_topics_excludes_pending_unsubscribe() {
685 let state = SubscriptionState::new('.');
686 state.mark_subscribe("tickers.BTCUSDT");
687 state.confirm_subscribe("tickers.BTCUSDT");
688 state.mark_unsubscribe("tickers.BTCUSDT");
689
690 let topics = state.all_topics();
691 assert!(topics.is_empty());
692 }
693
694 #[rstest]
695 fn test_reference_counting_single_topic() {
696 let state = SubscriptionState::new('.');
697
698 assert!(state.add_reference("tickers.BTCUSDT"));
699 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 1);
700
701 assert!(!state.add_reference("tickers.BTCUSDT"));
702 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 2);
703
704 assert!(!state.remove_reference("tickers.BTCUSDT"));
705 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 1);
706
707 assert!(state.remove_reference("tickers.BTCUSDT"));
708 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 0);
709 }
710
711 #[rstest]
712 fn test_reference_counting_multiple_topics() {
713 let state = SubscriptionState::new('.');
714
715 assert!(state.add_reference("tickers.BTCUSDT"));
716 assert!(state.add_reference("tickers.ETHUSDT"));
717
718 assert!(!state.add_reference("tickers.BTCUSDT"));
719 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 2);
720 assert_eq!(state.get_reference_count("tickers.ETHUSDT"), 1);
721
722 assert!(!state.remove_reference("tickers.BTCUSDT"));
723 assert!(state.remove_reference("tickers.ETHUSDT"));
724 }
725
726 #[rstest]
727 fn test_topic_without_symbol() {
728 let state = SubscriptionState::new('.');
729 state.mark_subscribe("orderbook");
730 state.confirm_subscribe("orderbook");
731
732 assert_eq!(state.len(), 1);
733 assert_eq!(state.all_topics(), vec!["orderbook"]);
734 }
735
736 #[rstest]
737 fn test_different_delimiters() {
738 let state_dot = SubscriptionState::new('.');
739 state_dot.mark_subscribe("tickers.BTCUSDT");
740 assert_eq!(
741 state_dot.pending_subscribe_topics(),
742 vec!["tickers.BTCUSDT"]
743 );
744
745 let state_colon = SubscriptionState::new(':');
746 state_colon.mark_subscribe("orderBookL2:XBTUSD");
747 assert_eq!(
748 state_colon.pending_subscribe_topics(),
749 vec!["orderBookL2:XBTUSD"]
750 );
751 }
752
753 #[rstest]
754 fn test_clear() {
755 let state = SubscriptionState::new('.');
756 state.mark_subscribe("tickers.BTCUSDT");
757 state.confirm_subscribe("tickers.BTCUSDT");
758 state.add_reference("tickers.BTCUSDT");
759
760 state.clear();
761
762 assert!(state.is_empty());
763 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 0);
764 }
765
766 #[rstest]
767 fn test_multiple_symbols_same_channel() {
768 let state = SubscriptionState::new('.');
769 state.mark_subscribe("tickers.BTCUSDT");
770 state.mark_subscribe("tickers.ETHUSDT");
771 state.confirm_subscribe("tickers.BTCUSDT");
772 state.confirm_subscribe("tickers.ETHUSDT");
773
774 assert_eq!(state.len(), 2);
775 let topics = state.all_topics();
776 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
777 assert!(topics.contains(&"tickers.ETHUSDT".to_string()));
778 }
779
780 #[rstest]
781 fn test_mixed_channel_and_symbol_subscriptions() {
782 let state = SubscriptionState::new('.');
783
784 state.mark_subscribe("tickers");
786 state.confirm_subscribe("tickers");
787 assert_eq!(state.len(), 1);
788 assert_eq!(state.all_topics(), vec!["tickers"]);
789
790 state.mark_subscribe("tickers.BTCUSDT");
792 state.confirm_subscribe("tickers.BTCUSDT");
793 assert_eq!(state.len(), 2);
794
795 let topics = state.all_topics();
797 assert_eq!(topics.len(), 2);
798 assert!(topics.contains(&"tickers".to_string()));
799 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
800
801 state.mark_subscribe("tickers.ETHUSDT");
803 state.confirm_subscribe("tickers.ETHUSDT");
804 assert_eq!(state.len(), 3);
805
806 let topics = state.all_topics();
807 assert_eq!(topics.len(), 3);
808 assert!(topics.contains(&"tickers".to_string()));
809 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
810 assert!(topics.contains(&"tickers.ETHUSDT".to_string()));
811
812 state.mark_unsubscribe("tickers");
814 state.confirm_unsubscribe("tickers");
815 assert_eq!(state.len(), 2);
816
817 let topics = state.all_topics();
818 assert_eq!(topics.len(), 2);
819 assert!(!topics.contains(&"tickers".to_string()));
820 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
821 assert!(topics.contains(&"tickers.ETHUSDT".to_string()));
822 }
823
824 #[rstest]
825 fn test_symbol_subscription_before_channel() {
826 let state = SubscriptionState::new('.');
827
828 state.mark_subscribe("tickers.BTCUSDT");
830 state.confirm_subscribe("tickers.BTCUSDT");
831 assert_eq!(state.len(), 1);
832
833 state.mark_subscribe("tickers");
835 state.confirm_subscribe("tickers");
836 assert_eq!(state.len(), 2);
837
838 let topics = state.all_topics();
840 assert_eq!(topics.len(), 2);
841 assert!(topics.contains(&"tickers".to_string()));
842 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
843 }
844
845 #[rstest]
846 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
847 async fn test_concurrent_subscribe_same_topic() {
848 let state = Arc::new(SubscriptionState::new('.'));
849 let mut handles = vec![];
850
851 for _ in 0..10 {
853 let state_clone = Arc::clone(&state);
854 let handle = tokio::spawn(async move {
855 state_clone.add_reference("tickers.BTCUSDT");
856 state_clone.mark_subscribe("tickers.BTCUSDT");
857 state_clone.confirm_subscribe("tickers.BTCUSDT");
858 });
859 handles.push(handle);
860 }
861
862 for handle in handles {
863 handle.await.unwrap();
864 }
865
866 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 10);
868 assert_eq!(state.len(), 1);
869 }
870
871 #[rstest]
872 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
873 async fn test_concurrent_subscribe_unsubscribe() {
874 let state = Arc::new(SubscriptionState::new('.'));
875 let mut handles = vec![];
876
877 for i in 0..20 {
880 let state_clone = Arc::clone(&state);
881 let handle = tokio::spawn(async move {
882 let topic = format!("tickers.SYMBOL{i}");
883 state_clone.add_reference(&topic);
885 state_clone.add_reference(&topic);
886 state_clone.mark_subscribe(&topic);
887 state_clone.confirm_subscribe(&topic);
888
889 state_clone.remove_reference(&topic);
891 });
892 handles.push(handle);
893 }
894
895 for handle in handles {
896 handle.await.unwrap();
897 }
898
899 for i in 0..20 {
901 let topic = format!("tickers.SYMBOL{i}");
902 assert_eq!(state.get_reference_count(&topic), 1);
903 }
904
905 assert_eq!(state.len(), 20);
907 assert!(!state.is_empty());
908 }
909
910 #[rstest]
911 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
912 async fn test_concurrent_reference_counting_same_topic() {
913 let state = Arc::new(SubscriptionState::new('.'));
914 let topic = "tickers.BTCUSDT";
915 let mut handles = vec![];
916
917 for _ in 0..10 {
919 let state_clone = Arc::clone(&state);
920 let handle = tokio::spawn(async move {
921 for _ in 0..10 {
922 state_clone.add_reference(topic);
923 }
924 });
925 handles.push(handle);
926 }
927
928 for handle in handles {
929 handle.await.unwrap();
930 }
931
932 assert_eq!(state.get_reference_count(topic), 100);
934
935 for _ in 0..50 {
937 state.remove_reference(topic);
938 }
939
940 assert_eq!(state.get_reference_count(topic), 50);
942 }
943
944 #[rstest]
945 fn test_reconnection_scenario() {
946 let state = SubscriptionState::new('.');
947
948 state.add_reference("tickers.BTCUSDT");
950 state.mark_subscribe("tickers.BTCUSDT");
951 state.confirm_subscribe("tickers.BTCUSDT");
952
953 state.add_reference("tickers.ETHUSDT");
954 state.mark_subscribe("tickers.ETHUSDT");
955 state.confirm_subscribe("tickers.ETHUSDT");
956
957 state.add_reference("orderbook");
958 state.mark_subscribe("orderbook");
959 state.confirm_subscribe("orderbook");
960
961 assert_eq!(state.len(), 3);
962
963 let topics_to_resubscribe = state.all_topics();
965 assert_eq!(topics_to_resubscribe.len(), 3);
966 assert!(topics_to_resubscribe.contains(&"tickers.BTCUSDT".to_string()));
967 assert!(topics_to_resubscribe.contains(&"tickers.ETHUSDT".to_string()));
968 assert!(topics_to_resubscribe.contains(&"orderbook".to_string()));
969
970 for topic in &topics_to_resubscribe {
972 state.mark_subscribe(topic);
973 }
974
975 for topic in &topics_to_resubscribe {
977 state.confirm_subscribe(topic);
978 }
979
980 assert_eq!(state.len(), 3);
982 assert_eq!(state.all_topics().len(), 3);
983 }
984
985 #[rstest]
986 fn test_state_machine_invalid_transitions() {
987 let state = SubscriptionState::new('.');
988
989 state.confirm_subscribe("tickers.BTCUSDT");
991 assert_eq!(state.len(), 1); state.confirm_unsubscribe("tickers.ETHUSDT");
995 assert_eq!(state.len(), 1); state.mark_subscribe("orderbook");
999 state.confirm_subscribe("orderbook");
1000 state.confirm_subscribe("orderbook"); assert_eq!(state.len(), 2);
1002
1003 state.mark_unsubscribe("nonexistent");
1005 state.confirm_unsubscribe("nonexistent");
1006 assert_eq!(state.len(), 2); }
1008
1009 #[rstest]
1010 fn test_mark_failure_moves_to_pending() {
1011 let state = SubscriptionState::new('.');
1012
1013 state.mark_subscribe("tickers.BTCUSDT");
1015 state.confirm_subscribe("tickers.BTCUSDT");
1016 assert_eq!(state.len(), 1);
1017 assert!(state.pending_subscribe_topics().is_empty());
1018
1019 state.mark_failure("tickers.BTCUSDT");
1021
1022 assert_eq!(state.len(), 0);
1024 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
1025
1026 assert_eq!(state.all_topics(), vec!["tickers.BTCUSDT"]);
1028 }
1029
1030 #[rstest]
1031 fn test_pending_subscribe_excludes_pending_unsubscribe() {
1032 let state = SubscriptionState::new('.');
1033
1034 state.mark_subscribe("tickers.BTCUSDT");
1036 state.confirm_subscribe("tickers.BTCUSDT");
1037
1038 state.mark_unsubscribe("tickers.BTCUSDT");
1040
1041 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
1043 assert!(state.all_topics().is_empty());
1044 assert_eq!(state.len(), 0);
1045 }
1046
1047 #[rstest]
1048 fn test_remove_reference_nonexistent_topic() {
1049 let state = SubscriptionState::new('.');
1050
1051 let should_unsubscribe = state.remove_reference("nonexistent");
1053
1054 assert!(!should_unsubscribe);
1056 assert_eq!(state.get_reference_count("nonexistent"), 0);
1057 }
1058
1059 #[rstest]
1060 fn test_edge_case_empty_channel_name() {
1061 let state = SubscriptionState::new('.');
1062
1063 state.mark_subscribe("");
1065 state.confirm_subscribe("");
1066
1067 assert_eq!(state.len(), 1);
1068 assert_eq!(state.all_topics(), vec![""]);
1069 }
1070
1071 #[rstest]
1072 fn test_special_characters_in_topics() {
1073 let state = SubscriptionState::new('.');
1074
1075 let special_topics = vec![
1077 "channel.symbol-with-dash",
1078 "channel.SYMBOL_WITH_UNDERSCORE",
1079 "channel.symbol123",
1080 "channel.symbol@special",
1081 ];
1082
1083 for topic in &special_topics {
1084 state.mark_subscribe(topic);
1085 state.confirm_subscribe(topic);
1086 }
1087
1088 assert_eq!(state.len(), special_topics.len());
1089
1090 let all_topics = state.all_topics();
1091 for topic in &special_topics {
1092 assert!(
1093 all_topics.contains(&(*topic).to_string()),
1094 "Missing topic: {topic}"
1095 );
1096 }
1097 }
1098
1099 #[rstest]
1100 fn test_clear_resets_all_state() {
1101 let state = SubscriptionState::new('.');
1102
1103 for i in 0..10 {
1105 let topic = format!("channel{i}.SYMBOL");
1106 state.add_reference(&topic);
1107 state.add_reference(&topic); state.mark_subscribe(&topic);
1109 state.confirm_subscribe(&topic);
1110 }
1111
1112 assert_eq!(state.len(), 10);
1113 assert!(!state.is_empty());
1114
1115 state.clear();
1117
1118 assert_eq!(state.len(), 0);
1120 assert!(state.is_empty());
1121 assert!(state.all_topics().is_empty());
1122 assert!(state.pending_subscribe_topics().is_empty());
1123 assert!(state.pending_unsubscribe_topics().is_empty());
1124
1125 for i in 0..10 {
1127 let topic = format!("channel{i}.SYMBOL");
1128 assert_eq!(state.get_reference_count(&topic), 0);
1129 }
1130 }
1131
1132 #[rstest]
1133 fn test_different_delimiter_does_not_affect_storage() {
1134 let state_dot = SubscriptionState::new('.');
1136 let state_colon = SubscriptionState::new(':');
1137
1138 state_dot.mark_subscribe("channel.SYMBOL");
1140 state_colon.mark_subscribe("channel:SYMBOL");
1141
1142 assert_eq!(state_dot.pending_subscribe_topics(), vec!["channel.SYMBOL"]);
1144 assert_eq!(
1145 state_colon.pending_subscribe_topics(),
1146 vec!["channel:SYMBOL"]
1147 );
1148 }
1149
1150 #[rstest]
1151 fn test_unsubscribe_before_subscribe_confirmed() {
1152 let state = SubscriptionState::new('.');
1153
1154 state.mark_subscribe("tickers.BTCUSDT");
1156 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
1157
1158 state.mark_unsubscribe("tickers.BTCUSDT");
1160
1161 assert!(state.pending_subscribe_topics().is_empty());
1163 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
1164
1165 state.confirm_unsubscribe("tickers.BTCUSDT");
1167
1168 assert!(state.is_empty());
1170 assert!(state.all_topics().is_empty());
1171 assert_eq!(state.len(), 0);
1172 }
1173
1174 #[rstest]
1175 fn test_late_subscribe_confirmation_after_unsubscribe() {
1176 let state = SubscriptionState::new('.');
1177
1178 state.mark_subscribe("tickers.BTCUSDT");
1180
1181 state.mark_unsubscribe("tickers.BTCUSDT");
1183
1184 state.confirm_subscribe("tickers.BTCUSDT");
1186
1187 assert_eq!(state.len(), 0);
1189 assert!(state.pending_subscribe_topics().is_empty());
1190
1191 state.confirm_unsubscribe("tickers.BTCUSDT");
1193
1194 assert!(state.is_empty());
1196 assert!(state.all_topics().is_empty());
1197 }
1198
1199 #[rstest]
1200 fn test_unsubscribe_clears_all_states() {
1201 let state = SubscriptionState::new('.');
1202
1203 state.mark_subscribe("tickers.BTCUSDT");
1205 state.confirm_subscribe("tickers.BTCUSDT");
1206 assert_eq!(state.len(), 1);
1207
1208 state.mark_unsubscribe("tickers.BTCUSDT");
1210
1211 assert_eq!(state.len(), 0);
1213 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
1214
1215 state.confirm_subscribe("tickers.BTCUSDT");
1217
1218 state.confirm_unsubscribe("tickers.BTCUSDT");
1220
1221 assert!(state.is_empty());
1223 assert_eq!(state.len(), 0);
1224 assert!(state.pending_subscribe_topics().is_empty());
1225 assert!(state.pending_unsubscribe_topics().is_empty());
1226 assert!(state.all_topics().is_empty());
1227 }
1228
1229 #[rstest]
1230 fn test_mark_failure_respects_pending_unsubscribe() {
1231 let state = SubscriptionState::new('.');
1232
1233 state.mark_subscribe("tickers.BTCUSDT");
1235 state.confirm_subscribe("tickers.BTCUSDT");
1236 assert_eq!(state.len(), 1);
1237
1238 state.mark_unsubscribe("tickers.BTCUSDT");
1240 assert_eq!(state.len(), 0);
1241 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
1242
1243 state.mark_failure("tickers.BTCUSDT");
1245
1246 assert!(state.pending_subscribe_topics().is_empty());
1248 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
1249
1250 assert!(state.all_topics().is_empty());
1252
1253 state.confirm_unsubscribe("tickers.BTCUSDT");
1255 assert!(state.is_empty());
1256 }
1257
1258 #[rstest]
1259 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
1260 async fn test_concurrent_stress_mixed_operations() {
1261 let state = Arc::new(SubscriptionState::new('.'));
1262 let mut handles = vec![];
1263
1264 for i in 0..50 {
1266 let state_clone = Arc::clone(&state);
1267 let handle = tokio::spawn(async move {
1268 let topic1 = format!("channel.SYMBOL{i}");
1269 let topic2 = format!("channel.SYMBOL{}", i + 100);
1270
1271 state_clone.add_reference(&topic1);
1273 state_clone.add_reference(&topic2);
1274
1275 state_clone.mark_subscribe(&topic1);
1277 state_clone.confirm_subscribe(&topic1);
1278 state_clone.mark_subscribe(&topic2);
1279
1280 if i % 3 == 0 {
1282 state_clone.mark_unsubscribe(&topic1);
1283 state_clone.confirm_unsubscribe(&topic1);
1284 }
1285
1286 state_clone.add_reference(&topic2);
1288 state_clone.remove_reference(&topic2);
1289
1290 state_clone.confirm_subscribe(&topic2);
1292 });
1293 handles.push(handle);
1294 }
1295
1296 for handle in handles {
1297 handle.await.unwrap();
1298 }
1299
1300 let all = state.all_topics();
1302 let confirmed_count = state.len();
1303
1304 assert!(confirmed_count > 50); assert!(confirmed_count <= 100); assert_eq!(
1309 all.len(),
1310 confirmed_count + state.pending_subscribe_topics().len()
1311 );
1312 }
1313
1314 #[rstest]
1315 fn test_edge_case_malformed_topics() {
1316 let state = SubscriptionState::new('.');
1317
1318 state.mark_subscribe("channel.symbol.extra");
1320 state.confirm_subscribe("channel.symbol.extra");
1321 let topics = state.all_topics();
1322 assert!(topics.contains(&"channel.symbol.extra".to_string()));
1323
1324 state.mark_subscribe(".channel");
1326 state.confirm_subscribe(".channel");
1327 assert_eq!(state.len(), 2);
1328
1329 state.mark_subscribe("channel.");
1332 state.confirm_subscribe("channel.");
1333 assert_eq!(state.len(), 3);
1334
1335 state.mark_subscribe("tickers");
1337 state.confirm_subscribe("tickers");
1338 assert_eq!(state.len(), 4);
1339
1340 let all = state.all_topics();
1342 assert_eq!(all.len(), 4);
1343 assert!(all.contains(&"channel.symbol.extra".to_string()));
1344 assert!(all.contains(&".channel".to_string()));
1345 assert!(all.contains(&"channel".to_string())); assert!(all.contains(&"tickers".to_string()));
1347 }
1348
1349 #[rstest]
1350 fn test_reference_count_underflow_safety() {
1351 let state = SubscriptionState::new('.');
1352
1353 assert!(!state.remove_reference("never.added"));
1355 assert_eq!(state.get_reference_count("never.added"), 0);
1356
1357 state.add_reference("once.added");
1359 assert_eq!(state.get_reference_count("once.added"), 1);
1360
1361 assert!(state.remove_reference("once.added")); assert_eq!(state.get_reference_count("once.added"), 0);
1363
1364 assert!(!state.remove_reference("once.added")); assert!(!state.remove_reference("once.added")); assert_eq!(state.get_reference_count("once.added"), 0);
1367
1368 assert!(state.add_reference("once.added"));
1370 assert_eq!(state.get_reference_count("once.added"), 1);
1371 }
1372
1373 #[rstest]
1374 fn test_reconnection_with_partial_state() {
1375 let state = SubscriptionState::new('.');
1376
1377 state.mark_subscribe("confirmed.BTCUSDT");
1380 state.confirm_subscribe("confirmed.BTCUSDT");
1381
1382 state.mark_subscribe("pending.ETHUSDT");
1384
1385 state.mark_subscribe("cancelled.XRPUSDT");
1387 state.confirm_subscribe("cancelled.XRPUSDT");
1388 state.mark_unsubscribe("cancelled.XRPUSDT");
1389
1390 assert_eq!(state.len(), 1); let all = state.all_topics();
1393 assert_eq!(all.len(), 2); assert!(all.contains(&"confirmed.BTCUSDT".to_string()));
1395 assert!(all.contains(&"pending.ETHUSDT".to_string()));
1396 assert!(!all.contains(&"cancelled.XRPUSDT".to_string())); let topics_to_resubscribe = state.all_topics();
1400
1401 state.confirmed().clear();
1403
1404 for topic in &topics_to_resubscribe {
1406 state.mark_subscribe(topic);
1407 }
1408
1409 for topic in &topics_to_resubscribe {
1411 state.confirm_subscribe(topic);
1412 }
1413
1414 assert_eq!(state.len(), 2); let final_topics = state.all_topics();
1417 assert_eq!(final_topics.len(), 2);
1418 assert!(final_topics.contains(&"confirmed.BTCUSDT".to_string()));
1419 assert!(final_topics.contains(&"pending.ETHUSDT".to_string()));
1420 assert!(!final_topics.contains(&"cancelled.XRPUSDT".to_string()));
1421 }
1422
1423 fn check_invariants(state: &SubscriptionState, label: &str) {
1434 let confirmed_topics: AHashSet<String> = state
1436 .topics_from_map(&state.confirmed)
1437 .into_iter()
1438 .collect();
1439 let pending_sub_topics: AHashSet<String> =
1440 state.pending_subscribe_topics().into_iter().collect();
1441 let pending_unsub_topics: AHashSet<String> =
1442 state.pending_unsubscribe_topics().into_iter().collect();
1443
1444 let confirmed_and_pending_sub: Vec<_> =
1446 confirmed_topics.intersection(&pending_sub_topics).collect();
1447 assert!(
1448 confirmed_and_pending_sub.is_empty(),
1449 "{label}: Topic in both confirmed and pending_subscribe: {confirmed_and_pending_sub:?}"
1450 );
1451
1452 let confirmed_and_pending_unsub: Vec<_> = confirmed_topics
1453 .intersection(&pending_unsub_topics)
1454 .collect();
1455 assert!(
1456 confirmed_and_pending_unsub.is_empty(),
1457 "{label}: Topic in both confirmed and pending_unsubscribe: {confirmed_and_pending_unsub:?}"
1458 );
1459
1460 let pending_sub_and_unsub: Vec<_> = pending_sub_topics
1461 .intersection(&pending_unsub_topics)
1462 .collect();
1463 assert!(
1464 pending_sub_and_unsub.is_empty(),
1465 "{label}: Topic in both pending_subscribe and pending_unsubscribe: {pending_sub_and_unsub:?}"
1466 );
1467
1468 let all_topics: AHashSet<String> = state.all_topics().into_iter().collect();
1470 let expected_all: AHashSet<String> = confirmed_topics
1471 .union(&pending_sub_topics)
1472 .cloned()
1473 .collect();
1474 assert_eq!(
1475 all_topics, expected_all,
1476 "{label}: all_topics() doesn't match confirmed ∪ pending_subscribe"
1477 );
1478
1479 for topic in &pending_unsub_topics {
1481 assert!(
1482 !all_topics.contains(topic),
1483 "{label}: pending_unsubscribe topic {topic} incorrectly in all_topics()"
1484 );
1485 }
1486
1487 let expected_len: usize = state
1489 .confirmed
1490 .iter()
1491 .map(|entry| entry.value().len())
1492 .sum();
1493 assert_eq!(
1494 state.len(),
1495 expected_len,
1496 "{label}: len() mismatch. Expected {expected_len}, was {}",
1497 state.len()
1498 );
1499
1500 let should_be_empty = state.confirmed.is_empty()
1502 && pending_sub_topics.is_empty()
1503 && pending_unsub_topics.is_empty();
1504 assert_eq!(
1505 state.is_empty(),
1506 should_be_empty,
1507 "{label}: is_empty() inconsistent. Maps empty: {should_be_empty}, is_empty(): {}",
1508 state.is_empty()
1509 );
1510
1511 for entry in state.reference_counts.iter() {
1513 let count = entry.value().get();
1514 assert!(
1515 count > 0,
1516 "{label}: Reference count should be NonZeroUsize (> 0), was {count} for {:?}",
1517 entry.key()
1518 );
1519 }
1520 }
1521
1522 fn check_topic_exclusivity(state: &SubscriptionState, topic: &str, label: &str) {
1524 let (channel, symbol) = split_topic(topic, state.delimiter);
1525
1526 let in_confirmed = is_tracked(&state.confirmed, channel, symbol);
1527 let in_pending_sub = is_tracked(&state.pending_subscribe, channel, symbol);
1528 let in_pending_unsub = is_tracked(&state.pending_unsubscribe, channel, symbol);
1529
1530 let count = [in_confirmed, in_pending_sub, in_pending_unsub]
1531 .iter()
1532 .filter(|&&x| x)
1533 .count();
1534
1535 assert!(
1536 count <= 1,
1537 "{label}: Topic {topic} in {count} states (should be 0 or 1). \
1538 confirmed: {in_confirmed}, pending_sub: {in_pending_sub}, pending_unsub: {in_pending_unsub}"
1539 );
1540 }
1541
1542 #[cfg(test)]
1543 mod property_tests {
1544 use proptest::prelude::*;
1545
1546 use super::*;
1547
1548 #[derive(Debug, Clone)]
1549 enum Operation {
1550 MarkSubscribe(String),
1551 ConfirmSubscribe(String),
1552 MarkUnsubscribe(String),
1553 ConfirmUnsubscribe(String),
1554 MarkFailure(String),
1555 AddReference(String),
1556 RemoveReference(String),
1557 Clear,
1558 }
1559
1560 fn topic_strategy() -> impl Strategy<Value = String> {
1562 prop_oneof![
1563 (any::<u8>(), any::<u8>())
1565 .prop_map(|(ch, sym)| { format!("channel{}.SYMBOL{}", ch % 5, sym % 10) }),
1566 any::<u8>().prop_map(|ch| format!("channel{}", ch % 5)),
1568 ]
1569 }
1570
1571 fn operation_strategy() -> impl Strategy<Value = Operation> {
1573 topic_strategy().prop_flat_map(|topic| {
1574 prop_oneof![
1575 Just(Operation::MarkSubscribe(topic.clone())),
1576 Just(Operation::ConfirmSubscribe(topic.clone())),
1577 Just(Operation::MarkUnsubscribe(topic.clone())),
1578 Just(Operation::ConfirmUnsubscribe(topic.clone())),
1579 Just(Operation::MarkFailure(topic.clone())),
1580 Just(Operation::AddReference(topic.clone())),
1581 Just(Operation::RemoveReference(topic)),
1582 Just(Operation::Clear),
1583 ]
1584 })
1585 }
1586
1587 fn apply_operation(state: &SubscriptionState, op: &Operation) {
1589 match op {
1590 Operation::MarkSubscribe(topic) => state.mark_subscribe(topic),
1591 Operation::ConfirmSubscribe(topic) => state.confirm_subscribe(topic),
1592 Operation::MarkUnsubscribe(topic) => state.mark_unsubscribe(topic),
1593 Operation::ConfirmUnsubscribe(topic) => state.confirm_unsubscribe(topic),
1594 Operation::MarkFailure(topic) => state.mark_failure(topic),
1595 Operation::AddReference(topic) => {
1596 state.add_reference(topic);
1597 }
1598 Operation::RemoveReference(topic) => {
1599 state.remove_reference(topic);
1600 }
1601 Operation::Clear => state.clear(),
1602 }
1603 }
1604
1605 proptest! {
1606 #![proptest_config(ProptestConfig::with_cases(500))]
1607
1608 #[rstest]
1610 fn prop_invariants_hold_after_operations(
1611 operations in prop::collection::vec(operation_strategy(), 1..50)
1612 ) {
1613 let state = SubscriptionState::new('.');
1614
1615 for (i, op) in operations.iter().enumerate() {
1617 apply_operation(&state, op);
1618
1619 check_invariants(&state, &format!("After op {i}: {op:?}"));
1621 }
1622
1623 check_invariants(&state, "Final state");
1625 }
1626
1627 #[rstest]
1629 fn prop_reference_counting_consistency(
1630 ops in prop::collection::vec(
1631 topic_strategy().prop_flat_map(|t| {
1632 prop_oneof![
1633 Just(Operation::AddReference(t.clone())),
1634 Just(Operation::RemoveReference(t)),
1635 ]
1636 }),
1637 1..100
1638 )
1639 ) {
1640 let state = SubscriptionState::new('.');
1641
1642 for op in &ops {
1643 apply_operation(&state, op);
1644
1645 for entry in state.reference_counts.iter() {
1647 assert!(entry.value().get() > 0);
1648 }
1649 }
1650 }
1651
1652 #[rstest]
1654 fn prop_all_topics_is_union(
1655 operations in prop::collection::vec(operation_strategy(), 1..50)
1656 ) {
1657 let state = SubscriptionState::new('.');
1658
1659 for op in &operations {
1660 apply_operation(&state, op);
1661
1662 let all_topics: AHashSet<String> = state.all_topics().into_iter().collect();
1664 let confirmed: AHashSet<String> = state.topics_from_map(&state.confirmed).into_iter().collect();
1665 let pending_sub: AHashSet<String> = state.pending_subscribe_topics().into_iter().collect();
1666 let expected: AHashSet<String> = confirmed.union(&pending_sub).cloned().collect();
1667
1668 assert_eq!(all_topics, expected);
1669
1670 let pending_unsub: AHashSet<String> = state.pending_unsubscribe_topics().into_iter().collect();
1672 for topic in pending_unsub {
1673 assert!(!all_topics.contains(&topic));
1674 }
1675 }
1676 }
1677
1678 #[rstest]
1680 fn prop_clear_resets_completely(
1681 operations in prop::collection::vec(operation_strategy(), 1..30)
1682 ) {
1683 let state = SubscriptionState::new('.');
1684
1685 for op in &operations {
1687 apply_operation(&state, op);
1688 }
1689
1690 state.clear();
1692
1693 assert!(state.is_empty());
1694 assert_eq!(state.len(), 0);
1695 assert!(state.all_topics().is_empty());
1696 assert!(state.pending_subscribe_topics().is_empty());
1697 assert!(state.pending_unsubscribe_topics().is_empty());
1698 assert!(state.confirmed.is_empty());
1699 assert!(state.pending_subscribe.is_empty());
1700 assert!(state.pending_unsubscribe.is_empty());
1701 assert!(state.reference_counts.is_empty());
1702 }
1703
1704 #[rstest]
1706 fn prop_topic_mutual_exclusivity(
1707 operations in prop::collection::vec(operation_strategy(), 1..50),
1708 topic in topic_strategy()
1709 ) {
1710 let state = SubscriptionState::new('.');
1711
1712 for (i, op) in operations.iter().enumerate() {
1713 apply_operation(&state, op);
1714 check_topic_exclusivity(&state, &topic, &format!("After op {i}: {op:?}"));
1715 }
1716 }
1717 }
1718 }
1719
1720 #[rstest]
1721 fn test_exhaustive_two_step_transitions() {
1722 let operations = [
1723 "mark_subscribe",
1724 "confirm_subscribe",
1725 "mark_unsubscribe",
1726 "confirm_unsubscribe",
1727 "mark_failure",
1728 ];
1729
1730 for &op1 in &operations {
1731 for &op2 in &operations {
1732 let state = SubscriptionState::new('.');
1733 let topic = "test.TOPIC";
1734
1735 apply_op(&state, op1, topic);
1737 apply_op(&state, op2, topic);
1738
1739 check_invariants(&state, &format!("{op1} → {op2}"));
1741 check_topic_exclusivity(&state, topic, &format!("{op1} → {op2}"));
1742 }
1743 }
1744 }
1745
1746 fn apply_op(state: &SubscriptionState, op: &str, topic: &str) {
1747 match op {
1748 "mark_subscribe" => state.mark_subscribe(topic),
1749 "confirm_subscribe" => state.confirm_subscribe(topic),
1750 "mark_unsubscribe" => state.mark_unsubscribe(topic),
1751 "confirm_unsubscribe" => state.confirm_unsubscribe(topic),
1752 "mark_failure" => state.mark_failure(topic),
1753 _ => panic!("Unknown operation: {op}"),
1754 }
1755 }
1756
1757 #[rstest]
1758 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
1759 async fn test_stress_rapid_resubscribe_pattern() {
1760 let state = Arc::new(SubscriptionState::new('.'));
1762 let mut handles = vec![];
1763
1764 for i in 0..100 {
1765 let state_clone = Arc::clone(&state);
1766 let handle = tokio::spawn(async move {
1767 let topic = format!("rapid.SYMBOL{}", i % 10); state_clone.mark_subscribe(&topic);
1771 state_clone.confirm_subscribe(&topic);
1772
1773 state_clone.mark_unsubscribe(&topic);
1775 state_clone.mark_subscribe(&topic);
1777 state_clone.confirm_unsubscribe(&topic);
1779 state_clone.confirm_subscribe(&topic);
1781 });
1782 handles.push(handle);
1783 }
1784
1785 for handle in handles {
1786 handle.await.unwrap();
1787 }
1788
1789 check_invariants(&state, "After rapid resubscribe stress test");
1790 }
1791
1792 #[rstest]
1793 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
1794 async fn test_stress_failure_recovery_loop() {
1795 let state = Arc::new(SubscriptionState::new('.'));
1798 let mut handles = vec![];
1799
1800 for i in 0..30 {
1801 let state_clone = Arc::clone(&state);
1802 let handle = tokio::spawn(async move {
1803 let topic = format!("failure.SYMBOL{i}"); state_clone.mark_subscribe(&topic);
1807 state_clone.confirm_subscribe(&topic);
1808
1809 for _ in 0..5 {
1811 state_clone.mark_failure(&topic);
1812 state_clone.confirm_subscribe(&topic); }
1814 });
1815 handles.push(handle);
1816 }
1817
1818 for handle in handles {
1819 handle.await.unwrap();
1820 }
1821
1822 check_invariants(&state, "After failure recovery loops");
1823
1824 assert_eq!(state.len(), 30);
1826 }
1827}