nautilus_indicators/momentum/
cmo.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2025 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::Display;
17
18use nautilus_model::data::{Bar, QuoteTick, TradeTick};
19
20use crate::{
21    average::{MovingAverageFactory, MovingAverageType},
22    indicator::{Indicator, MovingAverage},
23};
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28    feature = "python",
29    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators", unsendable)
30)]
31pub struct ChandeMomentumOscillator {
32    pub period: usize,
33    pub ma_type: MovingAverageType,
34    pub value: f64,
35    pub count: usize,
36    pub initialized: bool,
37    previous_close: f64,
38    average_gain: Box<dyn MovingAverage + Send + 'static>,
39    average_loss: Box<dyn MovingAverage + Send + 'static>,
40    has_inputs: bool,
41}
42
43impl Display for ChandeMomentumOscillator {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        write!(f, "{}({})", self.name(), self.period)
46    }
47}
48
49impl Indicator for ChandeMomentumOscillator {
50    fn name(&self) -> String {
51        stringify!(ChandeMomentumOscillator).to_string()
52    }
53
54    fn has_inputs(&self) -> bool {
55        self.has_inputs
56    }
57
58    fn initialized(&self) -> bool {
59        self.initialized
60    }
61
62    fn handle_quote(&mut self, _quote: &QuoteTick) {}
63
64    fn handle_trade(&mut self, _trade: &TradeTick) {}
65
66    fn handle_bar(&mut self, bar: &Bar) {
67        self.update_raw((&bar.close).into());
68    }
69
70    fn reset(&mut self) {
71        self.value = 0.0;
72        self.count = 0;
73        self.has_inputs = false;
74        self.initialized = false;
75        self.previous_close = 0.0;
76        self.average_gain.reset();
77        self.average_loss.reset();
78    }
79}
80
81impl ChandeMomentumOscillator {
82    /// Creates a new [`ChandeMomentumOscillator`] instance.
83    ///
84    /// # Panics
85    ///
86    /// Panics if `period` is not positive (> 0).
87    #[must_use]
88    pub fn new(period: usize, ma_type: Option<MovingAverageType>) -> Self {
89        assert!(period > 0, "ChandeMomentumOscillator: period must be > 0");
90        let ma_type = ma_type.unwrap_or(MovingAverageType::Wilder);
91        Self {
92            period,
93            ma_type,
94            average_gain: MovingAverageFactory::create(ma_type, period),
95            average_loss: MovingAverageFactory::create(ma_type, period),
96            previous_close: 0.0,
97            value: 0.0,
98            count: 0,
99            initialized: false,
100            has_inputs: false,
101        }
102    }
103
104    pub fn update_raw(&mut self, close: f64) {
105        self.count += 1;
106        if !self.has_inputs {
107            self.previous_close = close;
108            self.has_inputs = true;
109        }
110
111        let gain: f64 = close - self.previous_close;
112        if gain > 0.0 {
113            self.average_gain.update_raw(gain);
114            self.average_loss.update_raw(0.0);
115        } else if gain < 0.0 {
116            self.average_gain.update_raw(0.0);
117            self.average_loss.update_raw(-gain);
118        } else {
119            self.average_gain.update_raw(0.0);
120            self.average_loss.update_raw(0.0);
121        }
122
123        if !self.initialized && self.average_gain.initialized() && self.average_loss.initialized() {
124            self.initialized = true;
125        }
126        if self.initialized {
127            let divisor = self.average_gain.value() + self.average_loss.value();
128            if divisor == 0.0 {
129                self.value = 0.0;
130            } else {
131                self.value =
132                    100.0 * (self.average_gain.value() - self.average_loss.value()) / divisor;
133            }
134        }
135        self.previous_close = close;
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use nautilus_model::data::{Bar, QuoteTick};
142    use rstest::rstest;
143
144    use crate::{
145        average::MovingAverageType, indicator::Indicator, momentum::cmo::ChandeMomentumOscillator,
146        stubs::*,
147    };
148
149    #[rstest]
150    fn test_cmo_initialized(cmo_10: ChandeMomentumOscillator) {
151        let display_str = format!("{cmo_10}");
152        assert_eq!(display_str, "ChandeMomentumOscillator(10)");
153        assert_eq!(cmo_10.period, 10);
154        assert!(!cmo_10.initialized);
155    }
156
157    #[rstest]
158    fn test_initialized_with_required_inputs_returns_true(mut cmo_10: ChandeMomentumOscillator) {
159        for i in 0..12 {
160            cmo_10.update_raw(f64::from(i));
161        }
162        assert!(cmo_10.initialized);
163    }
164
165    #[rstest]
166    fn test_value_all_higher_inputs_returns_expected_value(mut cmo_10: ChandeMomentumOscillator) {
167        cmo_10.update_raw(109.93);
168        cmo_10.update_raw(110.0);
169        cmo_10.update_raw(109.77);
170        cmo_10.update_raw(109.96);
171        cmo_10.update_raw(110.29);
172        cmo_10.update_raw(110.53);
173        cmo_10.update_raw(110.27);
174        cmo_10.update_raw(110.21);
175        cmo_10.update_raw(110.06);
176        cmo_10.update_raw(110.19);
177        cmo_10.update_raw(109.83);
178        cmo_10.update_raw(109.9);
179        cmo_10.update_raw(110.0);
180        cmo_10.update_raw(110.03);
181        cmo_10.update_raw(110.13);
182        cmo_10.update_raw(109.95);
183        cmo_10.update_raw(109.75);
184        cmo_10.update_raw(110.15);
185        cmo_10.update_raw(109.9);
186        cmo_10.update_raw(110.04);
187        assert_eq!(cmo_10.value, 2.089_629_456_238_705_4);
188    }
189
190    #[rstest]
191    fn test_value_with_one_input_returns_expected_value(mut cmo_10: ChandeMomentumOscillator) {
192        cmo_10.update_raw(1.00000);
193        assert_eq!(cmo_10.value, 0.0);
194    }
195
196    #[rstest]
197    fn test_reset(mut cmo_10: ChandeMomentumOscillator) {
198        cmo_10.update_raw(1.00020);
199        cmo_10.update_raw(1.00030);
200        cmo_10.update_raw(1.00050);
201        cmo_10.reset();
202        assert!(!cmo_10.initialized());
203        assert_eq!(cmo_10.count, 0);
204        assert_eq!(cmo_10.value, 0.0);
205        assert_eq!(cmo_10.previous_close, 0.0);
206    }
207
208    #[rstest]
209    fn test_handle_quote_tick(mut cmo_10: ChandeMomentumOscillator, stub_quote: QuoteTick) {
210        cmo_10.handle_quote(&stub_quote);
211        assert_eq!(cmo_10.count, 0);
212        assert_eq!(cmo_10.value, 0.0);
213    }
214
215    #[rstest]
216    fn test_handle_bar(mut cmo_10: ChandeMomentumOscillator, bar_ethusdt_binance_minute_bid: Bar) {
217        cmo_10.handle_bar(&bar_ethusdt_binance_minute_bid);
218        assert_eq!(cmo_10.count, 1);
219        assert_eq!(cmo_10.value, 0.0);
220    }
221
222    #[rstest]
223    fn test_ma_type_affects_value() {
224        let mut cmo_sma = ChandeMomentumOscillator::new(3, Some(MovingAverageType::Simple));
225        let mut cmo_wilder = ChandeMomentumOscillator::new(3, Some(MovingAverageType::Wilder));
226        let prices = [1.0, 2.0, 3.0, 2.5, 3.5];
227        for price in prices {
228            cmo_sma.update_raw(price);
229            cmo_wilder.update_raw(price);
230        }
231        assert_ne!(cmo_sma.value, cmo_wilder.value);
232    }
233
234    #[rstest]
235    fn test_count_increments(mut cmo_10: ChandeMomentumOscillator) {
236        for i in 0..5 {
237            cmo_10.update_raw(f64::from(i));
238        }
239        assert_eq!(cmo_10.count, 5);
240    }
241
242    #[rstest]
243    fn test_reset_resets_inner_mas() {
244        let mut cmo = ChandeMomentumOscillator::new(3, None);
245        for price in [1.0, 2.0, 3.0] {
246            cmo.update_raw(price);
247        }
248        assert!(cmo.average_gain.initialized());
249        assert!(cmo.average_loss.initialized());
250        assert_ne!(cmo.average_gain.value(), 0.0);
251        cmo.reset();
252        assert!(!cmo.average_gain.initialized());
253        assert!(!cmo.average_loss.initialized());
254        assert_eq!(cmo.average_gain.value(), 0.0);
255        assert_eq!(cmo.average_loss.value(), 0.0);
256    }
257
258    #[rstest]
259    #[should_panic]
260    fn test_invalid_period_panics() {
261        let _ = ChandeMomentumOscillator::new(0, None);
262    }
263
264    #[rstest]
265    fn test_ma_type_propagation() {
266        let cmo = ChandeMomentumOscillator::new(5, Some(MovingAverageType::Simple));
267        assert_eq!(cmo.ma_type, MovingAverageType::Simple);
268    }
269
270    #[rstest]
271    fn test_zero_divisor_returns_zero() {
272        let mut cmo = ChandeMomentumOscillator::new(3, None);
273        for _ in 0..5 {
274            cmo.update_raw(100.0);
275        }
276        assert!(cmo.initialized);
277        assert_eq!(cmo.value, 0.0);
278    }
279
280    #[rstest]
281    fn test_random_walk_values_within_bounds() {
282        let prices = [
283            100.0, 100.5, 99.8, 100.3, 101.0, 100.7, 101.5, 101.2, 100.6, 101.1, 100.9, 101.4,
284            100.8, 101.2, 100.6, 100.9, 101.3, 101.0, 100.5, 101.1, 100.7, 101.4, 100.9, 100.8,
285            101.2, 100.6, 100.9, 101.3, 101.0, 100.5,
286        ];
287        let mut cmo = ChandeMomentumOscillator::new(10, None);
288        for price in prices {
289            cmo.update_raw(price);
290        }
291        assert!(cmo.initialized);
292        assert!(cmo.value <= 100.0 && cmo.value >= -100.0);
293    }
294}