nautilus_common/msgbus/
stubs.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{
17    any::Any,
18    cell::RefCell,
19    rc::Rc,
20    sync::{
21        Arc,
22        atomic::{AtomicBool, Ordering},
23    },
24};
25
26use nautilus_core::message::Message;
27use ustr::Ustr;
28use uuid::Uuid;
29
30use crate::msgbus::{ShareableMessageHandler, handler::MessageHandler};
31
32// Stub message handler which logs the data it receives
33pub struct StubMessageHandler {
34    id: Ustr,
35    callback: Arc<dyn Fn(Message) + Send>,
36}
37
38impl MessageHandler for StubMessageHandler {
39    fn id(&self) -> Ustr {
40        self.id
41    }
42
43    fn handle(&self, message: &dyn Any) {
44        (self.callback)(message.downcast_ref::<Message>().unwrap().clone());
45    }
46
47    fn as_any(&self) -> &dyn Any {
48        self
49    }
50}
51
52#[must_use]
53#[allow(unused_must_use)] // TODO: Temporary to fix docs build
54pub fn get_stub_shareable_handler(id: Option<Ustr>) -> ShareableMessageHandler {
55    // TODO: This reduces the need to come up with ID strings in tests.
56    // In Python we do something like `hash((self.topic, str(self.handler)))` for the hash
57    // which includes the memory address, just went with a UUID4 here.
58    let unique_id = id.unwrap_or_else(|| Ustr::from(&Uuid::new_v4().to_string()));
59    ShareableMessageHandler(Rc::new(StubMessageHandler {
60        id: unique_id,
61        callback: Arc::new(|m: Message| {
62            format!("{m:?}");
63        }),
64    }))
65}
66
67// Stub message handler which checks if handle was called
68pub struct CallCheckMessageHandler {
69    id: Ustr,
70    called: Arc<AtomicBool>,
71}
72
73impl CallCheckMessageHandler {
74    #[must_use]
75    pub fn was_called(&self) -> bool {
76        self.called.load(Ordering::SeqCst)
77    }
78}
79
80impl MessageHandler for CallCheckMessageHandler {
81    fn id(&self) -> Ustr {
82        self.id
83    }
84
85    fn handle(&self, _message: &dyn Any) {
86        self.called.store(true, Ordering::SeqCst);
87    }
88
89    fn as_any(&self) -> &dyn Any {
90        self
91    }
92}
93
94#[must_use]
95pub fn get_call_check_shareable_handler(id: Option<Ustr>) -> ShareableMessageHandler {
96    // TODO: This reduces the need to come up with ID strings in tests.
97    // In Python we do something like `hash((self.topic, str(self.handler)))` for the hash
98    // which includes the memory address, just went with a UUID4 here.
99    let unique_id = id.unwrap_or_else(|| Ustr::from(&Uuid::new_v4().to_string()));
100    ShareableMessageHandler(Rc::new(CallCheckMessageHandler {
101        id: unique_id,
102        called: Arc::new(AtomicBool::new(false)),
103    }))
104}
105
106#[must_use]
107pub fn check_handler_was_called(call_check_handler: ShareableMessageHandler) -> bool {
108    call_check_handler
109        .0
110        .as_ref()
111        .as_any()
112        .downcast_ref::<CallCheckMessageHandler>()
113        .unwrap()
114        .was_called()
115}
116
117// Handler which saves the messages it receives
118#[derive(Debug, Clone)]
119pub struct MessageSavingHandler<T> {
120    id: Ustr,
121    messages: Rc<RefCell<Vec<T>>>,
122}
123
124impl<T: Clone + 'static> MessageSavingHandler<T> {
125    #[must_use]
126    pub fn get_messages(&self) -> Vec<T> {
127        self.messages.borrow().clone()
128    }
129}
130
131impl<T: Clone + 'static> MessageHandler for MessageSavingHandler<T> {
132    fn id(&self) -> Ustr {
133        self.id
134    }
135
136    fn handle(&self, message: &dyn Any) {
137        let mut messages = self.messages.borrow_mut();
138        match message.downcast_ref::<T>() {
139            Some(m) => messages.push(m.clone()),
140            None => panic!("MessageSavingHandler: message type mismatch {message:?}"),
141        }
142    }
143
144    fn as_any(&self) -> &dyn Any {
145        self
146    }
147}
148
149#[must_use]
150pub fn get_message_saving_handler<T: Clone + 'static>(id: Option<Ustr>) -> ShareableMessageHandler {
151    // TODO: This reduces the need to come up with ID strings in tests.
152    // In Python we do something like `hash((self.topic, str(self.handler)))` for the hash
153    // which includes the memory address, just went with a UUID4 here.
154    let unique_id = id.unwrap_or_else(|| Ustr::from(&Uuid::new_v4().to_string()));
155    ShareableMessageHandler(Rc::new(MessageSavingHandler::<T> {
156        id: unique_id,
157        messages: Rc::new(RefCell::new(Vec::new())),
158    }))
159}
160
161#[must_use]
162pub fn get_saved_messages<T: Clone + 'static>(handler: ShareableMessageHandler) -> Vec<T> {
163    handler
164        .0
165        .as_ref()
166        .as_any()
167        .downcast_ref::<MessageSavingHandler<T>>()
168        .unwrap()
169        .get_messages()
170}