nautilus_indicators/momentum/
cmo.rs1use std::fmt::Display;
17
18use nautilus_model::data::{Bar, QuoteTick, TradeTick};
19
20use crate::{
21 average::{MovingAverageFactory, MovingAverageType},
22 indicator::{Indicator, MovingAverage},
23};
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28 feature = "python",
29 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators", unsendable)
30)]
31pub struct ChandeMomentumOscillator {
32 pub period: usize,
33 pub ma_type: MovingAverageType,
34 pub value: f64,
35 pub count: usize,
36 pub initialized: bool,
37 previous_close: f64,
38 average_gain: Box<dyn MovingAverage + Send + 'static>,
39 average_loss: Box<dyn MovingAverage + Send + 'static>,
40 has_inputs: bool,
41}
42
43impl Display for ChandeMomentumOscillator {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 write!(f, "{}({})", self.name(), self.period)
46 }
47}
48
49impl Indicator for ChandeMomentumOscillator {
50 fn name(&self) -> String {
51 stringify!(ChandeMomentumOscillator).to_string()
52 }
53
54 fn has_inputs(&self) -> bool {
55 self.has_inputs
56 }
57
58 fn initialized(&self) -> bool {
59 self.initialized
60 }
61
62 fn handle_quote(&mut self, _quote: &QuoteTick) {}
63
64 fn handle_trade(&mut self, _trade: &TradeTick) {}
65
66 fn handle_bar(&mut self, bar: &Bar) {
67 self.update_raw((&bar.close).into());
68 }
69
70 fn reset(&mut self) {
71 self.value = 0.0;
72 self.count = 0;
73 self.has_inputs = false;
74 self.initialized = false;
75 self.previous_close = 0.0;
76 self.average_gain.reset();
77 self.average_loss.reset();
78 }
79}
80
81impl ChandeMomentumOscillator {
82 #[must_use]
88 pub fn new(period: usize, ma_type: Option<MovingAverageType>) -> Self {
89 assert!(period > 0, "ChandeMomentumOscillator: period must be > 0");
90 let ma_type = ma_type.unwrap_or(MovingAverageType::Wilder);
91 Self {
92 period,
93 ma_type,
94 average_gain: MovingAverageFactory::create(ma_type, period),
95 average_loss: MovingAverageFactory::create(ma_type, period),
96 previous_close: 0.0,
97 value: 0.0,
98 count: 0,
99 initialized: false,
100 has_inputs: false,
101 }
102 }
103
104 pub fn update_raw(&mut self, close: f64) {
105 self.count += 1;
106 if !self.has_inputs {
107 self.previous_close = close;
108 self.has_inputs = true;
109 }
110
111 let gain: f64 = close - self.previous_close;
112 if gain > 0.0 {
113 self.average_gain.update_raw(gain);
114 self.average_loss.update_raw(0.0);
115 } else if gain < 0.0 {
116 self.average_gain.update_raw(0.0);
117 self.average_loss.update_raw(-gain);
118 } else {
119 self.average_gain.update_raw(0.0);
120 self.average_loss.update_raw(0.0);
121 }
122
123 if !self.initialized && self.average_gain.initialized() && self.average_loss.initialized() {
124 self.initialized = true;
125 }
126 if self.initialized {
127 let divisor = self.average_gain.value() + self.average_loss.value();
128 if divisor == 0.0 {
129 self.value = 0.0;
130 } else {
131 self.value =
132 100.0 * (self.average_gain.value() - self.average_loss.value()) / divisor;
133 }
134 }
135 self.previous_close = close;
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use nautilus_model::data::{Bar, QuoteTick};
142 use rstest::rstest;
143
144 use crate::{
145 average::MovingAverageType, indicator::Indicator, momentum::cmo::ChandeMomentumOscillator,
146 stubs::*,
147 };
148
149 #[rstest]
150 fn test_cmo_initialized(cmo_10: ChandeMomentumOscillator) {
151 let display_str = format!("{cmo_10}");
152 assert_eq!(display_str, "ChandeMomentumOscillator(10)");
153 assert_eq!(cmo_10.period, 10);
154 assert!(!cmo_10.initialized);
155 }
156
157 #[rstest]
158 fn test_initialized_with_required_inputs_returns_true(mut cmo_10: ChandeMomentumOscillator) {
159 for i in 0..12 {
160 cmo_10.update_raw(f64::from(i));
161 }
162 assert!(cmo_10.initialized);
163 }
164
165 #[rstest]
166 fn test_value_all_higher_inputs_returns_expected_value(mut cmo_10: ChandeMomentumOscillator) {
167 cmo_10.update_raw(109.93);
168 cmo_10.update_raw(110.0);
169 cmo_10.update_raw(109.77);
170 cmo_10.update_raw(109.96);
171 cmo_10.update_raw(110.29);
172 cmo_10.update_raw(110.53);
173 cmo_10.update_raw(110.27);
174 cmo_10.update_raw(110.21);
175 cmo_10.update_raw(110.06);
176 cmo_10.update_raw(110.19);
177 cmo_10.update_raw(109.83);
178 cmo_10.update_raw(109.9);
179 cmo_10.update_raw(110.0);
180 cmo_10.update_raw(110.03);
181 cmo_10.update_raw(110.13);
182 cmo_10.update_raw(109.95);
183 cmo_10.update_raw(109.75);
184 cmo_10.update_raw(110.15);
185 cmo_10.update_raw(109.9);
186 cmo_10.update_raw(110.04);
187 assert_eq!(cmo_10.value, 2.089_629_456_238_705_4);
188 }
189
190 #[rstest]
191 fn test_value_with_one_input_returns_expected_value(mut cmo_10: ChandeMomentumOscillator) {
192 cmo_10.update_raw(1.00000);
193 assert_eq!(cmo_10.value, 0.0);
194 }
195
196 #[rstest]
197 fn test_reset(mut cmo_10: ChandeMomentumOscillator) {
198 cmo_10.update_raw(1.00020);
199 cmo_10.update_raw(1.00030);
200 cmo_10.update_raw(1.00050);
201 cmo_10.reset();
202 assert!(!cmo_10.initialized());
203 assert_eq!(cmo_10.count, 0);
204 assert_eq!(cmo_10.value, 0.0);
205 assert_eq!(cmo_10.previous_close, 0.0);
206 }
207
208 #[rstest]
209 fn test_handle_quote_tick(mut cmo_10: ChandeMomentumOscillator, stub_quote: QuoteTick) {
210 cmo_10.handle_quote(&stub_quote);
211 assert_eq!(cmo_10.count, 0);
212 assert_eq!(cmo_10.value, 0.0);
213 }
214
215 #[rstest]
216 fn test_handle_bar(mut cmo_10: ChandeMomentumOscillator, bar_ethusdt_binance_minute_bid: Bar) {
217 cmo_10.handle_bar(&bar_ethusdt_binance_minute_bid);
218 assert_eq!(cmo_10.count, 1);
219 assert_eq!(cmo_10.value, 0.0);
220 }
221
222 #[rstest]
223 fn test_ma_type_affects_value() {
224 let mut cmo_sma = ChandeMomentumOscillator::new(3, Some(MovingAverageType::Simple));
225 let mut cmo_wilder = ChandeMomentumOscillator::new(3, Some(MovingAverageType::Wilder));
226 let prices = [1.0, 2.0, 3.0, 2.5, 3.5];
227 for price in prices {
228 cmo_sma.update_raw(price);
229 cmo_wilder.update_raw(price);
230 }
231 assert_ne!(cmo_sma.value, cmo_wilder.value);
232 }
233
234 #[rstest]
235 fn test_count_increments(mut cmo_10: ChandeMomentumOscillator) {
236 for i in 0..5 {
237 cmo_10.update_raw(f64::from(i));
238 }
239 assert_eq!(cmo_10.count, 5);
240 }
241
242 #[rstest]
243 fn test_reset_resets_inner_mas() {
244 let mut cmo = ChandeMomentumOscillator::new(3, None);
245 for price in [1.0, 2.0, 3.0] {
246 cmo.update_raw(price);
247 }
248 assert!(cmo.average_gain.initialized());
249 assert!(cmo.average_loss.initialized());
250 assert_ne!(cmo.average_gain.value(), 0.0);
251 cmo.reset();
252 assert!(!cmo.average_gain.initialized());
253 assert!(!cmo.average_loss.initialized());
254 assert_eq!(cmo.average_gain.value(), 0.0);
255 assert_eq!(cmo.average_loss.value(), 0.0);
256 }
257
258 #[rstest]
259 #[should_panic]
260 fn test_invalid_period_panics() {
261 let _ = ChandeMomentumOscillator::new(0, None);
262 }
263
264 #[rstest]
265 fn test_ma_type_propagation() {
266 let cmo = ChandeMomentumOscillator::new(5, Some(MovingAverageType::Simple));
267 assert_eq!(cmo.ma_type, MovingAverageType::Simple);
268 }
269
270 #[rstest]
271 fn test_zero_divisor_returns_zero() {
272 let mut cmo = ChandeMomentumOscillator::new(3, None);
273 for _ in 0..5 {
274 cmo.update_raw(100.0);
275 }
276 assert!(cmo.initialized);
277 assert_eq!(cmo.value, 0.0);
278 }
279
280 #[rstest]
281 fn test_random_walk_values_within_bounds() {
282 let prices = [
283 100.0, 100.5, 99.8, 100.3, 101.0, 100.7, 101.5, 101.2, 100.6, 101.1, 100.9, 101.4,
284 100.8, 101.2, 100.6, 100.9, 101.3, 101.0, 100.5, 101.1, 100.7, 101.4, 100.9, 100.8,
285 101.2, 100.6, 100.9, 101.3, 101.0, 100.5,
286 ];
287 let mut cmo = ChandeMomentumOscillator::new(10, None);
288 for price in prices {
289 cmo.update_raw(price);
290 }
291 assert!(cmo.initialized);
292 assert!(cmo.value <= 100.0 && cmo.value >= -100.0);
293 }
294}