nautilus_common/msgbus/
stubs.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    any::Any,
18    cell::RefCell,
19    fmt::Debug,
20    rc::Rc,
21    sync::{
22        Arc,
23        atomic::{AtomicBool, Ordering},
24    },
25};
26
27use nautilus_core::{UUID4, message::Message};
28use ustr::Ustr;
29
30use crate::msgbus::{ShareableMessageHandler, handler::MessageHandler};
31
32// Stub message handler which logs the data it receives
33pub struct StubMessageHandler {
34    id: Ustr,
35    callback: Arc<dyn Fn(Message) + Send>,
36}
37
38impl Debug for StubMessageHandler {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.debug_struct(stringify!(StubMessageHandler))
41            .field("id", &self.id)
42            .finish()
43    }
44}
45
46impl MessageHandler for StubMessageHandler {
47    fn id(&self) -> Ustr {
48        self.id
49    }
50
51    fn handle(&self, message: &dyn Any) {
52        (self.callback)(message.downcast_ref::<Message>().unwrap().clone());
53    }
54
55    fn as_any(&self) -> &dyn Any {
56        self
57    }
58}
59
60#[must_use]
61#[allow(unused_must_use)]
62pub fn get_stub_shareable_handler(id: Option<Ustr>) -> ShareableMessageHandler {
63    // TODO: This reduces the need to come up with ID strings in tests.
64    // In Python we do something like `hash((self.topic, str(self.handler)))` for the hash
65    // which includes the memory address, just went with a UUID4 here.
66    let unique_id = id.unwrap_or_else(|| Ustr::from(UUID4::new().as_str()));
67    ShareableMessageHandler(Rc::new(StubMessageHandler {
68        id: unique_id,
69        callback: Arc::new(|m: Message| {
70            format!("{m:?}");
71        }),
72    }))
73}
74
75// Stub message handler which checks if handle was called
76#[derive(Debug)]
77pub struct CallCheckMessageHandler {
78    id: Ustr,
79    called: Arc<AtomicBool>,
80}
81
82impl CallCheckMessageHandler {
83    #[must_use]
84    pub fn was_called(&self) -> bool {
85        self.called.load(Ordering::SeqCst)
86    }
87}
88
89impl MessageHandler for CallCheckMessageHandler {
90    fn id(&self) -> Ustr {
91        self.id
92    }
93
94    fn handle(&self, _message: &dyn Any) {
95        self.called.store(true, Ordering::SeqCst);
96    }
97
98    fn as_any(&self) -> &dyn Any {
99        self
100    }
101}
102
103#[must_use]
104pub fn get_call_check_shareable_handler(id: Option<Ustr>) -> ShareableMessageHandler {
105    // TODO: This reduces the need to come up with ID strings in tests.
106    // In Python we do something like `hash((self.topic, str(self.handler)))` for the hash
107    // which includes the memory address, just went with a UUID4 here.
108    let unique_id = id.unwrap_or_else(|| Ustr::from(UUID4::new().as_str()));
109    ShareableMessageHandler(Rc::new(CallCheckMessageHandler {
110        id: unique_id,
111        called: Arc::new(AtomicBool::new(false)),
112    }))
113}
114
115/// Returns whether the given `CallCheckMessageHandler` has been invoked at least once.
116///
117/// # Panics
118///
119/// Panics if the provided `handler` is not a `CallCheckMessageHandler`.
120#[must_use]
121pub fn check_handler_was_called(call_check_handler: ShareableMessageHandler) -> bool {
122    call_check_handler
123        .0
124        .as_ref()
125        .as_any()
126        .downcast_ref::<CallCheckMessageHandler>()
127        .unwrap()
128        .was_called()
129}
130
131// Handler which saves the messages it receives
132#[derive(Debug, Clone)]
133pub struct MessageSavingHandler<T> {
134    id: Ustr,
135    messages: Rc<RefCell<Vec<T>>>,
136}
137
138impl<T: Clone + 'static> MessageSavingHandler<T> {
139    #[must_use]
140    pub fn get_messages(&self) -> Vec<T> {
141        self.messages.borrow().clone()
142    }
143}
144
145impl<T: Clone + 'static> MessageHandler for MessageSavingHandler<T> {
146    fn id(&self) -> Ustr {
147        self.id
148    }
149
150    /// Handles an incoming message by saving it.
151    ///
152    /// # Panics
153    ///
154    /// Panics if the provided `message` is not of the expected type `T`.
155    fn handle(&self, message: &dyn Any) {
156        let mut messages = self.messages.borrow_mut();
157        match message.downcast_ref::<T>() {
158            Some(m) => messages.push(m.clone()),
159            None => panic!("MessageSavingHandler: message type mismatch {message:?}"),
160        }
161    }
162
163    fn as_any(&self) -> &dyn Any {
164        self
165    }
166}
167
168#[must_use]
169pub fn get_message_saving_handler<T: Clone + 'static>(id: Option<Ustr>) -> ShareableMessageHandler {
170    // TODO: This reduces the need to come up with ID strings in tests.
171    // In Python we do something like `hash((self.topic, str(self.handler)))` for the hash
172    // which includes the memory address, just went with a UUID4 here.
173    let unique_id = id.unwrap_or_else(|| Ustr::from(UUID4::new().as_str()));
174    ShareableMessageHandler(Rc::new(MessageSavingHandler::<T> {
175        id: unique_id,
176        messages: Rc::new(RefCell::new(Vec::new())),
177    }))
178}
179
180/// Retrieves the messages saved by a [`MessageSavingHandler`].
181///
182/// # Panics
183///
184/// Panics if the provided `handler` is not a `MessageSavingHandler<T>`.
185#[must_use]
186pub fn get_saved_messages<T: Clone + 'static>(handler: ShareableMessageHandler) -> Vec<T> {
187    handler
188        .0
189        .as_ref()
190        .as_any()
191        .downcast_ref::<MessageSavingHandler<T>>()
192        .unwrap()
193        .get_messages()
194}
195
196/// Clears all messages saved by a [`MessageSavingHandler`].
197///
198/// # Panics
199///
200/// Panics if the provided `handler` is not a `MessageSavingHandler<T>`.
201pub fn clear_saved_messages<T: Clone + 'static>(handler: ShareableMessageHandler) {
202    handler
203        .0
204        .as_ref()
205        .as_any()
206        .downcast_ref::<MessageSavingHandler<T>>()
207        .unwrap()
208        .messages
209        .borrow_mut()
210        .clear();
211}