1use std::fmt::Display;
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_core::correctness::{FAILED, check_predicate_true};
20use nautilus_model::{
21 data::{Bar, QuoteTick, TradeTick},
22 enums::PriceType,
23};
24
25use crate::indicator::{Indicator, MovingAverage};
26
27const MAX_PERIOD: usize = 8_192;
28
29#[repr(C)]
31#[derive(Debug)]
32#[cfg_attr(
33 feature = "python",
34 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
35)]
36pub struct WeightedMovingAverage {
37 pub period: usize,
39 pub weights: Vec<f64>,
41 pub price_type: PriceType,
43 pub value: f64,
45 pub initialized: bool,
47 pub inputs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
49}
50
51impl Display for WeightedMovingAverage {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 write!(f, "{}({},{:?})", self.name(), self.period, self.weights)
54 }
55}
56
57impl WeightedMovingAverage {
58 #[must_use]
67 pub fn new(period: usize, weights: Vec<f64>, price_type: Option<PriceType>) -> Self {
68 Self::new_checked(period, weights, price_type).expect(FAILED)
69 }
70
71 pub fn new_checked(
80 period: usize,
81 weights: Vec<f64>,
82 price_type: Option<PriceType>,
83 ) -> anyhow::Result<Self> {
84 const EPS: f64 = f64::EPSILON;
85
86 check_predicate_true(period > 0, "`period` must be positive")?;
87
88 check_predicate_true(
89 period == weights.len(),
90 "`period` must equal `weights.len()`",
91 )?;
92
93 let weight_sum: f64 = weights.iter().copied().sum();
94 check_predicate_true(
95 weight_sum > EPS,
96 "`weights` sum must be positive and > f64::EPSILON",
97 )?;
98
99 Ok(Self {
100 period,
101 weights,
102 price_type: price_type.unwrap_or(PriceType::Last),
103 value: 0.0,
104 inputs: ArrayDeque::new(),
105 initialized: false,
106 })
107 }
108
109 fn weighted_average(&self) -> f64 {
110 let n = self.inputs.len();
111 let weights_slice = &self.weights[self.period - n..];
112
113 let mut sum = 0.0;
114 let mut weight_sum = 0.0;
115
116 for (input, weight) in self.inputs.iter().rev().zip(weights_slice.iter().rev()) {
117 sum += input * weight;
118 weight_sum += weight;
119 }
120 sum / weight_sum
121 }
122}
123
124impl Indicator for WeightedMovingAverage {
125 fn name(&self) -> String {
126 stringify!(WeightedMovingAverage).to_string()
127 }
128
129 fn has_inputs(&self) -> bool {
130 !self.inputs.is_empty()
131 }
132
133 fn initialized(&self) -> bool {
134 self.initialized
135 }
136
137 fn handle_quote(&mut self, quote: &QuoteTick) {
138 self.update_raw(quote.extract_price(self.price_type).into());
139 }
140
141 fn handle_trade(&mut self, trade: &TradeTick) {
142 self.update_raw((&trade.price).into());
143 }
144
145 fn handle_bar(&mut self, bar: &Bar) {
146 self.update_raw((&bar.close).into());
147 }
148
149 fn reset(&mut self) {
150 self.value = 0.0;
151 self.initialized = false;
152 self.inputs.clear();
153 }
154}
155
156impl MovingAverage for WeightedMovingAverage {
157 fn value(&self) -> f64 {
158 self.value
159 }
160
161 fn count(&self) -> usize {
162 self.inputs.len()
163 }
164
165 fn update_raw(&mut self, value: f64) {
166 if self.inputs.len() == self.period.min(MAX_PERIOD) {
167 self.inputs.pop_front();
168 }
169 let _ = self.inputs.push_back(value);
170
171 self.value = self.weighted_average();
172 self.initialized = self.count() >= self.period;
173 }
174}
175
176#[cfg(test)]
177mod tests {
178
179 use arraydeque::{ArrayDeque, Wrapping};
180 use rstest::rstest;
181
182 use crate::{
183 average::wma::WeightedMovingAverage,
184 indicator::{Indicator, MovingAverage},
185 stubs::*,
186 };
187
188 #[rstest]
189 fn test_wma_initialized(indicator_wma_10: WeightedMovingAverage) {
190 let display_str = format!("{indicator_wma_10}");
191 assert_eq!(
192 display_str,
193 "WeightedMovingAverage(10,[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])"
194 );
195 assert_eq!(indicator_wma_10.name(), "WeightedMovingAverage");
196 assert!(!indicator_wma_10.has_inputs());
197 assert!(!indicator_wma_10.initialized());
198 }
199
200 #[rstest]
201 #[should_panic]
202 fn test_different_weights_len_and_period_error() {
203 let _ = WeightedMovingAverage::new(10, vec![0.5, 0.5, 0.5], None);
204 }
205
206 #[rstest]
207 fn test_value_with_one_input(mut indicator_wma_10: WeightedMovingAverage) {
208 indicator_wma_10.update_raw(1.0);
209 assert_eq!(indicator_wma_10.value, 1.0);
210 }
211
212 #[rstest]
213 fn test_value_with_two_inputs_equal_weights() {
214 let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
215 wma.update_raw(1.0);
216 wma.update_raw(2.0);
217 assert_eq!(wma.value, 1.5);
218 }
219
220 #[rstest]
221 fn test_value_with_four_inputs_equal_weights() {
222 let mut wma = WeightedMovingAverage::new(4, vec![0.25, 0.25, 0.25, 0.25], None);
223 wma.update_raw(1.0);
224 wma.update_raw(2.0);
225 wma.update_raw(3.0);
226 wma.update_raw(4.0);
227 assert_eq!(wma.value, 2.5);
228 }
229
230 #[rstest]
231 fn test_value_with_two_inputs(mut indicator_wma_10: WeightedMovingAverage) {
232 indicator_wma_10.update_raw(1.0);
233 indicator_wma_10.update_raw(2.0);
234 let result = 2.0f64.mul_add(1.0, 1.0 * 0.9) / 1.9;
235 assert_eq!(indicator_wma_10.value, result);
236 }
237
238 #[rstest]
239 fn test_value_with_three_inputs(mut indicator_wma_10: WeightedMovingAverage) {
240 indicator_wma_10.update_raw(1.0);
241 indicator_wma_10.update_raw(2.0);
242 indicator_wma_10.update_raw(3.0);
243 let result = 1.0f64.mul_add(0.8, 3.0f64.mul_add(1.0, 2.0 * 0.9)) / (1.0 + 0.9 + 0.8);
244 assert_eq!(indicator_wma_10.value, result);
245 }
246
247 #[rstest]
248 fn test_value_expected_with_exact_period(mut indicator_wma_10: WeightedMovingAverage) {
249 for i in 1..11 {
250 indicator_wma_10.update_raw(f64::from(i));
251 }
252 assert_eq!(indicator_wma_10.value, 7.0);
253 }
254
255 #[rstest]
256 fn test_value_expected_with_more_inputs(mut indicator_wma_10: WeightedMovingAverage) {
257 for i in 1..=11 {
258 indicator_wma_10.update_raw(f64::from(i));
259 }
260 assert_eq!(indicator_wma_10.value(), 8.000_000_000_000_002);
261 }
262
263 #[rstest]
264 fn test_reset(mut indicator_wma_10: WeightedMovingAverage) {
265 indicator_wma_10.update_raw(1.0);
266 indicator_wma_10.update_raw(2.0);
267 indicator_wma_10.reset();
268 assert_eq!(indicator_wma_10.value, 0.0);
269 assert_eq!(indicator_wma_10.count(), 0);
270 assert!(!indicator_wma_10.initialized);
271 }
272
273 #[rstest]
274 #[should_panic]
275 fn new_panics_on_zero_period() {
276 let _ = WeightedMovingAverage::new(0, vec![1.0], None);
277 }
278
279 #[rstest]
280 fn new_checked_err_on_zero_period() {
281 let res = WeightedMovingAverage::new_checked(0, vec![1.0], None);
282 assert!(res.is_err());
283 }
284
285 #[rstest]
286 #[should_panic]
287 fn new_panics_on_zero_weight_sum() {
288 let _ = WeightedMovingAverage::new(3, vec![0.0, 0.0, 0.0], None);
289 }
290
291 #[rstest]
292 fn new_checked_err_on_zero_weight_sum() {
293 let res = WeightedMovingAverage::new_checked(3, vec![0.0, 0.0, 0.0], None);
294 assert!(res.is_err());
295 }
296
297 #[rstest]
298 #[should_panic]
299 fn new_panics_when_weight_sum_below_epsilon() {
300 let tiny = f64::EPSILON / 10.0;
301 let _ = WeightedMovingAverage::new(3, vec![tiny; 3], None);
302 }
303
304 #[rstest]
305 fn initialized_flag_transitions() {
306 let period = 3;
307 let weights = vec![1.0, 2.0, 3.0];
308 let mut wma = WeightedMovingAverage::new(period, weights, None);
309
310 assert!(!wma.initialized());
311
312 for i in 0..period {
313 wma.update_raw(i as f64);
314 let expected = (i + 1) >= period;
315 assert_eq!(wma.initialized(), expected);
316 }
317 assert!(wma.initialized());
318 }
319
320 #[rstest]
321 fn count_matches_inputs_and_has_inputs() {
322 let mut wma = WeightedMovingAverage::new(4, vec![0.25; 4], None);
323
324 assert_eq!(wma.count(), 0);
325 assert!(!wma.has_inputs());
326
327 wma.update_raw(1.0);
328 wma.update_raw(2.0);
329 assert_eq!(wma.count(), 2);
330 assert!(wma.has_inputs());
331 }
332
333 #[rstest]
334 fn reset_restores_pristine_state() {
335 let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
336 wma.update_raw(1.0);
337 wma.update_raw(2.0);
338 assert!(wma.initialized());
339
340 wma.reset();
341
342 assert_eq!(wma.count(), 0);
343 assert_eq!(wma.value(), 0.0);
344 assert!(!wma.initialized());
345 assert!(!wma.has_inputs());
346 }
347
348 #[rstest]
349 fn weighted_average_with_non_uniform_weights() {
350 let mut wma = WeightedMovingAverage::new(3, vec![1.0, 2.0, 3.0], None);
351 wma.update_raw(10.0);
352 wma.update_raw(20.0);
353 wma.update_raw(30.0);
354 let expected = 23.333_333_333_333_332;
355 let tol = f64::EPSILON.sqrt();
356 assert!(
357 (wma.value() - expected).abs() < tol,
358 "value = {}, expected ≈ {}",
359 wma.value(),
360 expected
361 );
362 }
363
364 #[rstest]
365 fn test_window_never_exceeds_period(mut indicator_wma_10: WeightedMovingAverage) {
366 for i in 0..100 {
367 indicator_wma_10.update_raw(f64::from(i));
368 assert!(indicator_wma_10.count() <= indicator_wma_10.period);
369 }
370 }
371
372 #[rstest]
373 fn test_negative_weights_positive_sum() {
374 let period = 3;
375 let weights = vec![-1.0, 2.0, 2.0];
376 let mut wma = WeightedMovingAverage::new(period, weights, None);
377 wma.update_raw(1.0);
378 wma.update_raw(2.0);
379 wma.update_raw(3.0);
380
381 let expected = 2.0f64.mul_add(3.0, 2.0f64.mul_add(2.0, -1.0)) / 3.0;
382 let tol = f64::EPSILON.sqrt();
383 assert!((wma.value() - expected).abs() < tol);
384 }
385
386 #[rstest]
387 fn test_nan_input_propagates() {
388 let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
389 wma.update_raw(1.0);
390 wma.update_raw(f64::NAN);
391
392 assert!(wma.value().is_nan());
393 }
394
395 #[rstest]
396 #[should_panic]
397 fn new_panics_when_weight_sum_equals_epsilon() {
398 let eps_third = f64::EPSILON / 3.0;
399 let _ = WeightedMovingAverage::new(3, vec![eps_third; 3], None);
400 }
401
402 #[rstest]
403 fn new_checked_err_when_weight_sum_equals_epsilon() {
404 let eps_third = f64::EPSILON / 3.0;
405 let res = WeightedMovingAverage::new_checked(3, vec![eps_third; 3], None);
406 assert!(res.is_err());
407 }
408
409 #[rstest]
410 fn new_checked_err_when_weight_sum_below_epsilon() {
411 let w = f64::EPSILON * 0.9;
412 let res = WeightedMovingAverage::new_checked(1, vec![w], None);
413 assert!(res.is_err());
414 }
415
416 #[rstest]
417 fn new_ok_when_weight_sum_above_epsilon() {
418 let w = f64::EPSILON * 1.1;
419 let res = WeightedMovingAverage::new_checked(1, vec![w], None);
420 assert!(res.is_ok());
421 }
422
423 #[rstest]
424 #[should_panic]
425 fn new_panics_on_cancelled_weights_sum() {
426 let _ = WeightedMovingAverage::new(3, vec![1.0, -1.0, 0.0], None);
427 }
428
429 #[rstest]
430 fn new_checked_err_on_cancelled_weights_sum() {
431 let res = WeightedMovingAverage::new_checked(3, vec![1.0, -1.0, 0.0], None);
432 assert!(res.is_err());
433 }
434
435 #[rstest]
436 fn single_period_returns_latest_input() {
437 let mut wma = WeightedMovingAverage::new(1, vec![1.0], None);
438 for i in 0..5 {
439 let v = f64::from(i);
440 wma.update_raw(v);
441 assert_eq!(wma.value(), v);
442 }
443 }
444
445 #[rstest]
446 fn value_with_sparse_weights() {
447 let mut wma = WeightedMovingAverage::new(3, vec![0.0, 1.0, 0.0], None);
448 wma.update_raw(10.0);
449 wma.update_raw(20.0);
450 wma.update_raw(30.0);
451 assert_eq!(wma.value(), 20.0);
452 }
453
454 #[rstest]
455 fn warm_up_len1() {
456 let mut wma = WeightedMovingAverage::new(4, vec![1.0, 2.0, 3.0, 4.0], None);
457 wma.update_raw(42.0);
458 assert_eq!(wma.value(), 42.0);
459 }
460
461 #[rstest]
462 fn warm_up_len2() {
463 let mut wma = WeightedMovingAverage::new(4, vec![1.0, 2.0, 3.0, 4.0], None);
464 wma.update_raw(10.0);
465 wma.update_raw(20.0);
466 let expected = 20.0f64.mul_add(4.0, 10.0 * 3.0) / (4.0 + 3.0);
467 assert_eq!(wma.value(), expected);
468 }
469
470 #[rstest]
471 fn warm_up_len3() {
472 let mut wma = WeightedMovingAverage::new(4, vec![1.0, 2.0, 3.0, 4.0], None);
473 wma.update_raw(1.0);
474 wma.update_raw(2.0);
475 wma.update_raw(3.0);
476 let expected = 1.0f64.mul_add(2.0, 3.0f64.mul_add(4.0, 2.0 * 3.0)) / (4.0 + 3.0 + 2.0);
477 assert_eq!(wma.value(), expected);
478 }
479
480 #[rstest]
481 fn input_window_contains_latest_period() {
482 let period = 3;
483 let mut wma = WeightedMovingAverage::new(period, vec![1.0; period], None);
484 let vals = [1.0, 2.0, 3.0, 4.0];
485 for v in vals {
486 wma.update_raw(v);
487 }
488 let expected: Vec<f64> = vals[vals.len() - period..].to_vec();
489 assert_eq!(wma.inputs.iter().copied().collect::<Vec<_>>(), expected);
490 }
491
492 #[rstest]
493 fn window_slides_correctly() {
494 let mut wma = WeightedMovingAverage::new(2, vec![1.0; 2], None);
495 wma.update_raw(1.0);
496 assert_eq!(wma.inputs.iter().copied().collect::<Vec<_>>(), vec![1.0]);
497 wma.update_raw(2.0);
498 assert_eq!(
499 wma.inputs.iter().copied().collect::<Vec<_>>(),
500 vec![1.0, 2.0]
501 );
502 wma.update_raw(3.0);
503 assert_eq!(
504 wma.inputs.iter().copied().collect::<Vec<_>>(),
505 vec![2.0, 3.0]
506 );
507 }
508
509 #[rstest]
510 fn window_len_constant_after_many_updates() {
511 let period = 5;
512 let mut wma = WeightedMovingAverage::new(period, vec![1.0; period], None);
513 for i in 0..100 {
514 wma.update_raw(i as f64);
515 assert_eq!(wma.inputs.len(), period.min(i + 1));
516 }
517 }
518
519 #[rstest]
520 fn arraydeque_wraps_when_full() {
521 const CAP: usize = 3;
522 let mut buf: ArrayDeque<usize, CAP, Wrapping> = ArrayDeque::new();
523 for i in 0..=CAP {
524 let _ = buf.push_back(i);
525 }
526 assert_eq!(buf.len(), CAP);
527 assert_eq!(buf.front().copied(), Some(1));
528 assert_eq!(buf.back().copied(), Some(3));
529 }
530
531 #[rstest]
532 fn arraydeque_sliding_window_with_pop() {
533 const CAP: usize = 3;
534 let mut buf: ArrayDeque<usize, CAP, Wrapping> = ArrayDeque::new();
535 for i in 0..10 {
536 if buf.len() == CAP {
537 buf.pop_front();
538 }
539 let _ = buf.push_back(i);
540 assert!(buf.len() <= CAP);
541 }
542 assert_eq!(buf.len(), CAP);
543 }
544
545 #[rstest]
546 fn new_ok_with_infinite_weight() {
547 let res = WeightedMovingAverage::new_checked(2, vec![f64::INFINITY, 1.0], None);
548 assert!(res.is_ok());
549 }
550
551 #[rstest]
552 #[should_panic]
553 fn new_panics_on_nan_weight() {
554 let _ = WeightedMovingAverage::new(2, vec![f64::NAN, 1.0], None);
555 }
556
557 #[rstest]
558 #[should_panic]
559 fn new_panics_on_empty_weights() {
560 let _ = WeightedMovingAverage::new(1, Vec::new(), None);
561 }
562
563 #[rstest]
564 fn inf_input_propagates() {
565 let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
566 wma.update_raw(1.0);
567 wma.update_raw(f64::INFINITY);
568 assert!(wma.value().is_infinite());
569 }
570
571 #[rstest]
572 fn warm_up_with_front_zero_weights() {
573 let mut wma = WeightedMovingAverage::new(4, vec![0.0, 0.0, 1.0, 1.0], None);
574 wma.update_raw(10.0);
575 wma.update_raw(20.0);
576 let expected = 20.0f64.mul_add(1.0, 10.0 * 1.0) / 2.0;
577 assert_eq!(wma.value(), expected);
578 }
579}