1use std::{
17 cell::RefCell,
18 collections::HashMap,
19 fmt::Display,
20 hash::{Hash, Hasher},
21 ops::Deref,
22 rc::Rc,
23};
24
25use ahash::{AHashMap, AHashSet};
26use handler::ShareableMessageHandler;
27use indexmap::IndexMap;
28use matching::is_matching_backtracking;
29use nautilus_core::{
30 UUID4,
31 correctness::{FAILED, check_predicate_true, check_valid_string_utf8},
32};
33use nautilus_model::identifiers::TraderId;
34use serde::{Deserialize, Serialize};
35use switchboard::MessagingSwitchboard;
36use ustr::Ustr;
37
38use super::{handler, matching, set_message_bus, switchboard};
39
40#[inline(always)]
41fn check_fully_qualified_string(value: &Ustr, key: &str) -> anyhow::Result<()> {
42 check_predicate_true(
43 !value.chars().any(|c| c == '*' || c == '?'),
44 &format!("{key} `value` contained invalid characters, was {value}"),
45 )
46}
47
48#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
50pub struct Pattern;
51
52#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
54pub struct Topic;
55
56#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
58pub struct Endpoint;
59
60#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
62#[serde(transparent)]
63pub struct MStr<T> {
64 value: Ustr,
65 #[serde(skip)]
66 _marker: std::marker::PhantomData<T>,
67}
68
69impl<T> Display for MStr<T> {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 write!(f, "{}", self.value)
72 }
73}
74
75impl<T> Deref for MStr<T> {
76 type Target = Ustr;
77
78 fn deref(&self) -> &Self::Target {
79 &self.value
80 }
81}
82
83impl<T> AsRef<str> for MStr<T> {
84 fn as_ref(&self) -> &str {
85 self.value.as_str()
86 }
87}
88
89impl MStr<Pattern> {
90 pub fn pattern<T: AsRef<str>>(value: T) -> Self {
92 let value = Ustr::from(value.as_ref());
93
94 Self {
95 value,
96 _marker: std::marker::PhantomData,
97 }
98 }
99}
100
101impl From<&str> for MStr<Pattern> {
102 fn from(value: &str) -> Self {
103 Self::pattern(value)
104 }
105}
106
107impl From<String> for MStr<Pattern> {
108 fn from(value: String) -> Self {
109 value.as_str().into()
110 }
111}
112
113impl From<&String> for MStr<Pattern> {
114 fn from(value: &String) -> Self {
115 value.as_str().into()
116 }
117}
118
119impl From<MStr<Topic>> for MStr<Pattern> {
120 fn from(value: MStr<Topic>) -> Self {
121 Self {
122 value: value.value,
123 _marker: std::marker::PhantomData,
124 }
125 }
126}
127
128impl MStr<Topic> {
129 pub fn topic<T: AsRef<str>>(value: T) -> anyhow::Result<Self> {
135 let topic = Ustr::from(value.as_ref());
136 check_valid_string_utf8(value, stringify!(value))?;
137 check_fully_qualified_string(&topic, stringify!(Topic))?;
138
139 Ok(Self {
140 value: topic,
141 _marker: std::marker::PhantomData,
142 })
143 }
144}
145
146impl From<&str> for MStr<Topic> {
147 fn from(value: &str) -> Self {
148 Self::topic(value).expect(FAILED)
149 }
150}
151
152impl From<String> for MStr<Topic> {
153 fn from(value: String) -> Self {
154 value.as_str().into()
155 }
156}
157
158impl From<&String> for MStr<Topic> {
159 fn from(value: &String) -> Self {
160 value.as_str().into()
161 }
162}
163
164impl From<Ustr> for MStr<Topic> {
165 fn from(value: Ustr) -> Self {
166 value.as_str().into()
167 }
168}
169
170impl From<&Ustr> for MStr<Topic> {
171 fn from(value: &Ustr) -> Self {
172 (*value).into()
173 }
174}
175
176impl MStr<Endpoint> {
177 pub fn endpoint<T: AsRef<str>>(value: T) -> anyhow::Result<Self> {
183 let endpoint = Ustr::from(value.as_ref());
184 check_valid_string_utf8(value, stringify!(value))?;
185 check_fully_qualified_string(&endpoint, stringify!(Endpoint))?;
186
187 Ok(Self {
188 value: endpoint,
189 _marker: std::marker::PhantomData,
190 })
191 }
192}
193
194impl From<&str> for MStr<Endpoint> {
195 fn from(value: &str) -> Self {
196 Self::endpoint(value).expect(FAILED)
197 }
198}
199
200impl From<String> for MStr<Endpoint> {
201 fn from(value: String) -> Self {
202 value.as_str().into()
203 }
204}
205
206impl From<&String> for MStr<Endpoint> {
207 fn from(value: &String) -> Self {
208 value.as_str().into()
209 }
210}
211
212impl From<Ustr> for MStr<Endpoint> {
213 fn from(value: Ustr) -> Self {
214 value.as_str().into()
215 }
216}
217
218#[derive(Clone, Debug)]
224pub struct Subscription {
225 pub handler: ShareableMessageHandler,
227 pub handler_id: Ustr,
229 pub pattern: MStr<Pattern>,
231 pub priority: u8,
235}
236
237impl Subscription {
238 #[must_use]
240 pub fn new(
241 pattern: MStr<Pattern>,
242 handler: ShareableMessageHandler,
243 priority: Option<u8>,
244 ) -> Self {
245 Self {
246 handler_id: handler.0.id(),
247 pattern,
248 handler,
249 priority: priority.unwrap_or(0),
250 }
251 }
252}
253
254impl PartialEq<Self> for Subscription {
255 fn eq(&self, other: &Self) -> bool {
256 self.pattern == other.pattern && self.handler_id == other.handler_id
257 }
258}
259
260impl Eq for Subscription {}
261
262impl PartialOrd for Subscription {
263 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
264 Some(self.cmp(other))
265 }
266}
267
268impl Ord for Subscription {
269 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
270 other
271 .priority
272 .cmp(&self.priority)
273 .then_with(|| self.pattern.cmp(&other.pattern))
274 .then_with(|| self.handler_id.cmp(&other.handler_id))
275 }
276}
277
278impl Hash for Subscription {
279 fn hash<H: Hasher>(&self, state: &mut H) {
280 self.pattern.hash(state);
281 self.handler_id.hash(state);
282 }
283}
284
285#[derive(Debug)]
306pub struct MessageBus {
307 pub trader_id: TraderId,
309 pub instance_id: UUID4,
311 pub name: String,
313 pub has_backing: bool,
315 pub switchboard: MessagingSwitchboard,
317 pub subscriptions: AHashSet<Subscription>,
319 pub topics: IndexMap<MStr<Topic>, Vec<Subscription>>,
322 pub endpoints: IndexMap<MStr<Endpoint>, ShareableMessageHandler>,
324 pub correlation_index: AHashMap<UUID4, ShareableMessageHandler>,
326}
327
328impl MessageBus {
334 #[must_use]
336 pub fn new(
337 trader_id: TraderId,
338 instance_id: UUID4,
339 name: Option<String>,
340 _config: Option<HashMap<String, serde_json::Value>>,
341 ) -> Self {
342 Self {
343 trader_id,
344 instance_id,
345 name: name.unwrap_or(stringify!(MessageBus).to_owned()),
346 switchboard: MessagingSwitchboard::default(),
347 subscriptions: AHashSet::new(),
348 topics: IndexMap::new(),
349 endpoints: IndexMap::new(),
350 correlation_index: AHashMap::new(),
351 has_backing: false,
352 }
353 }
354
355 #[must_use]
357 pub fn mem_address(&self) -> String {
358 format!("{self:p}")
359 }
360
361 #[must_use]
363 pub fn endpoints(&self) -> Vec<&str> {
364 self.endpoints.iter().map(|e| e.0.as_str()).collect()
365 }
366
367 #[must_use]
369 pub fn patterns(&self) -> Vec<&str> {
370 self.subscriptions
371 .iter()
372 .map(|s| s.pattern.as_str())
373 .collect()
374 }
375
376 pub fn has_subscribers<T: AsRef<str>>(&self, topic: T) -> bool {
378 self.subscriptions_count(topic) > 0
379 }
380
381 #[must_use]
387 pub fn subscriptions_count<T: AsRef<str>>(&self, topic: T) -> usize {
388 let topic = MStr::<Topic>::topic(topic).expect(FAILED);
389 self.topics
390 .get(&topic)
391 .map_or_else(|| self.find_topic_matches(topic).len(), |subs| subs.len())
392 }
393
394 #[must_use]
396 pub fn subscriptions(&self) -> Vec<&Subscription> {
397 self.subscriptions.iter().collect()
398 }
399
400 #[must_use]
402 pub fn subscription_handler_ids(&self) -> Vec<&str> {
403 self.subscriptions
404 .iter()
405 .map(|s| s.handler_id.as_str())
406 .collect()
407 }
408
409 #[must_use]
415 pub fn is_registered<T: Into<MStr<Endpoint>>>(&self, endpoint: T) -> bool {
416 let endpoint: MStr<Endpoint> = endpoint.into();
417 self.endpoints.contains_key(&endpoint)
418 }
419
420 #[must_use]
422 pub fn is_subscribed<T: AsRef<str>>(
423 &self,
424 pattern: T,
425 handler: ShareableMessageHandler,
426 ) -> bool {
427 let pattern = MStr::<Pattern>::pattern(pattern);
428 let sub = Subscription::new(pattern, handler, None);
429 self.subscriptions.contains(&sub)
430 }
431
432 pub const fn close(&self) -> anyhow::Result<()> {
438 Ok(())
440 }
441
442 #[must_use]
444 pub fn get_endpoint(&self, endpoint: MStr<Endpoint>) -> Option<&ShareableMessageHandler> {
445 self.endpoints.get(&endpoint)
446 }
447
448 #[must_use]
450 pub fn get_response_handler(&self, correlation_id: &UUID4) -> Option<&ShareableMessageHandler> {
451 self.correlation_index.get(correlation_id)
452 }
453
454 pub(crate) fn find_topic_matches(&self, topic: MStr<Topic>) -> Vec<Subscription> {
456 self.subscriptions
457 .iter()
458 .filter_map(|sub| {
459 if is_matching_backtracking(topic, sub.pattern) {
460 Some(sub.clone())
461 } else {
462 None
463 }
464 })
465 .collect()
466 }
467
468 #[must_use]
471 pub fn matching_subscriptions<T: Into<MStr<Topic>>>(&mut self, topic: T) -> Vec<Subscription> {
472 self.inner_matching_subscriptions(topic.into())
473 }
474
475 pub(crate) fn inner_matching_subscriptions(&mut self, topic: MStr<Topic>) -> Vec<Subscription> {
476 self.topics.get(&topic).cloned().unwrap_or_else(|| {
477 let mut matches = self.find_topic_matches(topic);
478 matches.sort();
479 self.topics.insert(topic, matches.clone());
480 matches
481 })
482 }
483
484 pub fn register_response_handler(
490 &mut self,
491 correlation_id: &UUID4,
492 handler: ShareableMessageHandler,
493 ) -> anyhow::Result<()> {
494 if self.correlation_index.contains_key(correlation_id) {
495 anyhow::bail!("Correlation ID <{correlation_id}> already has a registered handler");
496 }
497
498 self.correlation_index.insert(*correlation_id, handler);
499
500 Ok(())
501 }
502}
503
504impl MessageBus {
506 pub fn register_message_bus(self) -> Rc<RefCell<Self>> {
523 let msgbus = Rc::new(RefCell::new(self));
524 set_message_bus(msgbus.clone());
525 msgbus
526 }
527}
528
529impl Default for MessageBus {
530 fn default() -> Self {
532 Self::new(TraderId::from("TRADER-001"), UUID4::new(), None, None)
533 }
534}