nautilus_indicators/momentum/
amat.rs1use std::fmt::{Debug, Display};
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::data::Bar;
20
21use crate::{
22 average::{MovingAverageFactory, MovingAverageType},
23 indicator::{Indicator, MovingAverage},
24};
25
26const DEFAULT_MA_TYPE: MovingAverageType = MovingAverageType::Exponential;
27const MAX_SIGNAL: usize = 1_024;
28
29type SignalBuf = ArrayDeque<f64, { MAX_SIGNAL + 1 }, Wrapping>;
30
31#[repr(C)]
32#[derive(Debug)]
33#[cfg_attr(
34 feature = "python",
35 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators", unsendable)
36)]
37pub struct ArcherMovingAveragesTrends {
38 pub fast_period: usize,
39 pub slow_period: usize,
40 pub signal_period: usize,
41 pub ma_type: MovingAverageType,
42 pub long_run: bool,
43 pub short_run: bool,
44 pub initialized: bool,
45 fast_ma: Box<dyn MovingAverage + Send + 'static>,
46 slow_ma: Box<dyn MovingAverage + Send + 'static>,
47 fast_ma_price: SignalBuf,
48 slow_ma_price: SignalBuf,
49 has_inputs: bool,
50}
51
52impl Display for ArcherMovingAveragesTrends {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(
55 f,
56 "{}({},{},{},{})",
57 self.name(),
58 self.fast_period,
59 self.slow_period,
60 self.signal_period,
61 self.ma_type,
62 )
63 }
64}
65
66impl Indicator for ArcherMovingAveragesTrends {
67 fn name(&self) -> String {
68 stringify!(ArcherMovingAveragesTrends).into()
69 }
70
71 fn has_inputs(&self) -> bool {
72 self.has_inputs
73 }
74
75 fn initialized(&self) -> bool {
76 self.initialized
77 }
78
79 fn handle_bar(&mut self, bar: &Bar) {
80 self.update_raw(bar.close.into());
81 }
82
83 fn reset(&mut self) {
84 self.fast_ma.reset();
85 self.slow_ma.reset();
86 self.long_run = false;
87 self.short_run = false;
88 self.fast_ma_price.clear();
89 self.slow_ma_price.clear();
90 self.has_inputs = false;
91 self.initialized = false;
92 }
93}
94
95impl ArcherMovingAveragesTrends {
96 #[must_use]
105 pub fn new(
106 fast_period: usize,
107 slow_period: usize,
108 signal_period: usize,
109 ma_type: Option<MovingAverageType>,
110 ) -> Self {
111 assert!(
112 fast_period > 0,
113 "fast_period must be positive (got {fast_period})"
114 );
115 assert!(
116 slow_period > 0,
117 "slow_period must be positive (got {slow_period})"
118 );
119 assert!(
120 signal_period > 0,
121 "signal_period must be positive (got {signal_period})"
122 );
123 assert!(
124 slow_period > fast_period,
125 "slow_period ({slow_period}) must be greater than fast_period ({fast_period})"
126 );
127 assert!(
128 signal_period <= MAX_SIGNAL,
129 "signal_period ({signal_period}) must not exceed MAX_SIGNAL ({MAX_SIGNAL})"
130 );
131
132 let ma_type = ma_type.unwrap_or(DEFAULT_MA_TYPE);
133
134 Self {
135 fast_period,
136 slow_period,
137 signal_period,
138 ma_type,
139 long_run: false,
140 short_run: false,
141 fast_ma: MovingAverageFactory::create(ma_type, fast_period),
142 slow_ma: MovingAverageFactory::create(ma_type, slow_period),
143 fast_ma_price: SignalBuf::new(),
144 slow_ma_price: SignalBuf::new(),
145 has_inputs: false,
146 initialized: false,
147 }
148 }
149
150 pub fn update_raw(&mut self, close: f64) {
155 self.fast_ma.update_raw(close);
156 self.slow_ma.update_raw(close);
157
158 if self.slow_ma.initialized() {
159 self.fast_ma_price.push_back(self.fast_ma.value());
160 self.slow_ma_price.push_back(self.slow_ma.value());
161
162 let max_len = self.signal_period + 1;
163 if self.fast_ma_price.len() > max_len {
164 self.fast_ma_price.pop_front();
165 self.slow_ma_price.pop_front();
166 }
167
168 let fast_back = self.fast_ma.value();
169 let fast_front = *self
170 .fast_ma_price
171 .front()
172 .expect("buffer has at least one element");
173
174 let fast_diff = fast_back - fast_front;
175 self.long_run = fast_diff > 0.0 || self.long_run;
176 self.short_run = fast_diff < 0.0 || self.short_run;
177 }
178
179 if !self.initialized {
180 self.has_inputs = true;
181 let max_len = self.signal_period + 1;
182 if self.slow_ma_price.len() == max_len && self.slow_ma.initialized() {
183 self.initialized = true;
184 }
185 }
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use rstest::rstest;
192
193 use super::*;
194 use crate::stubs::amat_345;
195
196 fn make(fast: usize, slow: usize, signal: usize) {
197 let _ = ArcherMovingAveragesTrends::new(fast, slow, signal, None);
198 }
199
200 #[rstest]
201 fn default_ma_type_is_exponential() {
202 let ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
203 assert_eq!(ind.ma_type, MovingAverageType::Exponential);
204 }
205
206 #[rstest]
207 fn test_name_returns_expected_string(amat_345: ArcherMovingAveragesTrends) {
208 assert_eq!(amat_345.name(), "ArcherMovingAveragesTrends");
209 }
210
211 #[rstest]
212 fn test_str_repr_returns_expected_string(amat_345: ArcherMovingAveragesTrends) {
213 assert_eq!(
214 format!("{amat_345}"),
215 "ArcherMovingAveragesTrends(3,4,5,SIMPLE)"
216 );
217 }
218
219 #[rstest]
220 fn test_period_returns_expected_value(amat_345: ArcherMovingAveragesTrends) {
221 assert_eq!(amat_345.fast_period, 3);
222 assert_eq!(amat_345.slow_period, 4);
223 assert_eq!(amat_345.signal_period, 5);
224 }
225
226 #[rstest]
227 fn test_initialized_without_inputs_returns_false(amat_345: ArcherMovingAveragesTrends) {
228 assert!(!amat_345.initialized());
229 }
230
231 #[rstest]
232 #[should_panic(expected = "fast_period must be positive")]
233 fn new_panics_on_zero_fast_period() {
234 make(0, 4, 5);
235 }
236
237 #[rstest]
238 #[should_panic(expected = "slow_period must be positive")]
239 fn new_panics_on_zero_slow_period() {
240 make(3, 0, 5);
241 }
242
243 #[rstest]
244 #[should_panic(expected = "signal_period must be positive")]
245 fn new_panics_on_zero_signal_period() {
246 make(3, 5, 0);
247 }
248
249 #[rstest]
250 #[should_panic(expected = "slow_period (3) must be greater than fast_period (3)")]
251 fn new_panics_when_slow_not_greater_than_fast() {
252 make(3, 3, 5);
253 }
254
255 #[rstest]
256 #[should_panic(expected = "slow_period (2) must be greater than fast_period (3)")]
257 fn new_panics_when_slow_less_than_fast() {
258 make(3, 2, 5);
259 }
260
261 fn feed_sequence(ind: &mut ArcherMovingAveragesTrends, start: i64, count: usize, step: i64) {
262 (0..count).for_each(|i| ind.update_raw((start + i as i64 * step) as f64));
263 }
264
265 #[rstest]
266 fn buffer_len_never_exceeds_signal_plus_one() {
267 let mut ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
268 feed_sequence(&mut ind, 0, 100, 1);
269 assert_eq!(ind.fast_ma_price.len(), ind.signal_period + 1);
270 assert_eq!(ind.slow_ma_price.len(), ind.signal_period + 1);
271 }
272
273 #[rstest]
274 fn initialized_becomes_true_after_slow_ready_and_buffer_full() {
275 let mut ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
276 feed_sequence(&mut ind, 0, 11, 1); assert!(ind.initialized());
278 }
279
280 #[rstest]
281 fn long_run_flag_sets_on_bullish_trend() {
282 let mut ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
283 feed_sequence(&mut ind, 0, 60, 1);
284 assert!(ind.long_run, "Expected long_run=TRUE on up-trend");
285 assert!(!ind.short_run, "short_run should remain FALSE here");
286 }
287
288 #[rstest]
289 fn short_run_flag_sets_on_bearish_trend() {
290 let mut ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
291 feed_sequence(&mut ind, 100, 60, -1);
292 assert!(ind.short_run, "Expected short_run=TRUE on down-trend");
293 assert!(!ind.long_run, "long_run should remain FALSE here");
294 }
295
296 #[rstest]
297 fn reset_clears_internal_state() {
298 let mut ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
299 feed_sequence(&mut ind, 0, 50, 1);
300 assert!(ind.long_run || ind.short_run);
301 assert!(!ind.fast_ma_price.is_empty());
302
303 ind.reset();
304
305 assert!(!ind.long_run && !ind.short_run);
306 assert_eq!(ind.fast_ma_price.len(), 0);
307 assert_eq!(ind.slow_ma_price.len(), 0);
308 assert!(!ind.initialized());
309 assert!(!ind.has_inputs());
310 }
311
312 #[rstest]
313 #[should_panic(expected = "signal_period (1025) must not exceed MAX_SIGNAL (1024)")]
314 fn new_panics_when_signal_exceeds_max() {
315 let _ = ArcherMovingAveragesTrends::new(3, 4, MAX_SIGNAL + 1, None);
316 }
317
318 #[rstest]
319 fn ma_type_override_is_respected() {
320 let ind = ArcherMovingAveragesTrends::new(3, 4, 5, Some(MovingAverageType::Simple));
321 assert_eq!(ind.ma_type, MovingAverageType::Simple);
322 }
323}