nautilus_common/msgbus/
stubs.rs1use std::{
17 any::Any,
18 cell::RefCell,
19 fmt::Debug,
20 rc::Rc,
21 sync::{
22 Arc,
23 atomic::{AtomicBool, Ordering},
24 },
25};
26
27use ahash::AHashMap;
28use nautilus_core::{UUID4, message::Message};
29use ustr::Ustr;
30
31use crate::msgbus::{
32 Handler, IntoHandler, ShareableMessageHandler, TypedHandler, TypedIntoHandler,
33 typed_handler::shareable_handler,
34};
35
36#[derive(Clone)]
38pub struct StubMessageHandler {
39 id: Ustr,
40 callback: Arc<dyn Fn(Message) + Send>,
41}
42
43impl Debug for StubMessageHandler {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 f.debug_struct(stringify!(StubMessageHandler))
46 .field("id", &self.id)
47 .finish()
48 }
49}
50
51impl Handler<dyn Any> for StubMessageHandler {
52 fn id(&self) -> Ustr {
53 self.id
54 }
55
56 fn handle(&self, message: &dyn Any) {
57 if let Some(msg) = message.downcast_ref::<Message>() {
58 (self.callback)(msg.clone());
59 }
60 }
61}
62
63#[must_use]
64#[allow(unused_must_use)]
65pub fn get_stub_shareable_handler(id: Option<Ustr>) -> ShareableMessageHandler {
66 let unique_id = id.unwrap_or_else(|| Ustr::from(UUID4::new().as_str()));
67 shareable_handler(Rc::new(StubMessageHandler {
68 id: unique_id,
69 callback: Arc::new(|m: Message| {
70 format!("{m:?}");
71 }),
72 }))
73}
74
75#[derive(Debug, Clone)]
77pub struct CallCheckHandler {
78 id: Ustr,
79 called: Arc<AtomicBool>,
80}
81
82impl CallCheckHandler {
83 #[must_use]
84 pub fn new(id: Option<Ustr>) -> Self {
85 let unique_id = id.unwrap_or_else(|| Ustr::from(UUID4::new().as_str()));
86 Self {
87 id: unique_id,
88 called: Arc::new(AtomicBool::new(false)),
89 }
90 }
91
92 #[must_use]
93 pub fn was_called(&self) -> bool {
94 self.called.load(Ordering::SeqCst)
95 }
96
97 #[must_use]
99 pub fn handler(&self) -> ShareableMessageHandler {
100 shareable_handler(Rc::new(self.clone()))
101 }
102}
103
104impl Handler<dyn Any> for CallCheckHandler {
105 fn id(&self) -> Ustr {
106 self.id
107 }
108
109 fn handle(&self, _message: &dyn Any) {
110 self.called.store(true, Ordering::SeqCst);
111 }
112}
113
114#[must_use]
117pub fn get_call_check_handler(id: Option<Ustr>) -> (ShareableMessageHandler, CallCheckHandler) {
118 let checker = CallCheckHandler::new(id);
119 let handler = checker.handler();
120 (handler, checker)
121}
122
123#[derive(Debug, Clone)]
125pub struct AnySavingHandler<T> {
126 id: Ustr,
127 messages: Rc<RefCell<Vec<T>>>,
128}
129
130impl<T: Clone + 'static> AnySavingHandler<T> {
131 #[must_use]
132 pub fn new(id: Option<Ustr>) -> Self {
133 let unique_id = id.unwrap_or_else(|| Ustr::from(UUID4::new().as_str()));
134 Self {
135 id: unique_id,
136 messages: Rc::new(RefCell::new(Vec::new())),
137 }
138 }
139
140 #[must_use]
141 pub fn get_messages(&self) -> Vec<T> {
142 self.messages.borrow().clone()
143 }
144
145 pub fn clear(&self) {
146 self.messages.borrow_mut().clear();
147 }
148
149 #[must_use]
151 pub fn handler(&self) -> ShareableMessageHandler {
152 shareable_handler(Rc::new(self.clone()))
153 }
154}
155
156impl<T: Clone + 'static> Handler<dyn Any> for AnySavingHandler<T> {
157 fn id(&self) -> Ustr {
158 self.id
159 }
160
161 fn handle(&self, message: &dyn Any) {
162 if let Some(m) = message.downcast_ref::<T>() {
163 self.messages.borrow_mut().push(m.clone());
164 } else {
165 log::error!(
166 "AnySavingHandler: expected {} got {:?}",
167 std::any::type_name::<T>(),
168 message.type_id()
169 );
170 }
171 }
172}
173
174#[must_use]
177pub fn get_any_saving_handler<T: Clone + 'static>(
178 id: Option<Ustr>,
179) -> (ShareableMessageHandler, AnySavingHandler<T>) {
180 let saver = AnySavingHandler::new(id);
181 let handler = saver.handler();
182 (handler, saver)
183}
184
185pub type MessageSavingHandler<T> = AnySavingHandler<T>;
187
188#[derive(Debug, Clone)]
190pub struct TypedMessageSavingHandler<T> {
191 id: Ustr,
192 messages: Rc<RefCell<Vec<T>>>,
193}
194
195impl<T: Clone + 'static> TypedMessageSavingHandler<T> {
196 #[must_use]
197 pub fn new(id: Option<Ustr>) -> Self {
198 let unique_id = id.unwrap_or_else(|| Ustr::from(UUID4::new().as_str()));
199 Self {
200 id: unique_id,
201 messages: Rc::new(RefCell::new(Vec::new())),
202 }
203 }
204
205 #[must_use]
206 pub fn get_messages(&self) -> Vec<T> {
207 self.messages.borrow().clone()
208 }
209
210 #[must_use]
212 pub fn handler(&self) -> TypedHandler<T> {
213 TypedHandler::new(self.clone())
214 }
215}
216
217impl<T: Clone + 'static> Handler<T> for TypedMessageSavingHandler<T> {
218 fn id(&self) -> Ustr {
219 self.id
220 }
221
222 fn handle(&self, message: &T) {
223 self.messages.borrow_mut().push(message.clone());
224 }
225}
226
227#[must_use]
230pub fn get_typed_message_saving_handler<T: Clone + 'static>(
231 id: Option<Ustr>,
232) -> (TypedHandler<T>, TypedMessageSavingHandler<T>) {
233 let saving_handler = TypedMessageSavingHandler::new(id);
234 let typed_handler = saving_handler.handler();
235 (typed_handler, saving_handler)
236}
237
238#[derive(Debug, Clone)]
243pub struct TypedIntoMessageSavingHandler<T> {
244 id: Ustr,
245 messages: Rc<RefCell<Vec<T>>>,
246}
247
248impl<T: 'static> TypedIntoMessageSavingHandler<T> {
249 #[must_use]
250 pub fn new(id: Option<Ustr>) -> Self {
251 let unique_id = id.unwrap_or_else(|| Ustr::from(UUID4::new().as_str()));
252 Self {
253 id: unique_id,
254 messages: Rc::new(RefCell::new(Vec::new())),
255 }
256 }
257
258 #[must_use]
260 pub fn new_with_messages(id: Option<Ustr>, messages: Rc<RefCell<Vec<T>>>) -> Self {
261 let unique_id = id.unwrap_or_else(|| Ustr::from(UUID4::new().as_str()));
262 Self {
263 id: unique_id,
264 messages,
265 }
266 }
267
268 #[must_use]
269 pub fn get_messages(&self) -> Vec<T>
270 where
271 T: Clone,
272 {
273 self.messages.borrow().clone()
274 }
275
276 #[must_use]
278 pub fn handler(&self) -> TypedIntoHandler<T> {
279 TypedIntoHandler::new(Self {
280 id: self.id,
281 messages: self.messages.clone(),
282 })
283 }
284
285 pub fn clear(&self) {
286 self.messages.borrow_mut().clear();
287 }
288}
289
290impl<T: 'static> IntoHandler<T> for TypedIntoMessageSavingHandler<T> {
291 fn id(&self) -> Ustr {
292 self.id
293 }
294
295 fn handle(&self, message: T) {
296 self.messages.borrow_mut().push(message);
297 }
298}
299
300#[must_use]
303pub fn get_typed_into_message_saving_handler<T: 'static>(
304 id: Option<Ustr>,
305) -> (TypedIntoHandler<T>, TypedIntoMessageSavingHandler<T>) {
306 let saving_handler = TypedIntoMessageSavingHandler::new(id);
307 let typed_handler = saving_handler.handler();
308 (typed_handler, saving_handler)
309}
310
311thread_local! {
315 static SAVING_HANDLERS: RefCell<AHashMap<Ustr, Box<dyn std::any::Any>>> = RefCell::new(AHashMap::new());
316}
317
318#[must_use]
320pub fn get_message_saving_handler<T: Clone + 'static>(id: Option<Ustr>) -> ShareableMessageHandler {
321 let (handler, saver) = get_any_saving_handler::<T>(id);
322 let handler_id = handler.0.id();
323 SAVING_HANDLERS.with(|handlers| {
324 handlers.borrow_mut().insert(handler_id, Box::new(saver));
325 });
326 handler
327}
328
329#[must_use]
331pub fn get_saved_messages<T: Clone + 'static>(handler: ShareableMessageHandler) -> Vec<T> {
332 let handler_id = handler.0.id();
333 SAVING_HANDLERS.with(|handlers| {
334 let handlers = handlers.borrow();
335 if let Some(saver) = handlers.get(&handler_id)
336 && let Some(saver) = saver.downcast_ref::<AnySavingHandler<T>>()
337 {
338 return saver.get_messages();
339 }
340 Vec::new()
341 })
342}