Skip to main content

nautilus_common/msgbus/
stubs.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    any::Any,
18    cell::RefCell,
19    fmt::Debug,
20    rc::Rc,
21    sync::{
22        Arc,
23        atomic::{AtomicBool, Ordering},
24    },
25};
26
27use ahash::AHashMap;
28use nautilus_core::{UUID4, message::Message};
29use ustr::Ustr;
30
31use crate::msgbus::{
32    Handler, IntoHandler, ShareableMessageHandler, TypedHandler, TypedIntoHandler,
33    typed_handler::shareable_handler,
34};
35
36/// Stub handler which logs messages it receives.
37#[derive(Clone)]
38pub struct StubMessageHandler {
39    id: Ustr,
40    callback: Arc<dyn Fn(Message) + Send>,
41}
42
43impl Debug for StubMessageHandler {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.debug_struct(stringify!(StubMessageHandler))
46            .field("id", &self.id)
47            .finish()
48    }
49}
50
51impl Handler<dyn Any> for StubMessageHandler {
52    fn id(&self) -> Ustr {
53        self.id
54    }
55
56    fn handle(&self, message: &dyn Any) {
57        if let Some(msg) = message.downcast_ref::<Message>() {
58            (self.callback)(msg.clone());
59        }
60    }
61}
62
63#[must_use]
64#[allow(unused_must_use)]
65pub fn get_stub_shareable_handler(id: Option<Ustr>) -> ShareableMessageHandler {
66    let unique_id = id.unwrap_or_else(|| Ustr::from(UUID4::new().as_str()));
67    shareable_handler(Rc::new(StubMessageHandler {
68        id: unique_id,
69        callback: Arc::new(|m: Message| {
70            format!("{m:?}");
71        }),
72    }))
73}
74
75/// Handler that tracks whether it has been called.
76#[derive(Debug, Clone)]
77pub struct CallCheckHandler {
78    id: Ustr,
79    called: Arc<AtomicBool>,
80}
81
82impl CallCheckHandler {
83    #[must_use]
84    pub fn new(id: Option<Ustr>) -> Self {
85        let unique_id = id.unwrap_or_else(|| Ustr::from(UUID4::new().as_str()));
86        Self {
87            id: unique_id,
88            called: Arc::new(AtomicBool::new(false)),
89        }
90    }
91
92    #[must_use]
93    pub fn was_called(&self) -> bool {
94        self.called.load(Ordering::SeqCst)
95    }
96
97    /// Returns a `ShareableMessageHandler` for registration.
98    #[must_use]
99    pub fn handler(&self) -> ShareableMessageHandler {
100        shareable_handler(Rc::new(self.clone()))
101    }
102}
103
104impl Handler<dyn Any> for CallCheckHandler {
105    fn id(&self) -> Ustr {
106        self.id
107    }
108
109    fn handle(&self, _message: &dyn Any) {
110        self.called.store(true, Ordering::SeqCst);
111    }
112}
113
114/// Creates a call-checking handler and returns both the handler for registration
115/// and a clone that can be used to check if it was called.
116#[must_use]
117pub fn get_call_check_handler(id: Option<Ustr>) -> (ShareableMessageHandler, CallCheckHandler) {
118    let checker = CallCheckHandler::new(id);
119    let handler = checker.handler();
120    (handler, checker)
121}
122
123/// Handler that saves messages it receives (for Any-based routing).
124#[derive(Debug, Clone)]
125pub struct AnySavingHandler<T> {
126    id: Ustr,
127    messages: Rc<RefCell<Vec<T>>>,
128}
129
130impl<T: Clone + 'static> AnySavingHandler<T> {
131    #[must_use]
132    pub fn new(id: Option<Ustr>) -> Self {
133        let unique_id = id.unwrap_or_else(|| Ustr::from(UUID4::new().as_str()));
134        Self {
135            id: unique_id,
136            messages: Rc::new(RefCell::new(Vec::new())),
137        }
138    }
139
140    #[must_use]
141    pub fn get_messages(&self) -> Vec<T> {
142        self.messages.borrow().clone()
143    }
144
145    pub fn clear(&self) {
146        self.messages.borrow_mut().clear();
147    }
148
149    /// Returns a `ShareableMessageHandler` for registration.
150    #[must_use]
151    pub fn handler(&self) -> ShareableMessageHandler {
152        shareable_handler(Rc::new(self.clone()))
153    }
154}
155
156impl<T: Clone + 'static> Handler<dyn Any> for AnySavingHandler<T> {
157    fn id(&self) -> Ustr {
158        self.id
159    }
160
161    fn handle(&self, message: &dyn Any) {
162        if let Some(m) = message.downcast_ref::<T>() {
163            self.messages.borrow_mut().push(m.clone());
164        } else {
165            log::error!(
166                "AnySavingHandler: expected {} got {:?}",
167                std::any::type_name::<T>(),
168                message.type_id()
169            );
170        }
171    }
172}
173
174/// Creates an Any-based message saving handler and returns both the handler
175/// for registration and a clone that can be used to retrieve messages.
176#[must_use]
177pub fn get_any_saving_handler<T: Clone + 'static>(
178    id: Option<Ustr>,
179) -> (ShareableMessageHandler, AnySavingHandler<T>) {
180    let saver = AnySavingHandler::new(id);
181    let handler = saver.handler();
182    (handler, saver)
183}
184
185// Type alias for backward compatibility
186pub type MessageSavingHandler<T> = AnySavingHandler<T>;
187
188/// Typed handler which saves the messages it receives (no downcast needed).
189#[derive(Debug, Clone)]
190pub struct TypedMessageSavingHandler<T> {
191    id: Ustr,
192    messages: Rc<RefCell<Vec<T>>>,
193}
194
195impl<T: Clone + 'static> TypedMessageSavingHandler<T> {
196    #[must_use]
197    pub fn new(id: Option<Ustr>) -> Self {
198        let unique_id = id.unwrap_or_else(|| Ustr::from(UUID4::new().as_str()));
199        Self {
200            id: unique_id,
201            messages: Rc::new(RefCell::new(Vec::new())),
202        }
203    }
204
205    #[must_use]
206    pub fn get_messages(&self) -> Vec<T> {
207        self.messages.borrow().clone()
208    }
209
210    /// Returns a `TypedHandler` that can be used for subscriptions.
211    #[must_use]
212    pub fn handler(&self) -> TypedHandler<T> {
213        TypedHandler::new(self.clone())
214    }
215}
216
217impl<T: Clone + 'static> Handler<T> for TypedMessageSavingHandler<T> {
218    fn id(&self) -> Ustr {
219        self.id
220    }
221
222    fn handle(&self, message: &T) {
223        self.messages.borrow_mut().push(message.clone());
224    }
225}
226
227/// Creates a typed message saving handler and returns both the handler for subscriptions
228/// and a clone that can be used to retrieve messages.
229#[must_use]
230pub fn get_typed_message_saving_handler<T: Clone + 'static>(
231    id: Option<Ustr>,
232) -> (TypedHandler<T>, TypedMessageSavingHandler<T>) {
233    let saving_handler = TypedMessageSavingHandler::new(id);
234    let typed_handler = saving_handler.handler();
235    (typed_handler, saving_handler)
236}
237
238/// Ownership-based typed handler which saves the messages it receives.
239///
240/// Unlike [`TypedMessageSavingHandler`] which borrows messages, this handler
241/// takes ownership which is required for `IntoEndpointMap` endpoints.
242#[derive(Debug, Clone)]
243pub struct TypedIntoMessageSavingHandler<T> {
244    id: Ustr,
245    messages: Rc<RefCell<Vec<T>>>,
246}
247
248impl<T: 'static> TypedIntoMessageSavingHandler<T> {
249    #[must_use]
250    pub fn new(id: Option<Ustr>) -> Self {
251        let unique_id = id.unwrap_or_else(|| Ustr::from(UUID4::new().as_str()));
252        Self {
253            id: unique_id,
254            messages: Rc::new(RefCell::new(Vec::new())),
255        }
256    }
257
258    /// Creates a handler backed by an existing shared messages vec.
259    #[must_use]
260    pub fn new_with_messages(id: Option<Ustr>, messages: Rc<RefCell<Vec<T>>>) -> Self {
261        let unique_id = id.unwrap_or_else(|| Ustr::from(UUID4::new().as_str()));
262        Self {
263            id: unique_id,
264            messages,
265        }
266    }
267
268    #[must_use]
269    pub fn get_messages(&self) -> Vec<T>
270    where
271        T: Clone,
272    {
273        self.messages.borrow().clone()
274    }
275
276    /// Returns a `TypedIntoHandler` that can be used for endpoint registration.
277    #[must_use]
278    pub fn handler(&self) -> TypedIntoHandler<T> {
279        TypedIntoHandler::new(Self {
280            id: self.id,
281            messages: self.messages.clone(),
282        })
283    }
284
285    pub fn clear(&self) {
286        self.messages.borrow_mut().clear();
287    }
288}
289
290impl<T: 'static> IntoHandler<T> for TypedIntoMessageSavingHandler<T> {
291    fn id(&self) -> Ustr {
292        self.id
293    }
294
295    fn handle(&self, message: T) {
296        self.messages.borrow_mut().push(message);
297    }
298}
299
300/// Creates an ownership-based typed message saving handler and returns both the handler
301/// for endpoint registration and a clone that can be used to retrieve messages.
302#[must_use]
303pub fn get_typed_into_message_saving_handler<T: 'static>(
304    id: Option<Ustr>,
305) -> (TypedIntoHandler<T>, TypedIntoMessageSavingHandler<T>) {
306    let saving_handler = TypedIntoMessageSavingHandler::new(id);
307    let typed_handler = saving_handler.handler();
308    (typed_handler, saving_handler)
309}
310
311// Legacy API for tests that use the old pattern with thread_local storage.
312// These wrap AnySavingHandler in thread_local for simpler test usage.
313
314thread_local! {
315    static SAVING_HANDLERS: RefCell<AHashMap<Ustr, Box<dyn std::any::Any>>> = RefCell::new(AHashMap::new());
316}
317
318/// Creates a message saving handler and stores it for later retrieval.
319#[must_use]
320pub fn get_message_saving_handler<T: Clone + 'static>(id: Option<Ustr>) -> ShareableMessageHandler {
321    let (handler, saver) = get_any_saving_handler::<T>(id);
322    let handler_id = handler.0.id();
323    SAVING_HANDLERS.with(|handlers| {
324        handlers.borrow_mut().insert(handler_id, Box::new(saver));
325    });
326    handler
327}
328
329/// Retrieves saved messages from a handler created by `get_message_saving_handler`.
330#[must_use]
331pub fn get_saved_messages<T: Clone + 'static>(handler: ShareableMessageHandler) -> Vec<T> {
332    let handler_id = handler.0.id();
333    SAVING_HANDLERS.with(|handlers| {
334        let handlers = handlers.borrow();
335        if let Some(saver) = handlers.get(&handler_id)
336            && let Some(saver) = saver.downcast_ref::<AnySavingHandler<T>>()
337        {
338            return saver.get_messages();
339        }
340        Vec::new()
341    })
342}